Add IntMap.mergeWithKey.
authorMilan Straka <fox@ucw.cz>
Sat, 14 Apr 2012 14:55:46 +0000 (16:55 +0200)
committerMilan Straka <fox@ucw.cz>
Sat, 14 Apr 2012 14:55:46 +0000 (16:55 +0200)
The slightly more general internal function mergeWithKey' is used to
define both mergeWithKey and other combining functions union*,
difference*, intersection*.

The resulting implementations of union*, difference* and intersection*
are no slower than before, and up to 30% faster in case of two large
interleaved maps. Measured by benchmarks/SetOperations-IntMap.hs.

Data/IntMap/Base.hs
Data/IntMap/Lazy.hs
Data/IntMap/Strict.hs
tests/intmap-properties.hs

index 3d5f923..42399e9 100644 (file)
@@ -77,6 +77,10 @@ module Data.IntMap.Base (
             , intersectionWith
             , intersectionWithKey
 
+            -- ** Universal combining function
+            , mergeWithKey
+            , mergeWithKey'
+
             -- * Traversal
             -- ** Map
             , map
@@ -668,24 +672,8 @@ unionsWith f ts
 -- > union (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "a"), (7, "C")]
 
 union :: IntMap a -> IntMap a -> IntMap a
-union t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = union1
-  | shorter m2 m1  = union2
-  | p1 == p2       = Bin p1 m1 (union l1 l2) (union r1 r2)
-  | otherwise      = join p1 t1 p2 t2
-  where
-    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
-            | zero p2 m1        = Bin p1 m1 (union l1 t2) r1
-            | otherwise         = Bin p1 m1 l1 (union r1 t2)
-
-    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
-            | zero p1 m2        = Bin p2 m2 (union t1 l2) r2
-            | otherwise         = Bin p2 m2 l2 (union t1 r2)
-
-union (Tip k x) t = insert k x t
-union t (Tip k x) = insertWith (\_ y -> y) k x t  -- right bias
-union Nil t       = t
-union t Nil       = t
+union m1 m2
+  = mergeWithKey' Bin const id id m1 m2
 
 -- | /O(n+m)/. The union with a combining function.
 --
