Continue to improve map functions
authorDavid Feuer <David.Feuer@gmail.com>
Mon, 1 Aug 2016 21:40:14 +0000 (17:40 -0400)
committerDavid Feuer <David.Feuer@gmail.com>
Wed, 3 Aug 2016 21:16:32 +0000 (17:16 -0400)
Rewrite `unionWith`, `intersectionWithKey`, etc., as independent
functions. Writing either in terms of the other leads to closures
being allocated with extra indirection for the passed function.
`mergeWithKey` misses singleton optimizations for unions. For the
rest, I think `mergeWithKey` is hard to understand, and it's not
immediately obvious how the parts are supposed to fit together.
Since it's used only to reduce *source* code size, and not actual
*generated* code size, I'd rather avoid it for the most part.
I've left `differenceWith` and `differenceWithKey` alone, as they
appear to be rather deeply tied to the concepts in `mergeWithKey`.

Data/Map/Base.hs
Data/Map/Strict.hs
Data/Set/Base.hs
Data/Utils/PtrEquality.hs
tests/map-properties.hs
tests/set-properties.hs

index adb7004..07b7127 100644 (file)
@@ -305,6 +305,7 @@ import Prelude hiding (lookup, map, filter, foldr, foldl, null)
 
 import qualified Data.Set.Base as Set
 import Data.Set.Base (Set)
+import Data.Utils.PtrEquality (ptrEq)
 import Data.Utils.StrictFold
 import Data.Utils.StrictPair
 import Data.Utils.StrictMaybe
@@ -671,13 +672,23 @@ singleton k x = Bin 1 k x Tip Tip
 insert :: Ord k => k -> a -> Map k a -> Map k a
 insert = go
   where
+    -- Unlike insertR, we only get sharing here
+    -- when the inserted value is at the same address
+    -- as the present value. We try anyway. If we decide
+    -- not to, then Data.Map.Strict should probably
+    -- get its own union implementation.
     go :: Ord k => k -> a -> Map k a -> Map k a
     go !kx x Tip = singleton kx x
-    go kx x (Bin sz ky y l r) =
+    go !kx x t@(Bin sz ky y l r) =
         case compare kx ky of
-            LT -> balanceL ky y (go kx x l) r
-            GT -> balanceR ky y l (go kx x r)
-            EQ -> Bin sz kx x l r
+            LT | l' `ptrEq` l -> t
+               | otherwise -> balanceL ky y l' r
+               where !l' = go kx x l
+            GT | r' `ptrEq` r -> t
+               | otherwise -> balanceR ky y l r'
+               where !r' = go kx x r
+            EQ | kx `ptrEq` ky && x `ptrEq` y -> t
+               | otherwise -> Bin sz kx x l r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insert #-}
 #else
@@ -693,10 +704,14 @@ insertR = go
   where
     go :: Ord k => k -> a -> Map k a -> Map k a
     go !kx x Tip = singleton kx x
-    go kx x t@(Bin _ ky y l r) =
+    go kx x t@(Bin sz ky y l r) =
         case compare kx ky of
-            LT -> balanceL ky y (go kx x l) r
-            GT -> balanceR ky y l (go kx x r)
+            LT | l' `ptrEq` l -> t
+               | otherwise -> balanceL ky y l' r
+               where !l' = go kx x l
+            GT | r' `ptrEq` r -> t
+               | otherwise -> balanceR ky y l r'
+               where !r' = go kx x r
             EQ -> t
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insertR #-}
@@ -715,13 +730,47 @@ insertR = go
 -- > insertWith (++) 5 "xxx" empty                         == singleton 5 "xxx"
 
 insertWith :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
