Define Map.{union,difference,intersection}WithKey using mergeWithKey.
authorMilan Straka <fox@ucw.cz>
Sat, 28 Apr 2012 14:36:31 +0000 (16:36 +0200)
committerMilan Straka <fox@ucw.cz>
Sat, 28 Apr 2012 14:36:31 +0000 (16:36 +0200)
The resulting implementations are approximately 40-50% faster, although
for some input data the performance is worse. This happens
* in unionWithKey, if the data are disjunct: 15% slowdown
* in differenceWithKey, as now we recurse over the first tree and not
  the second. The slowdown happens also only for disjunct data: 30%.

See the SetOperations benchmark for yourself if you are interested.

Data/Map/Base.hs
Data/Map/Strict.hs

index 2402289..e4ee00e 100644 (file)
@@ -250,7 +250,6 @@ module Data.Map.Base (
             , delta
             , join
             , merge
-            , splitLookupWithKey
             , glue
             , trim
             , trimLookupLo
@@ -1264,36 +1263,11 @@ unionWith f m1 m2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWithKey _ Tip t2  = t2
-unionWithKey _ t1 Tip  = t1
-unionWithKey f t1 t2 = hedgeUnionWithKey f NothingS NothingS t1 t2
+unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE unionWithKey #-}
 #endif
 
-hedgeUnionWithKey :: Ord a
-                  => (a -> b -> b -> b)
-                  -> MaybeS a -> MaybeS a
-                  -> Map a b -> Map a b
-                  -> Map a b
-hedgeUnionWithKey _ _     _     t1 Tip
-  = t1
-hedgeUnionWithKey _ blo bhi Tip (Bin _ kx x l r)
-  = join kx x (filterGt blo l) (filterLt bhi r)
-hedgeUnionWithKey f blo bhi (Bin _ kx x l r) t2
-  = join kx newx (hedgeUnionWithKey f blo bmi l lt)
-                 (hedgeUnionWithKey f bmi bhi r gt)
-  where
-    bmi        = JustS kx
-    lt         = trim blo bmi t2
-    (found,gt) = trimLookupLo kx bhi t2
-    newx       = case found of
-                   Nothing -> x
-                   Just (_,y) -> f kx x y
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE hedgeUnionWithKey #-}
-#endif
-
 {--------------------------------------------------------------------
   Difference
 --------------------------------------------------------------------}
@@ -1350,40 +1324,11 @@ differenceWith f m1 m2
 -- >     == singleton 3 "3:b|B"
 
 differenceWithKey :: Ord k => (k -> a -> b -> Maybe a) -> Map k a -> Map k b -> Map k a
-differenceWithKey _ Tip _   = Tip
-differenceWithKey _ t1 Tip  = t1
-differenceWithKey f t1 t2   = hedgeDiffWithKey f NothingS NothingS t1 t2
+differenceWithKey f t1 t2 = mergeWithKey f id (const Tip) t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE differenceWithKey #-}
 #endif
 
-hedgeDiffWithKey :: Ord a
-                 => (a -> b -> c -> Maybe b)
-                 -> MaybeS a -> MaybeS a
-                 -> Map a b -> Map a c
-                 -> Map a b
-hedgeDiffWithKey _ _     _     Tip _
-  = Tip
-hedgeDiffWithKey _ blo bhi (Bin _ kx x l r) Tip
-  = join kx x (filterGt blo l) (filterLt bhi r)
-hedgeDiffWithKey f blo bhi t (Bin _ kx x l r)
-  = case found of
-      Nothing -> merge tl tr
-      Just (ky,y) ->
-          case f ky y x of
-            Nothing -> merge tl tr
-            Just z  -> join ky z tl tr
-  where
-    bmi        = JustS kx
-    lt         = trim blo bmi t
-    (found,gt) = trimLookupLo kx bhi t
-    tl         = hedgeDiffWithKey f blo bmi lt l
-    tr         = hedgeDiffWithKey f bmi bhi gt r
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE hedgeDiffWithKey #-}
-#endif
-
-
 
 {--------------------------------------------------------------------
   Intersection
@@ -1432,17 +1377,7 @@ intersectionWith f m1 m2
 
 
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWithKey _ Tip _ = Tip
-intersectionWithKey _ _ Tip = Tip
-intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
-   if s1 >= s2 then
-     case splitLookupWithKey k2 t1 of
-       (lt, Just (k, x), gt) -> join k (f k x x2) (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
-       (lt, Nothing, gt) -> merge (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
-   else
-      case splitLookup k1 t2 of
-        (lt, Just x, gt) -> join k1 (f k1 x1 x) (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
-        (lt, Nothing, gt) -> merge (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
+intersectionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) (const Tip) (const Tip) t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE intersectionWithKey #-}
 #endif
@@ -2333,19 +2268,6 @@ splitLookup k t = k `seq`
 {-# INLINABLE splitLookup #-}
 #endif
 
--- | /O(log n)/.
-splitLookupWithKey :: Ord k => k -> Map k a -> (Map k a,Maybe (k,a),Map k a)
-splitLookupWithKey k t = k `seq`
-  case t of
-    Tip            -> (Tip,Nothing,Tip)
-    Bin _ kx x l r -> case compare k kx of
-      LT -> let (lt,z,gt) = splitLookupWithKey k l in (lt,z,join kx x gt r)
-      GT -> let (lt,z,gt) = splitLookupWithKey k r in (join kx x l lt,z,gt)
-      EQ -> (l,Just (kx, x),r)
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE splitLookupWithKey #-}
-#endif
-
 {--------------------------------------------------------------------
   Utility functions that maintain the balance properties of the tree.
   All constructors assume that all values in [l] < [k] and all values
index 1de7316..82b9bda 100644 (file)
@@ -716,36 +716,11 @@ unionWith f m1 m2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWithKey _ Tip t2  = t2
-unionWithKey _ t1 Tip  = t1
-unionWithKey f t1 t2 = hedgeUnionWithKey f NothingS NothingS t1 t2
+unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE unionWithKey #-}
 #endif
 
-hedgeUnionWithKey :: Ord a
-                  => (a -> b -> b -> b)
-                  -> MaybeS a -> MaybeS a
-                  -> Map a b -> Map a b
-                  -> Map a b
-hedgeUnionWithKey _ _     _     t1 Tip
-  = t1
-hedgeUnionWithKey _ blo bhi Tip (Bin _ kx x l r)
-  = join kx x (filterGt blo l) (filterLt bhi r)
-hedgeUnionWithKey f blo bhi (Bin _ kx x l r) t2
-  = newx `seq` join kx newx (hedgeUnionWithKey f blo bmi l lt)
-                            (hedgeUnionWithKey f bmi bhi r gt)
-  where
-    bmi        = JustS kx
-    lt         = trim blo bmi t2
-    (found,gt) = trimLookupLo kx bhi t2
-    newx       = case found of
-                   Nothing -> x
-                   Just (_,y) -> f kx x y
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE hedgeUnionWithKey #-}
-#endif
-
 {--------------------------------------------------------------------
   Difference
 --------------------------------------------------------------------}
@@ -779,38 +754,11 @@ differenceWith f m1 m2
 -- >     == singleton 3 "3:b|B"
 
 differenceWithKey :: Ord k => (k -> a -> b -> Maybe a) -> Map k a -> Map k b -> Map k a
-differenceWithKey _ Tip _   = Tip
-differenceWithKey _ t1 Tip  = t1
-differenceWithKey f t1 t2   = hedgeDiffWithKey f NothingS NothingS t1 t2
+differenceWithKey f t1 t2 = mergeWithKey f id (const Tip) t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE differenceWithKey #-}
 #endif
 
-hedgeDiffWithKey :: Ord a
-                 => (a -> b -> c -> Maybe b)
-                 -> MaybeS a -> MaybeS a
-                 -> Map a b -> Map a c
-                 -> Map a b
-hedgeDiffWithKey _ _     _     Tip _
-  = Tip
-hedgeDiffWithKey _ blo bhi (Bin _ kx x l r) Tip
-  = join kx x (filterGt blo l) (filterLt bhi r)
-hedgeDiffWithKey f blo bhi t (Bin _ kx x l r)
-  = case found of
-      Nothing -> merge tl tr
-      Just (ky,y) ->
-          case f ky y x of
-            Nothing -> merge tl tr
-            Just z  -> z `seq` join ky z tl tr
-  where
-    bmi        = JustS kx
-    lt         = trim blo bmi t
-    (found,gt) = trimLookupLo kx bhi t
-    tl         = hedgeDiffWithKey f blo bmi lt l
-    tr         = hedgeDiffWithKey f bmi bhi gt r
-#if __GLASGOW_HASKELL__ >= 700
-{-# INLINABLE hedgeDiffWithKey #-}
-#endif
 
 {--------------------------------------------------------------------
   Intersection
@@ -835,17 +783,7 @@ intersectionWith f m1 m2
 
 
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWithKey _ Tip _ = Tip
-intersectionWithKey _ _ Tip = Tip
-intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
-   if s1 >= s2 then
-     case splitLookupWithKey k2 t1 of
-       (lt, Just (k, x), gt) -> case f k x x2 of x' -> x' `seq` join k x' (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
-       (lt, Nothing, gt) -> merge (intersectionWithKey f lt l2) (intersectionWithKey f gt r2)
-   else
-      case splitLookup k1 t2 of
-        (lt, Just x, gt) -> case f k1 x1 x of x' -> x' `seq` join k1 x' (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
-        (lt, Nothing, gt) -> merge (intersectionWithKey f l1 lt) (intersectionWithKey f r1 gt)
+intersectionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) (const Tip) (const Tip) t1 t2
 #if __GLASGOW_HASKELL__ >= 700
 {-# INLINABLE intersectionWithKey #-}
 #endif