Improve {Map,IntMap}.intersection* and its tests.
authorMilan Straka <fox@ucw.cz>
Sat, 3 Mar 2012 10:28:18 +0000 (11:28 +0100)
committerMilan Straka <fox@ucw.cz>
Sun, 4 Mar 2012 15:38:12 +0000 (16:38 +0100)
* Add tests for intersectionWith*.
* Add specific Map.intersection implementation instead of using
  Map.intersectionWithKey.
* Refactor Map.intersectionWithKey implementatioin.

Data/Map/Base.hs
Data/Map/Strict.hs
tests/intmap-properties.hs
tests/map-properties.hs

index 3c5171e..b0c19e4 100644 (file)
@@ -1189,8 +1189,17 @@ hedgeDiffWithKey f blo bhi t (Bin _ kx x l r)
 -- > intersection (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "a"
 
 intersection :: Ord k => Map k a -> Map k b -> Map k a
-intersection m1 m2
-  = intersectionWithKey (\_ x _ -> x) m1 m2
+intersection Tip _ = Tip
+intersection _ Tip = Tip
+intersection t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 _ l2 r2) =
+   if s1 >= s2 then
+     case splitLookupWithKey k2 t1 of
+       (lt, Just (k, x), gt) -> join k x (intersection lt l2) (intersection gt r2)
+       (lt, Nothing, gt) -> merge (intersection lt l2) (intersection gt r2)
+   else
+      case splitLookup k1 t2 of
+        (lt, Just _, gt) -> join k1 x1 (intersection l1 lt) (intersection r1 gt)
+        (lt, Nothing, gt) -> merge (intersection l1 lt) (intersection r1 gt)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE intersection #-}
 #endif
@@ -1218,18 +1227,13 @@ intersectionWithKey _ Tip _ = Tip
 intersectionWithKey _ _ Tip = Tip
 intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
    if s1 >= s2 then