-insertWith f = insertWithKey (\_ x' y' -> f x' y')
+insertWith = go
+  where
+    -- We have no hope of making pointer equality tricks work
+    -- here, because lazy insertWith *always* changes the tree,
+    -- either adding a new entry or replacing an element with a
+    -- thunk.
+    go :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+    go _ !kx x Tip = singleton kx x
+    go f !kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> Bin sy kx (f x y) l r
+
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insertWith #-}
 #else
 {-# INLINE insertWith #-}
 #endif
 
+-- | A helper function for 'unionWith'. When the key is already in
+-- the map, the key is left alone, not replaced. The combining
+-- function is flipped--it is applied to the old value and then the
+-- new value.
+
+insertWithR :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+insertWithR = go
+  where
+    go :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+    go _ !kx x Tip = singleton kx x
+    go f !kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> Bin sy ky (f y x) l r
+#if __GLASGOW_HASKELL__
+{-# INLINABLE insertWithR #-}
+#else
+{-# INLINE insertWithR #-}
+#endif
+
 -- | /O(log n)/. Insert with a function, combining key, new value and old value.
 -- @'insertWithKey' f key value mp@
 -- will insert the pair (key, value) into @mp@ if key does
@@ -751,6 +800,26 @@ insertWithKey = go
 {-# INLINE insertWithKey #-}
 #endif
 
+-- | A helper function for 'unionWithKey'. When the key is already in
+-- the map, the key is left alone, not replaced. The combining
+-- function is flipped--it is applied to the old value and then the
+-- new value.
+insertWithKeyR :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
+insertWithKeyR = go
+  where
+    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
+    go _ !kx x Tip = singleton kx x
+    go f kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> Bin sy ky (f ky y x) l r
+#if __GLASGOW_HASKELL__
+{-# INLINABLE insertWithKeyR #-}
+#else
+{-# INLINE insertWithKeyR #-}
+#endif
+
 -- | /O(log n)/. Combines insert operation with old value retrieval.
 -- The expression (@'insertLookupWithKey' f k x map@)
 -- is a pair where the first element is equal to (@'lookup' k map@)
@@ -770,19 +839,19 @@ insertWithKey = go
 -- See Note: Type of local 'go' function
 insertLookupWithKey :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a
                     -> (Maybe a, Map k a)
-insertLookupWithKey = go
+insertLookupWithKey f0 k0 x0 = toPair . go f0 k0 x0
   where
-    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> (Maybe a, Map k a)
-    go _ !kx x Tip = (Nothing, singleton kx x)
+    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> StrictPair (Maybe a) (Map k a)
+    go _ !kx x Tip = (Nothing :*: singleton kx x)
     go f kx x (Bin sy ky y l r) =
         case compare kx ky of
-            LT -> let !(found, l') = go f kx x l
+            LT -> let !(found :*: l') = go f kx x l
                       !t' = balanceL ky y l' r
-                  in (found, t')
-            GT -> let !(found, r') = go f kx x r
+                  in (found :*: t')
+            GT -> let !(found :*: r') = go f kx x r
                       !t' = balanceR ky y l r'
-                  in (found, t')
-            EQ -> (Just y, Bin sy kx (f kx x y) l r)
+                  in (found :*: t')
+            EQ -> (Just y :*: Bin sy kx (f kx x y) l r)
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insertLookupWithKey #-}
 #else
@@ -805,10 +874,14 @@ delete = go
   where
     go :: Ord k => k -> Map k a -> Map k a
     go !_ Tip = Tip
-    go k (Bin _ kx x l r) =
+    go k t@(Bin _ kx x l r) =
         case compare k kx of
-            LT -> balanceR kx x (go k l) r
-            GT -> balanceL kx x l (go k r)
+            LT | l' `ptrEq` l -> t
+               | otherwise -> balanceR kx x l' r
+               where !l' = go k l
+            GT | r' `ptrEq` r -> t
+               | otherwise -> balanceL kx x l r'
+               where !r' = go k r
             EQ -> glue l r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE delete #-}
@@ -913,22 +986,22 @@ updateWithKey = go
 
 -- See Note: Type of local 'go' function
 updateLookupWithKey :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a)
-updateLookupWithKey = go
+updateLookupWithKey f0 k0 = toPair . go f0 k0
  where
-   go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> (Maybe a,Map k a)
-   go _ !_ Tip = (Nothing,Tip)
+   go :: Ord k => (k -> a -> Maybe a) -> k -> Map k a -> StrictPair (Maybe a) (Map k a)
+   go _ !_ Tip = (Nothing :*: Tip)
    go f k (Bin sx kx x l r) =
           case compare k kx of
-               LT -> let !(found,l') = go f k l
+               LT -> let !(found :*: l') = go f k l
                          !t' = balanceR kx x l' r
-                     in (found, t')
-               GT -> let !(found,r') = go f k r
+                     in (found :*: t')
+               GT -> let !(found :*: r') = go f k r
                          !t' = balanceL kx x l r'
-                     in (found, t')
+                     in (found :*: t')
                EQ -> case f kx x of
-                       Just x' -> (Just x', Bin sx kx x' l r)
+                       Just x' -> (Just x' :*: Bin sx kx x' l r)
                        Nothing -> let !glued = glue l r
-                                  in (Just x, glued)
+                                  in (Just x :*: glued)
 #if __GLASGOW_HASKELL__
 {-# INLINABLE updateLookupWithKey #-}
 #else
@@ -1524,8 +1597,11 @@ union t1 Tip  = t1
 union t1 (Bin _ k x Tip Tip) = insertR k x t1
 union (Bin _ k x Tip Tip) t2 = insert k x t2
 union Tip t2 = t2
-union (Bin _ k1 x1 l1 r1) t2 = case split k1 t2 of
-  (l2, r2) -> link k1 x1 (union l1 l2) (union r1 r2)
+union t1@(Bin _ k1 x1 l1 r1) t2 = case split k1 t2 of
+  (l2, r2) | l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1 -> t1
+           | otherwise -> link k1 x1 l1l2 r1r2
+           where !l1l2 = union l1 l2
+                 !r1r2 = union r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE union #-}
 #endif
@@ -1538,8 +1614,17 @@ union (Bin _ k1 x1 l1 r1) t2 = case split k1 t2 of
 -- > unionWith (++) (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "aA"), (7, "C")]
 
 unionWith :: Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWith f t1 t2
-  = mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) id id t1 t2
+-- QuickCheck says pointer equality never happens here.
+unionWith _f t1 Tip = t1
+unionWith f t1 (Bin _ k x Tip Tip) = insertWithR f k x t1
+unionWith f (Bin _ k x Tip Tip) t2 = insertWith f k x t2
+unionWith _f Tip t2 = t2
+unionWith f (Bin _ k1 x1 l1 r1) t2 = case splitLookup k1 t2 of
+  (l2, mb, r2) -> case mb of
+      Nothing -> link k1 x1 l1l2 r1r2
+      Just x2 -> link k1 (f x1 x2) l1l2 r1r2
+    where !l1l2 = unionWith f l1 l2
+          !r1r2 = unionWith f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE unionWith #-}
 #endif
@@ -1551,7 +1636,16 @@ unionWith f t1 t2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
+unionWithKey _f t1 Tip = t1
+unionWithKey f t1 (Bin _ k x Tip Tip) = insertWithKeyR f k x t1
+unionWithKey f (Bin _ k x Tip Tip) t2 = insertWithKey f k x t2
+unionWithKey _f Tip t2 = t2
+unionWithKey f (Bin _ k1 x1 l1 r1) t2 = case splitLookup k1 t2 of
+  (l2, mb, r2) -> case mb of
+      Nothing -> link k1 x1 l1l2 r1r2
+      Just x2 -> link k1 (f k1 x1 x2) l1l2 r1r2
+    where !l1l2 = unionWithKey f l1 l2
+          !r1r2 = unionWithKey f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE unionWithKey #-}
 #endif
@@ -1567,8 +1661,13 @@ unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
 difference :: Ord k => Map k a -> Map k b -> Map k a
 difference Tip _   = Tip
 difference t1 Tip  = t1
-difference t1 (Bin _ k _ l2 r2) = case split k t1 of
-  (l1, r1) -> merge (difference l1 l2) (difference r1 r2)
+difference t1 (Bin _ k _ l2 r2) = case splitMember k t1 of
+  (l1, b, r1)
+     | not b && l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1 -> t1
+     | otherwise -> merge l1l2 r1r2
+     where
+       !l1l2 = difference l1 l2
+       !r1r2 = difference r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE difference #-}
 #endif
@@ -1584,8 +1683,13 @@ difference t1 (Bin _ k _ l2 r2) = case split k t1 of
 withoutKeys :: Ord k => Map k a -> Set k -> Map k a
 withoutKeys Tip _ = Tip
 withoutKeys m Set.Tip = m
-withoutKeys m (Set.Bin _ k ls rs) = case split k m of
-  (lm, rm) -> merge (withoutKeys lm ls) (withoutKeys rm rs)
+withoutKeys m (Set.Bin _ k ls rs) = case splitMember k m of
+  (lm, b, rm)
+     | not b && lm' `ptrEq` lm && rm' `ptrEq` rm -> m
+     | otherwise -> merge lm' rm'
+     where
+       !lm' = withoutKeys lm ls
+       !rm' = withoutKeys rm rs
 #if __GLASGOW_HASKELL__
 {-# INLINABLE withoutKeys #-}
 #endif
@@ -1633,11 +1737,13 @@ differenceWithKey f t1 t2 = mergeWithKey f id (const Tip) t1 t2
 intersection :: Ord k => Map k a -> Map k b -> Map k a
 intersection Tip _ = Tip
 intersection _ Tip = Tip
-intersection (Bin _ k x l1 r1) t2 = case mb of
-  Nothing -> merge l1l2 r1r2
-  Just _ -> link k x l1l2 r1r2
+intersection t1@(Bin _ k x l1 r1) t2
+  | mb = if l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1
+         then t1
+         else link k x l1l2 r1r2
+  | otherwise = merge l1l2 r1r2
   where
-    !(l2, mb, r2) = splitLookup k t2
+    !(l2, mb, r2) = splitMember k t2
     !l1l2 = intersection l1 l2
     !r1r2 = intersection r1 r2
 #if __GLASGOW_HASKELL__
@@ -1655,8 +1761,10 @@ intersection (Bin _ k x l1 r1) t2 = case mb of
 restrictKeys :: Ord k => Map k a -> Set k -> Map k a
 restrictKeys Tip _ = Tip
 restrictKeys _ Set.Tip = Tip
-restrictKeys (Bin _ k x l1 r1) s
-  | b = link k x l1l2 r1r2
+restrictKeys m@(Bin _ k x l1 r1) s
+  | b = if l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1
+        then m
+        else link k x l1l2 r1r2
   | otherwise = merge l1l2 r1r2
   where
     !(l2, b, r2) = Set.splitMember k s
@@ -1671,8 +1779,17 @@ restrictKeys (Bin _ k x l1 r1) s
 -- > intersectionWith (++) (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "aA"
 
 intersectionWith :: Ord k => (a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWith f t1 t2 =
-  mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) (const Tip) (const Tip) t1 t2
+-- We have no hope of pointer equality tricks here because every single
+-- element in the result will be a thunk.
+intersectionWith f Tip _ = Tip
+intersectionWith f _ Tip = Tip
+intersectionWith f t1@(Bin _ k x1 l1 r1) t2 = case mb of
+    Just x2 -> link k (f x1 x2) l1l2 r1r2
+    Nothing -> merge l1l2 r1r2
+  where
+    !(l2, mb, r2) = splitLookup k t2
+    !l1l2 = intersectionWith f l1 l2
+    !r1r2 = intersectionWith f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE intersectionWith #-}
 #endif
@@ -1683,8 +1800,15 @@ intersectionWith f t1 t2 =
 -- > intersectionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "5:a|A"
 
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWithKey f t1 t2 =
-  mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) (const Tip) (const Tip) t1 t2
+intersectionWithKey f Tip _ = Tip
+intersectionWithKey f _ Tip = Tip
+intersectionWithKey f t1@(Bin _ k x1 l1 r1) t2 = case mb of
+    Just x2 -> link k (f k x1 x2) l1l2 r1r2
+    Nothing -> merge l1l2 r1r2
+  where
+    !(l2, mb, r2) = splitLookup k t2
+    !l1l2 = intersectionWithKey f l1 l2
+    !r1r2 = intersectionWithKey f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE intersectionWithKey #-}
 #endif
@@ -1854,9 +1978,13 @@ filter p m
 
 filterWithKey :: (k -> a -> Bool) -> Map k a -> Map k a
 filterWithKey _ Tip = Tip
-filterWithKey p (Bin _ kx x l r)
-  | p kx x    = link kx x (filterWithKey p l) (filterWithKey p r)
-  | otherwise = merge (filterWithKey p l) (filterWithKey p r)
+filterWithKey p t@(Bin _ kx x l r)
+  | p kx x    = if pl `ptrEq` l && pr `ptrEq` r
+                then t
+                else link kx x pl pr
+  | otherwise = merge pl pr
+  where !pl = filterWithKey p l
+        !pr = filterWithKey p r
 
 -- | /O(n)/. Partition the map according to a predicate. The first
 -- map contains all elements that satisfy the predicate, the second all
@@ -1882,9 +2010,14 @@ partitionWithKey :: (k -> a -> Bool) -> Map k a -> (Map k a,Map k a)
 partitionWithKey p0 t0 = toPair $ go p0 t0
   where
     go _ Tip = (Tip :*: Tip)
-    go p (Bin _ kx x l r)
-      | p kx x    = link kx x l1 r1 :*: merge l2 r2
-      | otherwise = merge l1 r1 :*: link kx x l2 r2
+    go p t@(Bin _ kx x l r)
+      | p kx x    = (if l1 `ptrEq` l && r1 `ptrEq` r
+                     then t
+                     else link kx x l1 r1) :*: merge l2 r2
+      | otherwise = merge l1 r1 :*:
+                    (if l2 `ptrEq` l && r2 `ptrEq` r
+                     then t
+                     else link kx x l2 r2)
       where
         (l1 :*: l2) = go p l
         (r1 :*: r2) = go p r
@@ -2646,7 +2779,7 @@ split !k0 t0 = toPair $ go k0 t0
   where
     go k t =
       case t of
-        Tip            -> (Tip :*: Tip)
+        Tip            -> Tip :*: Tip
         Bin _ kx x l r -> case compare k kx of
           LT -> let (lt :*: gt) = go k l in lt :*: link kx x gt r
           GT -> let (lt :*: gt) = go k r in link kx x l lt :*: gt
@@ -2663,28 +2796,48 @@ split !k0 t0 = toPair $ go k0 t0
 -- > splitLookup 4 (fromList [(5,"a"), (3,"b")]) == (singleton 3 "b", Nothing, singleton 5 "a")
 -- > splitLookup 5 (fromList [(5,"a"), (3,"b")]) == (singleton 3 "b", Just "a", empty)
 -- > splitLookup 6 (fromList [(5,"a"), (3,"b")]) == (fromList [(3,"b"), (5,"a")], Nothing, empty)
-
 splitLookup :: Ord k => k -> Map k a -> (Map k a,Maybe a,Map k a)
-splitLookup k m = case splitLookupS k m of
-  StrictTriple l mv r -> (l, mv, r)
+splitLookup k m = case go k m of
+     StrictTriple l mv r -> (l, mv, r)
+  where
+    go :: Ord k => k -> Map k a -> StrictTriple (Map k a) (Maybe a) (Map k a)
+    go !k t =
+      case t of
+        Tip            -> StrictTriple Tip Nothing Tip
+        Bin _ kx x l r -> case compare k kx of
+          LT -> let StrictTriple lt z gt = go k l
+                    !gt' = link kx x gt r
+                in StrictTriple lt z gt'
+          GT -> let StrictTriple lt z gt = go k r
+                    !lt' = link kx x l lt
+                in StrictTriple lt' z gt
+          EQ -> StrictTriple l (Just x) r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE splitLookup #-}
 #endif
 
-splitLookupS :: Ord k => k -> Map k a -> StrictTriple (Map k a) (Maybe a) (Map k a)
-splitLookupS !k t =
-  case t of
-    Tip            -> StrictTriple Tip Nothing Tip
-    Bin _ kx x l r -> case compare k kx of
-      LT -> let StrictTriple lt z gt = splitLookupS k l
-                !gt' = link kx x gt r
-            in StrictTriple lt z gt'
-      GT -> let StrictTriple lt z gt = splitLookupS k r
-                !lt' = link kx x l lt
-            in StrictTriple lt' z gt
-      EQ -> StrictTriple l (Just x) r
-#if __GLASGOW_HASKELL__
-{-# INLINABLE splitLookupS #-}
+-- | A variant of 'splitLookup' that indicates only whether the
+-- key was present, rather than producing its value. This is used to
+-- implement 'intersection' to avoid allocating unnecessary 'Just'
+-- constructors.
+splitMember :: Ord k => k -> Map k a -> (Map k a,Bool,Map k a)
+splitMember k m = case go k m of
+     StrictTriple l mv r -> (l, mv, r)
+  where
+    go :: Ord k => k -> Map k a -> StrictTriple (Map k a) Bool (Map k a)
+    go !k t =
+      case t of
+        Tip            -> StrictTriple Tip False Tip
+        Bin _ kx x l r -> case compare k kx of
+          LT -> let StrictTriple lt z gt = go k l
+                    !gt' = link kx x gt r
+                in StrictTriple lt z gt'
+          GT -> let StrictTriple lt z gt = go k r
+                    !lt' = link kx x l lt
+                in StrictTriple lt' z gt
+          EQ -> StrictTriple l True r
+#if __GLASGOW_HASKELL__
+{-# INLINABLE splitMember #-}
 #endif
 
 data StrictTriple a b c = StrictTriple !a !b !c
index a837304..ba059ec 100644 (file)
@@ -325,6 +325,17 @@ import Data.Functor.Identity (Identity (..))
 -- > map (\ v -> undefined) m  ==  undefined      -- m is not empty
 -- > mapKeys (\ k -> undefined) m  ==  undefined  -- m is not empty
 
+-- [Note: Pointer equality for sharing]
+--
+-- We use pointer equality to enhance sharing between the arguments
+-- of some functions and their results. Notably, we use it
+-- for insert, delete, union, intersection, and difference. We do
+-- *not* use it for functions, like insertWith, unionWithKey,
+-- intersectionWith, etc., that allow the user to modify the elements.
+-- While we *could* do so, we would only get sharing under fairly
+-- narrow conditions and at a relatively high cost. It does not seem
+-- worth the price.
+
 {--------------------------------------------------------------------
   Query
 --------------------------------------------------------------------}
@@ -404,13 +415,37 @@ insert = go
 -- > insertWith (++) 5 "xxx" empty                         == singleton 5 "xxx"
 
 insertWith :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
-insertWith f = insertWithKey (\_ x' y' -> f x' y')
+insertWith = go
+  where
+    go :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+    go _ !kx x Tip = singleton kx x
+    go f !kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> let !y' = f x y in Bin sy kx y' l r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insertWith #-}
 #else
 {-# INLINE insertWith #-}
 #endif
 
+insertWithR :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+insertWithR = go
+  where
+    go :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
+    go _ !kx x Tip = singleton kx x
+    go f !kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> let !y' = f y x in Bin sy ky y' l r
+#if __GLASGOW_HASKELL__
+{-# INLINABLE insertWithR #-}
+#else
+{-# INLINE insertWithR #-}
+#endif
+
 -- | /O(log n)/. Insert with a function, combining key, new value and old value.
 -- @'insertWithKey' f key value mp@
 -- will insert the pair (key, value) into @mp@ if key does
@@ -443,6 +478,25 @@ insertWithKey = go
 {-# INLINE insertWithKey #-}
 #endif
 
+insertWithKeyR :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
+insertWithKeyR = go
+  where
+    go :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
+    -- Forcing `kx` may look redundant, but it's possible `compare` will
+    -- be lazy.
+    go _ !kx x Tip = singleton kx x
+    go f kx x (Bin sy ky y l r) =
+        case compare kx ky of
+            LT -> balanceL ky y (go f kx x l) r
+            GT -> balanceR ky y l (go f kx x r)
+            EQ -> let !y' = f ky y x
+                  in Bin sy ky y' l r
+#if __GLASGOW_HASKELL__
+{-# INLINABLE insertWithKeyR #-}
+#else
+{-# INLINE insertWithKeyR #-}
+#endif
+
 -- | /O(log n)/. Combines insert operation with old value retrieval.
 -- The expression (@'insertLookupWithKey' f k x map@)
 -- is a pair where the first element is equal to (@'lookup' k map@)
@@ -800,7 +854,13 @@ unionsWith f ts
 -- > unionWith (++) (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "aA"), (7, "C")]
 
 unionWith :: Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWith f t1 t2 = mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) id id t1 t2
+unionWith _f t1 Tip = t1
+unionWith f t1 (Bin _ k x Tip Tip) = insertWithR f k x t1
+unionWith f (Bin _ k x Tip Tip) t2 = insertWith f k x t2
+unionWith _f Tip t2 = t2
+unionWith f (Bin _ k1 x1 l1 r1) t2 = case splitLookup k1 t2 of
+  (l2, mb, r2) -> link k1 x1' (unionWith f l1 l2) (unionWith f r1 r2)
+    where !x1' = maybe x1 (f x1) mb
 #if __GLASGOW_HASKELL__
 {-# INLINABLE unionWith #-}
 #endif
@@ -812,7 +872,13 @@ unionWith f t1 t2 = mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) id id t1 t2
 -- > unionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == fromList [(3, "b"), (5, "5:a|A"), (7, "C")]
 
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
-unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
+unionWithKey _f t1 Tip = t1
+unionWithKey f t1 (Bin _ k x Tip Tip) = insertWithKeyR f k x t1
+unionWithKey f (Bin _ k x Tip Tip) t2 = insertWithKey f k x t2
+unionWithKey _f Tip t2 = t2
+unionWithKey f (Bin _ k1 x1 l1 r1) t2 = case splitLookup k1 t2 of
+  (l2, mb, r2) -> link k1 x1' (unionWithKey f l1 l2) (unionWithKey f r1 r2)
+    where !x1' = maybe x1 (f k1 x1) mb
 #if __GLASGOW_HASKELL__
 {-# INLINABLE unionWithKey #-}
 #endif
@@ -862,7 +928,15 @@ differenceWithKey f t1 t2 = mergeWithKey f id (const Tip) t1 t2
 -- > intersectionWith (++) (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "aA"
 
 intersectionWith :: Ord k => (a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWith f t1 t2 = mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) (const Tip) (const Tip) t1 t2
+intersectionWith f Tip _ = Tip
+intersectionWith f _ Tip = Tip
+intersectionWith f t1@(Bin _ k x1 l1 r1) t2 = case mb of
+    Just x2 -> let !x1' = f x1 x2 in link k x1' l1l2 r1r2
+    Nothing -> merge l1l2 r1r2
+  where
+    !(l2, mb, r2) = splitLookup k t2
+    !l1l2 = intersectionWith f l1 l2
+    !r1r2 = intersectionWith f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE intersectionWith #-}
 #endif
@@ -873,7 +947,15 @@ intersectionWith f t1 t2 = mergeWithKey (\_ x1 x2 -> Just $ f x1 x2) (const Tip)
 -- > intersectionWithKey f (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "5:a|A"
 
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
-intersectionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) (const Tip) (const Tip) t1 t2
+intersectionWithKey f Tip _ = Tip
+intersectionWithKey f _ Tip = Tip
+intersectionWithKey f t1@(Bin _ k x1 l1 r1) t2 = case mb of
+    Just x2 -> let !x1' = f k x1 x2 in link k x1' l1l2 r1r2
+    Nothing -> merge l1l2 r1r2
+  where
+    !(l2, mb, r2) = splitLookup k t2
+    !l1l2 = intersectionWithKey f l1 l2
+    !r1r2 = intersectionWithKey f r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE intersectionWithKey #-}
 #endif
index 487db12..7a5c8bb 100644 (file)
@@ -702,10 +702,15 @@ partition :: (a -> Bool) -> Set a -> (Set a,Set a)
 partition p0 t0 = toPair $ go p0 t0
   where
     go _ Tip = (Tip :*: Tip)
-    go p (Bin _ x l r) = case (go p l, go p r) of
+    go p t@(Bin _ x l r) = case (go p l, go p r) of
       ((l1 :*: l2), (r1 :*: r2))
-        | p x       -> link x l1 r1 :*: merge l2 r2
-        | otherwise -> merge l1 r1 :*: link x l2 r2
+        | p x       -> (if l1 `ptrEq` l && r1 `ptrEq` r
+                        then t
+                        else link x l1 r1) :*: merge l2 r2
+        | otherwise -> merge l1 r1 :*:
+                       (if l2 `ptrEq` l && r2 `ptrEq` r
+                        then t
+                        else link x l2 r2)
 
 {----------------------------------------------------------------------
   Map
index 324ef40..ca89af5 100644 (file)
@@ -7,16 +7,20 @@ module Data.Utils.PtrEquality (ptrEq) where
 
 #ifdef __GLASGOW_HASKELL__
 import GHC.Exts ( reallyUnsafePtrEquality# )
+import Unsafe.Coerce (unsafeCoerce)
 #if __GLASGOW_HASKELL__ < 707
 import GHC.Exts ( (==#) )
 #else
 import GHC.Exts ( isTrue# )
 #endif
+#endif
 
 -- | Checks if two pointers are equal. Yes means yes;
 -- no means maybe. The values should be forced to at least
 -- WHNF before comparison to get moderately reliable results.
 ptrEq :: a -> a -> Bool
+
+#ifdef __GLASGOW_HASKELL__
 #if __GLASGOW_HASKELL__ < 707
 ptrEq x y = reallyUnsafePtrEquality# x y ==# 1#
 #else
@@ -24,7 +28,7 @@ ptrEq x y = isTrue# (reallyUnsafePtrEquality# x y)
 #endif
 
 #else
-ptrEq :: a -> a -> Bool
+-- Not GHC
 ptrEq _ _ = False
 #endif
 
index 0f25bbd..ffd838f 100644 (file)
@@ -6,7 +6,7 @@ import Data.Map.Strict as Data.Map
 import Data.Map.Lazy as Data.Map
 #endif
 
-import Control.Applicative (Const(Const, getConst), pure)
+import Control.Applicative (Const(Const, getConst), pure, (<$>), (<*>))
 import Data.Functor.Identity (Identity(runIdentity))
 import Data.Monoid
 import Data.Maybe hiding (mapMaybe)
@@ -26,6 +26,7 @@ import Test.HUnit hiding (Test, Testable)
 import Test.QuickCheck
 import Test.QuickCheck.Function (Fun (..), apply)
 import Test.QuickCheck.Poly (A)
+import Control.Arrow (first)
 
 default (Int)
 
@@ -137,6 +138,7 @@ main = defaultMain
          , testCase "minViewWithKey" test_minViewWithKey
          , testCase "maxViewWithKey" test_maxViewWithKey
          , testCase "valid" test_valid
+         , testProperty "unionWith3"           prop_unionWith3
          , testProperty "valid"                prop_valid
          , testProperty "insert to singleton"  prop_singleton
          , testProperty "insert"               prop_insert
@@ -236,6 +238,18 @@ instance (Enum k,Arbitrary a) => Arbitrary (Map k a) where
                                         ; return (bin (toEnum i) x l r)
                                         }
 
+-- A type with a peculiar Eq instance designed to make sure keys
+-- come from where they're supposed to.
+data OddEq a = OddEq Bool a deriving (Show)
+getOddEq :: OddEq a -> (Bool, a)
+getOddEq (OddEq b a) = (b, a)
+instance Arbitrary a => Arbitrary (OddEq a) where
+  arbitrary = OddEq <$> arbitrary <*> arbitrary
+instance Eq a => Eq (OddEq a) where
+  OddEq _ x == OddEq _ y = x == y
+instance Ord a => Ord (OddEq a) where
+  OddEq _ x `compare` OddEq _ y = x `compare` y
+
 ------------------------------------------------------------------------
 
 type UMap = Map Int ()
@@ -970,6 +984,22 @@ prop_unionWith t1 t2 = (union t1 t2 == unionWith (\_ y -> y) t2 t1)
 prop_unionWith2 :: IMap -> IMap -> Bool
 prop_unionWith2 t1 t2 = valid (unionWithKey (\_ x y -> x+y) t1 t2)
 
+prop_unionWith3 :: Fun (Int,Int) Int -> IMap -> IMap -> Property
+prop_unionWith3 f t1 t2 = valid uw .&&. uwUndone === uwEasyUndone
+  where
+    t1' :: Map (OddEq Int) Int
+    t1' = mapKeysMonotonic (OddEq False) t1
+    t2' :: Map (OddEq Int) Int
+    t2' = mapKeysMonotonic (OddEq True) t2
+    uw :: Map (OddEq Int) Int
+    uw = unionWith (apply2 f) t1' t2'
+    uwUndone :: [((Bool, Int), Int)]
+    uwUndone = first getOddEq <$> toList uw
+    uwEasy :: Map (OddEq Int) Int
+    uwEasy = List.foldl' (\t (k1, v1) -> insertWith (apply2 f) k1 v1 t) t2' (toList t1')
+    uwEasyUndone :: [((Bool, Int), Int)]
+    uwEasyUndone = first getOddEq <$> toList uwEasy
+
 prop_unionSum :: [(Int,Int)] -> [(Int,Int)] -> Bool
 prop_unionSum xs ys
   = sum (elems (unionWith (+) (fromListWith (+) xs) (fromListWith (+) ys)))
index 029110d..bc8c5c4 100644 (file)
@@ -1,3 +1,4 @@
+{-# LANGUAGE CPP #-}
 import qualified Data.IntSet as IntSet
 import Data.List (nub,sort)
 import qualified Data.List as List
@@ -10,6 +11,11 @@ import Test.Framework.Providers.HUnit
 import Test.Framework.Providers.QuickCheck2
 import Test.HUnit hiding (Test, Testable)
 import Test.QuickCheck
+import Test.QuickCheck.Function
+import Test.QuickCheck.Poly
+#if !MIN_VERSION_base(4,8,0)
+import Control.Applicative (Applicative (..), (<$>))
+#endif
 
 main :: IO ()
 main = defaultMain [ testCase "lookupLT" test_lookupLT
@@ -62,6 +68,8 @@ main = defaultMain [ testCase "lookupLT" test_lookupLT
                    , testProperty "prop_foldL" prop_foldL
                    , testProperty "prop_foldL'" prop_foldL'
                    , testProperty "prop_map" prop_map
+                   , testProperty "prop_map2" prop_map2
+                   , testProperty "prop_mapMonotonic" prop_mapMonotonic
                    , testProperty "prop_maxView" prop_maxView
                    , testProperty "prop_minView" prop_minView
                    , testProperty "prop_split" prop_split
@@ -71,6 +79,19 @@ main = defaultMain [ testCase "lookupLT" test_lookupLT
                    , testProperty "prop_filter" prop_filter
                    ]
 
+-- A type with a peculiar Eq instance designed to make sure keys
+-- come from where they're supposed to.
+data OddEq a = OddEq Bool a deriving (Show)
+
+getOddEq :: OddEq a -> (Bool, a)
+getOddEq (OddEq b a) = (b, a)
+instance Arbitrary a => Arbitrary (OddEq a) where
+  arbitrary = OddEq <$> arbitrary <*> arbitrary
+instance Eq a => Eq (OddEq a) where
+  OddEq _ x == OddEq _ y = x == y
+instance Ord a => Ord (OddEq a) where
+  OddEq _ x `compare` OddEq _ y = x `compare` y
+
 ----------------------------------------------------------------
 -- Unit tests
 ----------------------------------------------------------------
@@ -358,6 +379,12 @@ prop_foldL' s = foldl' (flip (:)) [] s == List.foldl' (flip (:)) [] (toList s)
 prop_map :: Set Int -> Bool
 prop_map s = map id s == s
 
+prop_map2 :: Fun Int Int -> Fun Int Int -> Set Int -> Property
+prop_map2 f g s = map (apply f) (map (apply g) s) === map (apply f . apply g) s
+
+prop_mapMonotonic :: Set Int -> Property
+prop_mapMonotonic s = mapMonotonic id s === s
+
 prop_maxView :: Set Int -> Bool
 prop_maxView s = case maxView s of
     Nothing -> null s