Use pointer equality to enhance sharing for Sets
authorDavid Feuer <David.Feuer@gmail.com>
Mon, 1 Aug 2016 17:10:00 +0000 (13:10 -0400)
committerDavid Feuer <David.Feuer@gmail.com>
Mon, 1 Aug 2016 17:10:00 +0000 (13:10 -0400)
Use pointer equality to avoid allocating new copies of existing
structures. This helps a number of benchmarks a *lot*. Unfortunately,
it hurts some others a little.

Data/Set/Base.hs
Data/Utils/PtrEquality.hs [new file with mode: 0644]
containers.cabal

index 1885be7..487db12 100644 (file)
@@ -212,6 +212,7 @@ import Control.DeepSeq (NFData(rnf))
 
 import Data.Utils.StrictFold
 import Data.Utils.StrictPair
+import Data.Utils.PtrEquality
 
 #if __GLASGOW_HASKELL__
 import GHC.Exts ( build )
@@ -485,10 +486,15 @@ insert = go
   where
     go :: Ord a => a -> Set a -> Set a
     go !x Tip = singleton x
-    go x (Bin sz y l r) = case compare x y of
-        LT -> balanceL y (go x l) r
-        GT -> balanceR y l (go x r)
-        EQ -> Bin sz x l r
+    go !x t@(Bin sz y l r) = case compare x y of
+        LT | l' `ptrEq` l -> t
+           | otherwise -> balanceL y l' r
+           where !l' = go x l
+        GT | r' `ptrEq` r -> t
+           | otherwise -> balanceR y l r'
+           where !r' = go x r
+        EQ | x `ptrEq` y -> t
+           | otherwise -> Bin sz x l r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insert #-}
 #else
@@ -504,9 +510,13 @@ insertR = go
   where
     go :: Ord a => a -> Set a -> Set a
     go !x Tip = singleton x
-    go x t@(Bin _ y l r) = case compare x y of
-        LT -> balanceL y (go x l) r
-        GT -> balanceR y l (go x r)
+    go !x t@(Bin _ y l r) = case compare x y of
+        LT | l' `ptrEq` l -> t
+           | otherwise -> balanceL y l' r
+           where !l' = go x l
+        GT | r' `ptrEq` r -> t
+           | otherwise -> balanceR y l r'
+           where !r' = go x r
         EQ -> t
 #if __GLASGOW_HASKELL__
 {-# INLINABLE insertR #-}
@@ -522,9 +532,13 @@ delete = go
   where
     go :: Ord a => a -> Set a -> Set a
     go !_ Tip = Tip
-    go x (Bin _ y l r) = case compare x y of
-        LT -> balanceR y (go x l) r
-        GT -> balanceL y l (go x r)
+    go x t@(Bin _ y l r) = case compare x y of
+        LT | l' `ptrEq` l -> t
+           | otherwise -> balanceR y l' r
+           where !l' = go x l
+        GT | r' `ptrEq` r -> t
+           | otherwise -> balanceL y l r'
+           where !r' = go x r
         EQ -> glue l r
 #if __GLASGOW_HASKELL__
 {-# INLINABLE delete #-}
@@ -609,8 +623,12 @@ union t1 Tip  = t1
 union t1 (Bin _ x Tip Tip) = insertR x t1
 union (Bin _ x Tip Tip) t2 = insert x t2
 union Tip t2  = t2
-union (Bin _ x l r) t2 = case splitS x t2 of
-  (l2 :*: r2) -> link x (union l l2) (union r r2)
+union t1@(Bin _ x l1 r1) t2 = case splitS x t2 of
+  (l2 :*: r2)
+    | l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1 -> t1
+    | otherwise -> link x l1l2 r1r2
+    where !l1l2 = union l1 l2
+          !r1r2 = union r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE union #-}
 #endif
@@ -622,8 +640,12 @@ union (Bin _ x l r) t2 = case splitS x t2 of
 difference :: Ord a => Set a -> Set a -> Set a
 difference Tip _   = Tip
 difference t1 Tip  = t1
-difference t1 (Bin _ x l2 r2) = case splitS x t1 of
-   (l1 :*: r1) -> merge (difference l1 l2) (difference r1 r2)
+difference t1 (Bin _ x l2 r2) = case splitMember x t1 of
+   (l1, b, r1)
+     | not b && l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1 -> t1
+     | otherwise -> merge l1l2 r1r2
+     where !l1l2 = difference l1 l2
+           !r1r2 = difference r1 r2
 #if __GLASGOW_HASKELL__
 {-# INLINABLE difference #-}
 #endif
@@ -645,8 +667,10 @@ difference t1 (Bin _ x l2 r2) = case splitS x t1 of
 intersection :: Ord a => Set a -> Set a -> Set a
 intersection Tip _ = Tip
 intersection _ Tip = Tip
-intersection (Bin _ x l1 r1) t2
-  | b = link x l1l2 r1r2
+intersection t1@(Bin _ x l1 r1) t2
+  | b = if l1l2 `ptrEq` l1 && r1r2 `ptrEq` r1
+        then t1
+        else link x l1l2 r1r2
   | otherwise = merge l1l2 r1r2
   where
     !(l2, b, r2) = splitMember x t2
@@ -662,9 +686,14 @@ intersection (Bin _ x l1 r1) t2
 -- | /O(n)/. Filter all elements that satisfy the predicate.
 filter :: (a -> Bool) -> Set a -> Set a
 filter _ Tip = Tip
-filter p (Bin _ x l r)
-    | p x       = link x (filter p l) (filter p r)
-    | otherwise = merge (filter p l) (filter p r)
+filter p t@(Bin _ x l r)
+    | p x = if l `ptrEq` l' && r `ptrEq` r'
+            then t
+            else link x l' r'
+    | otherwise = merge l' r'
+    where
+      !l' = filter p l
+      !r' = filter p r
 
 -- | /O(n)/. Partition the set into two sets, one with all elements that satisfy
 -- the predicate and one with all elements that don't satisfy the predicate.
diff --git a/Data/Utils/PtrEquality.hs b/Data/Utils/PtrEquality.hs
new file mode 100644 (file)
index 0000000..5ab38fa
--- /dev/null
@@ -0,0 +1,26 @@
+{-# LANGUAGE CPP #-}
+#ifdef __GLASGOW_HASKELL__
+{-# LANGUAGE MagicHash #-}
+#endif
+
+module Data.Utils.PtrEquality (ptrEq) where
+
+#ifdef __GLASGOW_HASKELL__
+import GHC.Exts ( reallyUnsafePtrEquality# )
+
+-- | Checks if two pointers are equal. Yes means yes;
+-- no means maybe. The values should be forced to at least
+-- WHNF before comparison to get moderately reliable results.
+ptrEq :: a -> a -> Bool
+ptrEq x y = case reallyUnsafePtrEquality# x y of
+              1# -> True
+              _ -> False
+
+#else
+ptrEq :: a -> a -> Bool
+ptrEq _ _ = False
+#endif
+
+{-# INLINE ptrEq #-}
+
+infix 4 `ptrEq`
index a4200e3..fa593d0 100644 (file)
@@ -62,6 +62,7 @@ Library
         Data.Utils.StrictFold
         Data.Utils.StrictPair
         Data.Utils.StrictMaybe
+        Data.Utils.PtrEquality
 
     include-dirs: include