Add Map.mergeWithKey.
authorMilan Straka <fox@ucw.cz>
Sat, 28 Apr 2012 14:35:51 +0000 (16:35 +0200)
committerMilan Straka <fox@ucw.cz>
Sat, 28 Apr 2012 14:35:51 +0000 (16:35 +0200)
Data/Map/Base.hs
Data/Map/Lazy.hs
Data/Map/Strict.hs
tests/map-properties.hs

index fd8612f..2402289 100644 (file)
@@ -148,6 +148,9 @@ module Data.Map.Base (
             , intersectionWith
             , intersectionWithKey
 
+            -- ** Universal combining function
+            , mergeWithKey
+
             -- * Traversal
             -- ** Map
             , map
@@ -259,6 +262,7 @@ module Data.Map.Base (
 
 import Prelude hiding (lookup,map,filter,foldr,foldl,null)
 import qualified Data.Set.Base as Set
+import Data.StrictPair
 import Data.Monoid (Monoid(..))
 import Control.Applicative (Applicative(..), (<$>))
 import Data.Traversable (Traversable(traverse))
@@ -411,23 +415,6 @@ lookup = go
 {-# INLINE lookup #-}
 #endif
 
--- See Note: Type of local 'go' function
-lookupAssoc :: Ord k => k -> Map k a -> Maybe (k,a)
-lookupAssoc = go
-  where
-    go :: Ord k => k -> Map k a -> Maybe (k,a)
-    STRICT_1_OF_2(go)
-    go _ Tip = Nothing
-    go k (Bin _ kx x l r) = case compare k kx of
-      LT -> go k l
-      GT -> go k r
-      EQ -> Just (kx,x)
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE lookupAssoc #-}
-#else
-{-# INLINE lookupAssoc #-}
-#endif
-
 -- | /O(log n)/. Is the key a member of the map? See also 'notMember'.
 --
 -- > member 5 (fromList [(5,'a'), (3,'b')]) == True
@@ -1461,6 +1448,69 @@ intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
 #endif
 
 
+{--------------------------------------------------------------------
+  MergeWithKey
+--------------------------------------------------------------------}
+
+-- | /O(n+m)/. A high-performance universal combining function. This function
+-- is used to define 'unionWith', 'unionWithKey', 'differenceWith',
+-- 'differenceWithKey', 'intersectionWith', 'intersectionWithKey' and can be
+-- used to define other custom combine functions.
+--
+-- Please make sure you know what is going on when using 'mergeWithKey',
+-- otherwise you can be surprised by unexpected code growth or even
+-- corruption of the data structure.
+--
+-- When 'mergeWithKey' is given three arguments, it is inlined to the call
+-- site. You should therefore use 'mergeWithKey' only to define your custom
+-- combining functions. For example, you could define 'unionWithKey',
+-- 'differenceWithKey' and 'intersectionWithKey' as
+--
+-- > myUnionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) id id m1 m2
+-- > myDifferenceWithKey f m1 m2 = mergeWithKey f id (const empty) m1 m2
+-- > myIntersectionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) (const empty) (const empty) m1 m2
+--
+-- When calling @'mergeWithKey' combine only1 only2@, a function combining two
+-- 'IntMap's is created, such that
+--
+-- * if a key is present in both maps, it is passed with both corresponding
+--   values to the @combine@ function. Depending on the result, the key is either
+--   present in the result with specified value, or is left out;
+--
+-- * a nonempty subtree present only in the first map is passed to @only1@ and
+--   the output is added to the result;
+--
+-- * a nonempty subtree present only in the second map is passed to @only2@ and
+--   the output is added to the result.
+--
+-- The @only1@ and @only2@ methods /must return a map with a subset (possibly empty) of the keys of the given map/.
+-- The values can be modified arbitrarily. Most common variants of @only1@ and
+-- @only2@ are 'id' and @'const' 'empty'@, but for example @'map' f@ or
+-- @'filterWithKey' f@ could be used for any @f@.
+
+mergeWithKey :: Ord k => (k -> a -> b -> Maybe c) -> (Map k a -> Map k c) -> (Map k b -> Map k c)
+             -> Map k a -> Map k b -> Map k c
+mergeWithKey f g1 g2 = go
+  where
+    go Tip t2 = g2 t2
+    go t1 Tip = g1 t1
+    go t1 t2 = hedgeMerge NothingS NothingS t1 t2
+
+    hedgeMerge _   _   t1  Tip = g1 t1
+    hedgeMerge blo bhi Tip (Bin _ kx x l r) = g2 $ join kx x (filterGt blo l) (filterLt bhi r)
+    hedgeMerge blo bhi (Bin _ kx x l r) t2 = let l' = hedgeMerge blo bmi l (trim blo bmi t2)
+                                                 (found, trim_t2) = trimLookupLo kx bhi t2
+                                                 r' = hedgeMerge bmi bhi r trim_t2
+                                             in case found of
+                                                  Nothing -> case g1 (singleton kx x) of
+                                                               Tip -> merge l' r'
+                                                               (Bin _ _ x' Tip Tip) -> join kx x' l' r'
+                                                               _ -> error "mergeWithKey: Given function only1 does not fulfil required conditions (see documentation)"
+                                                  Just x2 -> case f kx x x2 of
+                                                               Nothing -> merge l' r'
+                                                               Just x' -> join kx x' l' r'
+      where bmi = JustS kx
+{-# INLINE mergeWithKey #-}
 
 {--------------------------------------------------------------------
   Submap
@@ -2182,17 +2232,28 @@ trim (JustS lk) (JustS hk) t = middle lk hk t  where middle lo hi (Bin _ k _ _ r
 {-# INLINABLE trim #-}
 #endif
 
-trimLookupLo :: Ord k => k -> MaybeS k -> Map k a -> (Maybe (k,a), Map k a)
-trimLookupLo _  _  Tip = (Nothing, Tip)
-trimLookupLo lo hi t@(Bin _ kx x l r)
-  = case compare lo kx of
-      LT -> case compare' kx hi of
-              LT -> (lookupAssoc lo t, t)
-              _  -> trimLookupLo lo hi l
-      GT -> trimLookupLo lo hi r
-      EQ -> (Just (kx,x),trim (JustS lo) hi r)
-  where compare' _    NothingS   = LT
-        compare' kx' (JustS hi') = compare kx' hi'
+-- Helper function for 'mergeWithKey'. The @'trimLookupLo' lk hk t@ performs both
+-- @'trim' (JustS lk) hk t@ and @'lookup' lk t@.
+
+-- See Note: Type of local 'go' function
+trimLookupLo :: Ord k => k -> MaybeS k -> Map k a -> (Maybe a, Map k a)
+trimLookupLo lk NothingS t = greater lk t
+  where greater :: Ord k => k -> Map k a -> (Maybe a, Map k a)
+        greater lo t'@(Bin _ kx x l r) = case compare lo kx of LT -> lookup lo l `strictPair` t'
+                                                               EQ -> (Just x, r)
+                                                               GT -> greater lo r
+        greater _ Tip = (Nothing, Tip)
+trimLookupLo lk (JustS hk) t = middle lk hk t
+  where middle :: Ord k => k -> k -> Map k a -> (Maybe a, Map k a)
+        middle lo hi t'@(Bin _ kx x l r) = case compare lo kx of LT | kx < hi -> lookup lo l `strictPair` t'
+                                                                    | otherwise -> middle lo hi l
+                                                                 EQ -> Just x `strictPair` lesser hi r
+                                                                 GT -> middle lo hi r
+        middle _ _ Tip = (Nothing, Tip)
+
+        lesser :: Ord k => k -> Map k a -> Map k a
+        lesser hi (Bin _ k _ l _) | k >= hi = lesser hi l
+        lesser _ t' = t'
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE trimLookupLo #-}
 #endif
index 1a44933..ddccffd 100644 (file)
@@ -109,6 +109,9 @@ module Data.Map.Lazy (
             , intersectionWith
             , intersectionWithKey
 
+            -- ** Universal combining function
+            , mergeWithKey
+
             -- * Traversal
             -- ** Map
             , M.map
index c0b7af3..1de7316 100644 (file)
@@ -116,6 +116,9 @@ module Data.Map.Strict
     , intersectionWith
     , intersectionWithKey
 
+    -- ** Universal combining function
+    , mergeWithKey
+
     -- * Traversal
     -- ** Map
     , map
@@ -237,6 +240,7 @@ import Data.Map.Base hiding
     , differenceWithKey
     , intersectionWith
     , intersectionWithKey
+    , mergeWithKey
     , map
     , mapWithKey
     , mapAccum
@@ -846,6 +850,71 @@ intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
 {-# INLINABLE intersectionWithKey #-}
 #endif
 
+
+{--------------------------------------------------------------------
+  MergeWithKey
+--------------------------------------------------------------------}
+
+-- | /O(n+m)/. A high-performance universal combining function. This function
+-- is used to define 'unionWith', 'unionWithKey', 'differenceWith',
+-- 'differenceWithKey', 'intersectionWith', 'intersectionWithKey' and can be
+-- used to define other custom combine functions.
+--
+-- Please make sure you know what is going on when using 'mergeWithKey',
+-- otherwise you can be surprised by unexpected code growth or even
+-- corruption of the data structure.
+--
+-- When 'mergeWithKey' is given three arguments, it is inlined to the call
+-- site. You should therefore use 'mergeWithKey' only to define your custom
+-- combining functions. For example, you could define 'unionWithKey',
+-- 'differenceWithKey' and 'intersectionWithKey' as
+--
+-- > myUnionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) id id m1 m2
+-- > myDifferenceWithKey f m1 m2 = mergeWithKey f id (const empty) m1 m2
+-- > myIntersectionWithKey f m1 m2 = mergeWithKey (\k x1 x2 -> Just (f k x1 x2)) (const empty) (const empty) m1 m2
+--
+-- When calling @'mergeWithKey' combine only1 only2@, a function combining two
+-- 'IntMap's is created, such that
+--
+-- * if a key is present in both maps, it is passed with both corresponding
+--   values to the @combine@ function. Depending on the result, the key is either
+--   present in the result with specified value, or is left out;
+--
+-- * a nonempty subtree present only in the first map is passed to @only1@ and
+--   the output is added to the result;
+--
+-- * a nonempty subtree present only in the second map is passed to @only2@ and
+--   the output is added to the result.
+--
+-- The @only1@ and @only2@ methods /must return a map with a subset (possibly empty) of the keys of the given map/.
+-- The values can be modified arbitrarily. Most common variants of @only1@ and
+-- @only2@ are 'id' and @'const' 'empty'@, but for example @'map' f@ or
+-- @'filterWithKey' f@ could be used for any @f@.
+
+mergeWithKey :: Ord k => (k -> a -> b -> Maybe c) -> (Map k a -> Map k c) -> (Map k b -> Map k c)
+             -> Map k a -> Map k b -> Map k c
+mergeWithKey f g1 g2 = go
+  where
+    go Tip t2 = g2 t2
+    go t1 Tip = g1 t1
+    go t1 t2 = hedgeMerge NothingS NothingS t1 t2
+
+    hedgeMerge _   _   t1  Tip = g1 t1
+    hedgeMerge blo bhi Tip (Bin _ kx x l r) = g2 $ join kx x (filterGt blo l) (filterLt bhi r)
+    hedgeMerge blo bhi (Bin _ kx x l r) t2 = let l' = hedgeMerge blo bmi l (trim blo bmi t2)
+                                                 (found, trim_t2) = trimLookupLo kx bhi t2
+                                                 r' = hedgeMerge bmi bhi r trim_t2
+                                             in case found of
+                                                  Nothing -> case g1 (singleton kx x) of
+                                                               Tip -> merge l' r'
+                                                               (Bin _ _ x' Tip Tip) -> join kx x' l' r'
+                                                               _ -> error "mergeWithKey: Given function only1 does not fulfil required conditions (see documentation)"
+                                                  Just x2 -> case f kx x x2 of
+                                                               Nothing -> merge l' r'
+                                                               Just x' -> x' `seq` join kx x' l' r'
+      where bmi = JustS kx
+{-# INLINE mergeWithKey #-}
+
 {--------------------------------------------------------------------
   Filter and partition
 --------------------------------------------------------------------}
index 77b2236..4b2817f 100644 (file)
@@ -8,6 +8,7 @@ import Data.Map.Lazy as Data.Map
 
 import Data.Monoid
 import Data.Maybe hiding (mapMaybe)
+import qualified Data.Maybe as Maybe (mapMaybe)
 import Data.Ord
 import Data.Function
 import Prelude hiding (lookup, null, map, filter, foldr, foldl)
@@ -152,6 +153,7 @@ main = defaultMainWithOpts
          , testProperty "intersectionWithModel" prop_intersectionWithModel
          , testProperty "intersectionWithKey"  prop_intersectionWithKey
          , testProperty "intersectionWithKeyModel" prop_intersectionWithKeyModel
+         , testProperty "mergeWithKey model"   prop_mergeWithKeyModel
          , testProperty "fromAscList"          prop_ordered
          , testProperty "fromList then toList" prop_list
          , testProperty "toDescList"           prop_descList
@@ -935,6 +937,41 @@ prop_intersectionWithKeyModel xs ys
           ys' = List.nubBy ((==) `on` fst) ys
           f k l r = k + 2 * l + 3 * r
 
+prop_mergeWithKeyModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
+prop_mergeWithKeyModel xs ys
+  = and [ testMergeWithKey f keep_x keep_y
+        | f <- [ \_k x1  _x2 -> Just x1
+               , \_k _x1 x2  -> Just x2
+               , \_k _x1 _x2 -> Nothing
+               , \k  x1  x2  -> if k `mod` 2 == 0 then Nothing else Just (2 * x1 + 3 * x2)
+               ]
+        , keep_x <- [ True, False ]
+        , keep_y <- [ True, False ]
+        ]
+
+    where xs' = List.nubBy ((==) `on` fst) xs
+          ys' = List.nubBy ((==) `on` fst) ys
+
+          xm = fromList xs'
+          ym = fromList ys'
+
+          testMergeWithKey f keep_x keep_y
+            = toList (mergeWithKey f (keep keep_x) (keep keep_y) xm ym) == emulateMergeWithKey f keep_x keep_y
+              where keep False _ = empty
+                    keep True  m = m
+
+                    emulateMergeWithKey f keep_x keep_y
+                      = Maybe.mapMaybe combine (sort $ List.union (List.map fst xs') (List.map fst ys'))
+                        where combine k = case (List.lookup k xs', List.lookup k ys') of
+                                            (Nothing, Just y) -> if keep_y then Just (k, y) else Nothing
+                                            (Just x, Nothing) -> if keep_x then Just (k, x) else Nothing
+                                            (Just x, Just y) -> (\v -> (k, v)) `fmap` f k x y
+
+          -- We prevent inlining testMergeWithKey to disable the SpecConstr
+          -- optimalization. There are too many call patterns here so several
+          -- warnings are issued if testMergeWithKey gets inlined.
+          {-# NOINLINE testMergeWithKey #-}
+
 ----------------------------------------------------------------
 
 prop_ordered :: Property