@@ -701,24 +689,8 @@ unionWith f m1 m2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: (Key -> a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
-unionWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = union1
-  | shorter m2 m1  = union2
-  | p1 == p2       = Bin p1 m1 (unionWithKey f l1 l2) (unionWithKey f r1 r2)
-  | otherwise      = join p1 t1 p2 t2
-  where
-    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
-            | zero p2 m1        = Bin p1 m1 (unionWithKey f l1 t2) r1
-            | otherwise         = Bin p1 m1 l1 (unionWithKey f r1 t2)
-
-    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
-            | zero p1 m2        = Bin p2 m2 (unionWithKey f t1 l2) r2
-            | otherwise         = Bin p2 m2 l2 (unionWithKey f t1 r2)
-
-unionWithKey f (Tip k x) t = insertWithKey f k x t
-unionWithKey f t (Tip k x) = insertWithKey (\k' x' y' -> f k' y' x') k x t  -- right bias
-unionWithKey _ Nil t  = t
-unionWithKey _ t Nil  = t
+unionWithKey f m1 m2
+  = mergeWithKey' Bin (\(Tip k1 x1) (Tip _k2 x2) -> Tip k1 (f k1 x1 x2)) id id m1 m2
 
 {--------------------------------------------------------------------
   Difference
@@ -728,27 +700,8 @@ unionWithKey _ t Nil  = t
 -- > difference (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 3 "b"
 
 difference :: IntMap a -> IntMap b -> IntMap a
-difference t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = difference1
-  | shorter m2 m1  = difference2
-  | p1 == p2       = bin p1 m1 (difference l1 l2) (difference r1 r2)
-  | otherwise      = t1
-  where
-    difference1 | nomatch p2 p1 m1  = t1
-                | zero p2 m1        = bin p1 m1 (difference l1 t2) r1
-                | otherwise         = bin p1 m1 l1 (difference r1 t2)
-
-    difference2 | nomatch p1 p2 m2  = t1
-                | zero p1 m2        = difference t1 l2
-                | otherwise         = difference t1 r2
-
-difference t1@(Tip k _) t2
-  | member k t2  = Nil
-  | otherwise    = t1
-
-difference Nil _       = Nil
-difference t (Tip k _) = delete k t
-difference t Nil       = t
+difference m1 m2
+  = mergeWithKey (\_ _ _ -> Nothing) id (const Nil) m1 m2
 
 -- | /O(n+m)/. Difference with a combining function.
 --
@@ -770,30 +723,8 @@ differenceWith f m1 m2
 -- >     == singleton 3 "3:b|B"
 
 differenceWithKey :: (Key -> a -> b -> Maybe a) -> IntMap a -> IntMap b -> IntMap a
-differenceWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = difference1
-  | shorter m2 m1  = difference2
-  | p1 == p2       = bin p1 m1 (differenceWithKey f l1 l2) (differenceWithKey f r1 r2)
-  | otherwise      = t1
-  where
-    difference1 | nomatch p2 p1 m1  = t1
-                | zero p2 m1        = bin p1 m1 (differenceWithKey f l1 t2) r1
-                | otherwise         = bin p1 m1 l1 (differenceWithKey f r1 t2)
-
-    difference2 | nomatch p1 p2 m2  = t1
-                | zero p1 m2        = differenceWithKey f t1 l2
-                | otherwise         = differenceWithKey f t1 r2
-
-differenceWithKey f t1@(Tip k x) t2
-  = case lookup k t2 of
-      Just y  -> case f k x y of
-                   Just y' -> Tip k y'
-                   Nothing -> Nil
-      Nothing -> t1
-
-differenceWithKey _ Nil _       = Nil
-differenceWithKey f t (Tip k y) = updateWithKey (\k' x -> f k' x y) k t
-differenceWithKey _ t Nil       = t
+differenceWithKey f m1 m2
+  = mergeWithKey f id (const Nil) m1 m2
 
 
 {--------------------------------------------------------------------
@@ -804,29 +735,8 @@ differenceWithKey _ t Nil       = t
 -- > intersection (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "a"
 
 intersection :: IntMap a -> IntMap b -> IntMap a
-intersection t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = intersection1
-  | shorter m2 m1  = intersection2
-  | p1 == p2       = bin p1 m1 (intersection l1 l2) (intersection r1 r2)
-  | otherwise      = Nil
-  where
-    intersection1 | nomatch p2 p1 m1  = Nil
-                  | zero p2 m1        = intersection l1 t2
-                  | otherwise         = intersection r1 t2
-
-    intersection2 | nomatch p1 p2 m2  = Nil
-                  | zero p1 m2        = intersection t1 l2
-                  | otherwise         = intersection t1 r2
-
-intersection t1@(Tip k _) t2
-  | member k t2  = t1
-  | otherwise    = Nil
-intersection t (Tip k _)
-  = case lookup k t of
-      Just y  -> Tip k y
-      Nothing -> Nil
-intersection Nil _ = Nil
-intersection _ Nil = Nil
+intersection m1 m2
+  = mergeWithKey' bin const (const Nil) (const Nil) m1 m2
 
 -- | /O(n+m)/. The intersection with a combining function.
 --
@@ -842,31 +752,108 @@ intersectionWith f m1 m2
 -- > intersectionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "5:a|A"
 
 intersectionWithKey :: (Key -> a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
-intersectionWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = intersection1
-  | shorter m2 m1  = intersection2
-  | p1 == p2       = bin p1 m1 (intersectionWithKey f l1 l2) (intersectionWithKey f r1 r2)
-  | otherwise      = Nil
-  where
-    intersection1 | nomatch p2 p1 m1  = Nil
-                  | zero p2 m1        = intersectionWithKey f l1 t2
-                  | otherwise         = intersectionWithKey f r1 t2
-
-    intersection2 | nomatch p1 p2 m2  = Nil
-                  | zero p1 m2        = intersectionWithKey f t1 l2
-                  | otherwise         = intersectionWithKey f t1 r2
-
-intersectionWithKey f (Tip k x) t2
-  = case lookup k t2 of
-      Just y  -> Tip k (f k x y)
-      Nothing -> Nil
-intersectionWithKey f t1 (Tip k y)
-  = case lookup k t1 of
-      Just x  -> Tip k (f k x y)
-      Nothing -> Nil
-intersectionWithKey _ Nil _ = Nil
-intersectionWithKey _ _ Nil = Nil
+intersectionWithKey f m1 m2
+  = mergeWithKey' bin (\(Tip k1 x1) (Tip _k2 x2) -> Tip k1 (f k1 x1 x2)) (const Nil) (const Nil) m1 m2
 
+{--------------------------------------------------------------------
+  MergeWithKey
+--------------------------------------------------------------------}
+
+-- | /O(n+m)/. A high-performance universal combining function. Using
+-- 'mergeWithKey', all combining functions can be defined without any loss of
+-- efficiency (with exception of 'union', 'difference' and 'intersection',
+-- where sharing of some nodes is lost with 'mergeWithKey').
+--
+-- Please make sure you know what is going on when using 'mergeWithKey',
+-- otherwise you can be surprised by unexpected code growth or even
+-- corruption of the data structure.
+--
+-- When 'mergeWithKey' is given three arguments, it is inlined to the call
+-- site. You should therefore use 'mergeWithKey' only to define your custom
+-- combining functions. For example, you could define 'unionWithKey',
+-- 'differenceWithKey' and 'intersectionWithKey' as
+--
+-- > myUnionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) id id m1 m2
+-- > myDifferenceWithKey f m1 m2 = mergeWithKey f id (const empty) m1 m2
+-- > myIntersectionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) (const empty) (const empty) m1 m2
+--
+-- When calling @'mergeWithKey' combine only1 only2@, a function combining two
+-- 'IntMap's is created, such that
+--
+-- * if a key is present in both maps, it is passed with both corresponding
+--   values to the @combine@ function. Depending on the result, the key is either
+--   present in the result with specified value, or is left out;
+--
+-- * a nonempty subtree present only in the first map is passed to @only1@ and
+--   the output is added to the result;
+--
+-- * a nonempty subtree present only in the second map is passed to @only2@ and
+--   the output is added to the result.
+--
+-- The @only1@ and @only2@ methods /must return a map with a subset (possibly empty) of the keys of the given map/.
+-- The values can be modified arbitrarily.  Most common variants of @only1@ and
+-- @only2@ are 'id' and @'const' 'empty'@, but for example @'map' f@ or
+-- @'filterWithKey' f@ could be used for any @f@.
+
+mergeWithKey :: (Key -> a -> b -> Maybe c) -> (IntMap a -> IntMap c) -> (IntMap b -> IntMap c)
+             -> IntMap a -> IntMap b -> IntMap c
+mergeWithKey f g1 g2 = mergeWithKey' bin combine g1 g2
+  where combine (Tip k1 x1) (Tip _k2 x2) = case f k1 x1 x2 of Nothing -> Nil
+                                                              Just x -> Tip k1 x
+        {-# INLINE combine #-}
+{-# INLINE mergeWithKey #-}
+
+-- Slightly more general version of mergeWithKey. It differs in the following:
+--
+-- * the combining function operates on maps instead of keys and values. The
+--   reason is to enable sharing in union, difference and intersection.
+--
+-- * mergeWithKey' is given an equivalent of bin. The reason is that in union*,
+--   Bin constructor can be used, because we know both subtrees are nonempty.
+
+mergeWithKey' :: (Prefix -> Mask -> IntMap c -> IntMap c -> IntMap c)
+              -> (IntMap a -> IntMap b -> IntMap c) -> (IntMap a -> IntMap c) -> (IntMap b -> IntMap c)
+              -> IntMap a -> IntMap b -> IntMap c
+mergeWithKey' bin' f g1 g2 = go
+  where
+    go t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
+      | shorter m1 m2  = merge1
+      | shorter m2 m1  = merge2
+      | p1 == p2       = bin' p1 m1 (go l1 l2) (go r1 r2)
+      | otherwise      = maybe_join p1 (g1 t1) p2 (g2 t2)
+      where
+        merge1 | nomatch p2 p1 m1  = maybe_join p1 (g1 t1) p2 (g2 t2)
+               | zero p2 m1        = bin' p1 m1 (go l1 t2) (g1 r1)
+               | otherwise         = bin' p1 m1 (g1 l1) (go r1 t2)
+        merge2 | nomatch p1 p2 m2  = maybe_join p1 (g1 t1) p2 (g2 t2)
+               | zero p1 m2        = bin' p2 m2 (go t1 l2) (g2 r2)
+               | otherwise         = bin' p2 m2 (g2 l2) (go t1 r2)
+
+    go t1'@(Bin _ _ _ _) t2@(Tip k2 x2) = merge t1'
+      where merge t1@(Bin p1 m1 l1 r1) | nomatch k2 p1 m1 = maybe_join p1 (g1 t1) k2 (g2 t2)
+                                       | zero k2 m1 = bin' p1 m1 (merge l1) (g1 r1)
+                                       | otherwise  = bin' p1 m1 (g1 l1) (merge r1)
+            merge t1@(Tip k1 x1) | k1 == k2 = f t1 t2
+                                 | otherwise = maybe_join k1 (g1 t1) k2 (g2 t2)
+            merge Nil = g2 t2
+
+    go t1@(Bin _ _ _ _) Nil = g1 t1
+
+    go t1@(Tip k1 x1) t2' = merge t2'
+      where merge t2@(Bin p2 m2 l2 r2) | nomatch k1 p2 m2 = maybe_join k1 (g1 t1) p2 (g2 t2)
+                                       | zero k1 m2 = bin' p2 m2 (merge l2) (g2 r2)
+                                       | otherwise  = bin' p2 m2 (g2 l2) (merge r2)
+            merge t2@(Tip k2 x2) | k1 == k2 = f t1 t2
+                                 | otherwise = maybe_join k1 (g1 t1) k2 (g2 t2)
+            merge Nil = g1 t1
+
+    go Nil t2 = g2 t2
+
+    maybe_join _ Nil _ t2 = t2
+    maybe_join _ t1 _ Nil = t1
+    maybe_join p1 t1 p2 t2 = join p1 t1 p2 t2
+    {-# INLINE maybe_join #-}
+{-# INLINE mergeWithKey' #-}
 
 {--------------------------------------------------------------------
   Min\/Max
index e6b2d99..a8c6578 100644 (file)
@@ -109,6 +109,9 @@ module Data.IntMap.Lazy (
             , intersectionWith
             , intersectionWithKey
 
+            -- ** Universal combining function
+            , mergeWithKey
+
             -- * Traversal
             -- ** Map
             , IM.map
index 17d6b5d..1fd49a0 100644 (file)
@@ -113,6 +113,9 @@ module Data.IntMap.Strict (
             , intersectionWith
             , intersectionWithKey
 
+            -- ** Universal combining function
+            , mergeWithKey
+
             -- * Traversal
             -- ** Map
             , map
@@ -216,6 +219,7 @@ import Data.IntMap.Base hiding
     , differenceWithKey
     , intersectionWith
     , intersectionWithKey
+    , mergeWithKey
     , updateMinWithKey
     , updateMaxWithKey
     , updateMax
@@ -520,24 +524,8 @@ unionWith f m1 m2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: (Key -> a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
-unionWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = union1
-  | shorter m2 m1  = union2
-  | p1 == p2       = Bin p1 m1 (unionWithKey f l1 l2) (unionWithKey f r1 r2)
-  | otherwise      = join p1 t1 p2 t2
-  where
-    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
-            | zero p2 m1        = Bin p1 m1 (unionWithKey f l1 t2) r1
-            | otherwise         = Bin p1 m1 l1 (unionWithKey f r1 t2)
-
-    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
-            | zero p1 m2        = Bin p2 m2 (unionWithKey f t1 l2) r2
-            | otherwise         = Bin p2 m2 l2 (unionWithKey f t1 r2)
-
-unionWithKey f (Tip k x) t = insertWithKey f k x t
-unionWithKey f t (Tip k x) = insertWithKey (\k' x' y' -> f k' y' x') k x t  -- right bias
-unionWithKey _ Nil t  = t
-unionWithKey _ t Nil  = t
+unionWithKey f m1 m2
+  = mergeWithKey' Bin (\(Tip k1 x1) (Tip _k2 x2) -> Tip k1 $! f k1 x1 x2) id id m1 m2
 
 {--------------------------------------------------------------------
   Difference
@@ -563,31 +551,8 @@ differenceWith f m1 m2
 -- >     == singleton 3 "3:b|B"
 
 differenceWithKey :: (Key -> a -> b -> Maybe a) -> IntMap a -> IntMap b -> IntMap a
-differenceWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = difference1
-  | shorter m2 m1  = difference2
-  | p1 == p2       = bin p1 m1 (differenceWithKey f l1 l2) (differenceWithKey f r1 r2)
-  | otherwise      = t1
-  where
-    difference1 | nomatch p2 p1 m1  = t1
-                | zero p2 m1        = bin p1 m1 (differenceWithKey f l1 t2) r1
-                | otherwise         = bin p1 m1 l1 (differenceWithKey f r1 t2)
-
-    difference2 | nomatch p1 p2 m2  = t1
-                | zero p1 m2        = differenceWithKey f t1 l2
-                | otherwise         = differenceWithKey f t1 r2
-
-differenceWithKey f t1@(Tip k x) t2
-  = case lookup k t2 of
-      Just y  -> case f k x y of
-                   Just y' -> y' `seq` Tip k y'
-                   Nothing -> Nil
-      Nothing -> t1
-
-differenceWithKey _ Nil _       = Nil
-differenceWithKey f t (Tip k y) = updateWithKey (\k' x -> f k' x y) k t
-differenceWithKey _ t Nil       = t
-
+differenceWithKey f m1 m2
+  = mergeWithKey f id (const Nil) m1 m2
 
 {--------------------------------------------------------------------
   Intersection
@@ -607,31 +572,56 @@ intersectionWith f m1 m2
 -- > intersectionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "5:a|A"
 
 intersectionWithKey :: (Key -> a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
-intersectionWithKey f t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
-  | shorter m1 m2  = intersection1
-  | shorter m2 m1  = intersection2
-  | p1 == p2       = bin p1 m1 (intersectionWithKey f l1 l2) (intersectionWithKey f r1 r2)
-  | otherwise      = Nil
-  where
-    intersection1 | nomatch p2 p1 m1  = Nil
-                  | zero p2 m1        = intersectionWithKey f l1 t2
-                  | otherwise         = intersectionWithKey f r1 t2
-
-    intersection2 | nomatch p1 p2 m2  = Nil
-                  | zero p1 m2        = intersectionWithKey f t1 l2
-                  | otherwise         = intersectionWithKey f t1 r2
-
-intersectionWithKey f (Tip k x) t2
-  = case lookup k t2 of
-      Just y  -> Tip k $! f k x y
-      Nothing -> Nil
-intersectionWithKey f t1 (Tip k y)
-  = case lookup k t1 of
-      Just x  -> Tip k $! f k x y
-      Nothing -> Nil
-intersectionWithKey _ Nil _ = Nil
-intersectionWithKey _ _ Nil = Nil
+intersectionWithKey f m1 m2
+  = mergeWithKey' bin (\(Tip k1 x1) (Tip _k2 x2) -> Tip k1 $! f k1 x1 x2) (const Nil) (const Nil) m1 m2
+
+{--------------------------------------------------------------------
+  MergeWithKey
+--------------------------------------------------------------------}
 
+-- | /O(n+m)/. A high-performance universal combining function. Using
+-- 'mergeWithKey', all combining functions can be defined without any loss of
+-- efficiency (with exception of 'union', 'difference' and 'intersection',
+-- where sharing of some nodes is lost with 'mergeWithKey').
+--
+-- Please make sure you know what is going on when using 'mergeWithKey',
+-- otherwise you can be surprised by unexpected code growth or even
+-- corruption of the data structure.
+--
+-- When 'mergeWithKey' is given three arguments, it is inlined to the call
+-- site. You should therefore use 'mergeWithKey' only to define your custom
+-- combining functions. For example, you could define 'unionWithKey',
+-- 'differenceWithKey' and 'intersectionWithKey' as
+--
+-- > myUnionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) id id m1 m2
+-- > myDifferenceWithKey f m1 m2 = mergeWithKey f id (const empty) m1 m2
+-- > myIntersectionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) (const empty) (const empty) m1 m2
+--
+-- When calling @'mergeWithKey' combine only1 only2@, a function combining two
+-- 'IntMap's is created, such that
+--
+-- * if a key is present in both maps, it is passed with both corresponding
+--   values to the @combine@ function. Depending on the result, the key is either
+--   present in the result with specified value, or is left out;
+--
+-- * a nonempty subtree present only in the first map is passed to @only1@ and
+--   the output is added to the result;
+--
+-- * a nonempty subtree present only in the second map is passed to @only2@ and
+--   the output is added to the result.
+--
+-- The @only1@ and @only2@ methods /must return a map with a subset (possibly empty) of the keys of the given map/.
+-- The values can be modified arbitrarily.  Most common variants of @only1@ and
+-- @only2@ are 'id' and @'const' 'empty'@, but for example @'map' f@ or
+-- @'filterWithKey' f@ could be used for any @f@.
+
+mergeWithKey :: (Key -> a -> b -> Maybe c) -> (IntMap a -> IntMap c) -> (IntMap b -> IntMap c)
+             -> IntMap a -> IntMap b -> IntMap c
+mergeWithKey f g1 g2 = mergeWithKey' bin combine g1 g2
+  where combine (Tip k1 x1) (Tip _k2 x2) = case f k1 x1 x2 of Nothing -> Nil
+                                                              Just x -> x `seq` Tip k1 x
+        {-# INLINE combine #-}
+{-# INLINE mergeWithKey #-}
 
 {--------------------------------------------------------------------
   Min\/Max
index 0bed938..679a92a 100644 (file)
@@ -8,6 +8,7 @@ import Data.IntMap.Lazy as Data.IntMap
 
 import Data.Monoid
 import Data.Maybe hiding (mapMaybe)
+import qualified Data.Maybe as Maybe (mapMaybe)
 import Data.Ord
 import Data.Function
 import Prelude hiding (lookup, null, map, filter, foldr, foldl)
@@ -126,6 +127,7 @@ main = defaultMainWithOpts
              , testProperty "intersection model"   prop_intersectionModel
              , testProperty "intersectionWith model" prop_intersectionWithModel
              , testProperty "intersectionWithKey model" prop_intersectionWithKeyModel
+             , testProperty "mergeWithKey model"   prop_mergeWithKeyModel
              , testProperty "fromAscList"          prop_ordered
              , testProperty "fromList then toList" prop_list
              , testProperty "toDescList"           prop_descList
@@ -758,6 +760,36 @@ prop_intersectionWithKeyModel xs ys
           ys' = List.nubBy ((==) `on` fst) ys
           f k l r = k + 2 * l + 3 * r
 
+prop_mergeWithKeyModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_mergeWithKeyModel xs ys
+  = and [ testMergeWithKey f keep_x keep_y
+        | f <- [ \_k x1  _x2 -> Just x1
+               , \_k _x1 x2  -> Just x2
+               , \_k _x1 _x2 -> Nothing
+               , \k  x1  x2  -> if k `mod` 2 == 0 then Nothing else Just (2 * x1 + 3 * x2)
+               ]
+        , keep_x <- [ True, False ]
+        , keep_y <- [ True, False ]
+        ]
+
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+
+          xm = fromList xs'
+          ym = fromList ys'
+
+          testMergeWithKey f keep_x keep_y
+            = toList (mergeWithKey f (keep keep_x) (keep keep_y) xm ym) == emulateMergeWithKey f keep_x keep_y
+              where keep False _ = empty
+                    keep True  m = m
+
+                    emulateMergeWithKey f keep_x keep_y
+                      = Maybe.mapMaybe combine (sort $ List.union (List.map fst xs') (List.map fst ys'))
+                        where combine k = case (List.lookup k xs', List.lookup k ys') of
+                                            (Nothing, Just y) -> if keep_y then Just (k, y) else Nothing
+                                            (Just x, Nothing) -> if keep_x then Just (k, x) else Nothing
+                                            (Just x, Just y) -> (\v -> (k, v)) `fmap` f k x y
+
 ----------------------------------------------------------------
 
 prop_ordered :: Property