-      let (lt,found,gt) = splitLookupWithKey k2 t1
-          tl            = intersectionWithKey f lt l2
-          tr            = intersectionWithKey f gt r2
-      in case found of
-      Just (k,x) -> join k (f k x x2) tl tr
-      Nothing -> merge tl tr
-   else let (lt,found,gt) = splitLookup k1 t2
-            tl            = intersectionWithKey f l1 lt
-            tr            = intersectionWithKey f r1 gt
-      in case found of
-      Just x -> join k1 (f k1 x1 x) tl tr
-      Nothing -> merge tl tr
+     case splitLookupWithKey k2 t1 of
+       (lt, Just (k, x), gt) -> join k (f k x x2) (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
+       (lt, Nothing, gt) -> merge (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
+   else
+      case splitLookup k1 t2 of
+        (lt, Just x, gt) -> join k1 (f k1 x1 x) (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
+        (lt, Nothing, gt) -> merge (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE intersectionWithKey #-}
 #endif
index 79b44c5..e99521d 100644 (file)
@@ -824,18 +824,13 @@ intersectionWithKey _ Tip _ = Tip
 intersectionWithKey _ _ Tip = Tip
 intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
    if s1 >= s2 then
-      let (lt,found,gt) = splitLookupWithKey k2 t1
-          tl            = intersectionWithKey f lt l2
-          tr            = intersectionWithKey f gt r2
-      in case found of
-      Just (k,x) -> join k (f k x x2) tl tr
-      Nothing -> merge tl tr
-   else let (lt,found,gt) = splitLookup k1 t2
-            tl            = intersectionWithKey f l1 lt
-            tr            = intersectionWithKey f r1 gt
-      in case found of
-      Just x -> let x' = f k1 x1 x in x' `seq` join k1 x' tl tr
-      Nothing -> merge tl tr
+     case splitLookupWithKey k2 t1 of
+       (lt, Just (k, x), gt) -> case f k x x2 of x' -> x' `seq` join k x' (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
+       (lt, Nothing, gt) -> merge (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
+   else
+      case splitLookup k1 t2 of
+        (lt, Just x, gt) -> case f k1 x1 x of x' -> x' `seq` join k1 x' (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
+        (lt, Nothing, gt) -> merge (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE intersectionWithKey #-}
 #endif
index b6dcc12..fc24534 100644 (file)
@@ -124,6 +124,8 @@ main = defaultMainWithOpts
              , testProperty "union sum"            prop_unionSum
              , testProperty "difference model"     prop_differenceModel
              , testProperty "intersection model"   prop_intersectionModel
+             , testProperty "intersectionWith model" prop_intersectionWithModel
+             , testProperty "intersectionWithKey model" prop_intersectionWithKeyModel
              , testProperty "fromAscList"          prop_ordered
              , testProperty "fromList then toList" prop_list
              , testProperty "toDescList"           prop_descList
@@ -740,6 +742,22 @@ prop_intersectionModel xs ys
   = sort (keys (intersection (fromListWith (+) xs) (fromListWith (+) ys)))
     == sort (nub ((List.intersect) (Prelude.map fst xs) (Prelude.map fst ys)))
 
+prop_intersectionWithModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_intersectionWithModel xs ys
+  = toList (intersectionWith f (fromList xs') (fromList ys'))
+    == [(kx, f vx vy ) | (kx, vx) <- List.sort xs', (ky, vy) <- ys', kx == ky]
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+          f l r = l + 2 * r
+
+prop_intersectionWithKeyModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_intersectionWithKeyModel xs ys
+  = toList (intersectionWithKey f (fromList xs') (fromList ys'))
+    == [(kx, f kx vx vy) | (kx, vx) <- List.sort xs', (ky, vy) <- ys', kx == ky]
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+          f k l r = k + 2 * l + 3 * r
+
 ----------------------------------------------------------------
 
 prop_ordered :: Property
index b6dc089..4e9cfe5 100644 (file)
@@ -143,6 +143,10 @@ main = defaultMainWithOpts
          , testProperty "difference model"     prop_differenceModel
          , testProperty "intersection"         prop_intersection
          , testProperty "intersection model"   prop_intersectionModel
+         , testProperty "intersectionWith"     prop_intersectionWith
+         , testProperty "intersectionWithModel" prop_intersectionWithModel
+         , testProperty "intersectionWithKey"  prop_intersectionWithKey
+         , testProperty "intersectionWithKeyModel" prop_intersectionWithKeyModel
          , testProperty "fromAscList"          prop_ordered
          , testProperty "fromList then toList" prop_list
          , testProperty "toDescList"           prop_descList
@@ -869,6 +873,28 @@ prop_intersectionModel xs ys
   = sort (keys (intersection (fromListWith (+) xs) (fromListWith (+) ys)))
     == sort (nub ((List.intersect) (Prelude.map fst xs) (Prelude.map fst ys)))
 
+prop_intersectionWith :: (Int -> Int -> Maybe Int) -> IMap -> IMap -> Bool
+prop_intersectionWith f t1 t2 = valid (intersectionWith f t1 t2)
+
+prop_intersectionWithModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_intersectionWithModel xs ys
+  = toList (intersectionWith f (fromList xs') (fromList ys'))
+    == [(kx, f vx vy) | (kx, vx) <- List.sort xs', (ky, vy) <- ys', kx == ky]
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+          f l r = l + 2 * r
+
+prop_intersectionWithKey :: (Int -> Int -> Int -> Maybe Int) -> IMap -> IMap -> Bool
+prop_intersectionWithKey f t1 t2 = valid (intersectionWithKey f t1 t2)
+
+prop_intersectionWithKeyModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_intersectionWithKeyModel xs ys
+  = toList (intersectionWithKey f (fromList xs') (fromList ys'))
+    == [(kx, f kx vx vy) | (kx, vx) <- List.sort xs', (ky, vy) <- ys', kx == ky]
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+          f k l r = k + 2 * l + 3 * r
+
 ----------------------------------------------------------------
 
 prop_ordered :: Property