Add restrictKeys and withoutKeys
authorDavid Feuer <David.Feuer@gmail.com>
Tue, 26 Jul 2016 04:05:06 +0000 (00:05 -0400)
committerDavid Feuer <David.Feuer@gmail.com>
Tue, 26 Jul 2016 05:17:36 +0000 (01:17 -0400)
* Add `restrictKeys` and `withoutKeys` to `Data.Map` and
`Data.IntMap`.

* Add tests for the defining properties of these operations.

12 files changed:
Data/IntMap/Base.hs
Data/IntMap/Lazy.hs
Data/IntMap/Strict.hs
Data/Map/Base.hs
Data/Map/Lazy.hs
Data/Map/Strict.hs
Data/Set/Base.hs
Data/Utils/StrictMaybe.hs [new file with mode: 0644]
changelog.md
containers.cabal
tests/intmap-properties.hs
tests/map-properties.hs

index 6fcca50..2c1cf00 100644 (file)
@@ -164,6 +164,8 @@ module Data.IntMap.Base (
     -- * Filter
     , filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
@@ -958,6 +960,49 @@ differenceWithKey :: (Key -> a -> b -> Maybe a) -> IntMap a -> IntMap b -> IntMa
 differenceWithKey f m1 m2
   = mergeWithKey f id (const Nil) m1 m2
 
+-- | Remove all the keys in a given set from a map.
+--
+-- @
+-- m `withoutKeys` s = 'filterWithKey' (\k _ -> k `'IntSet.notMember'` s) m
+-- @
+--
+-- @since 0.5.8
+withoutKeys :: IntMap a -> IntSet.IntSet -> IntMap a
+withoutKeys = go
+  where
+    go t1@(Bin p1 m1 l1 r1) t2@(IntSet.Bin p2 m2 l2 r2)
+      | shorter m1 m2  = merge1
+      | shorter m2 m1  = merge2
+      | p1 == p2       = bin p1 m1 (go l1 l2) (go r1 r2)
+      | otherwise      = t1
+      where
+        merge1 | nomatch p2 p1 m1  = t1
+               | zero p2 m1        = binCheckLeft p1 m1 (go l1 t2) r1
+               | otherwise         = binCheckRight p1 m1 l1 (go r1 t2)
+        merge2 | nomatch p1 p2 m2  = t1
+               | zero p1 m2        = bin p2 m2 (go t1 l2) Nil
+               | otherwise         = bin p2 m2 Nil (go t1 r2)
+
+    go t1'@(Bin _ _ _ _) t2'@(IntSet.Tip k2' _) = merge t2' k2' t1'
+      where merge t2 k2 t1@(Bin p1 m1 l1 r1) | nomatch k2 p1 m1 = t1
+                                             | zero k2 m1 = binCheckLeft p1 m1 (merge t2 k2 l1) r1
+                                             | otherwise  = binCheckRight p1 m1 l1 (merge t2 k2 r1)
+            merge _ k2 t1@(Tip k1 _) | k1 == k2 = Nil
+                                     | otherwise = t1
+            merge _ _  Nil = Nil
+
+    go t1@(Bin _ _ _ _) IntSet.Nil = t1
+
+    go t1'@(Tip k1' _) t2' = merge t1' k1' t2'
+      where merge t1 k1 (IntSet.Bin p2 m2 l2 r2) | nomatch k1 p2 m2 = t1
+                                                 | zero k1 m2 = bin p2 m2 (merge t1 k1 l2) Nil
+                                                 | otherwise  = bin p2 m2 Nil (merge t1 k1 r2)
+            merge t1 k1 (IntSet.Tip k2 _) | k1 == k2 = Nil
+                                          | otherwise = t1
+            merge t1 _  IntSet.Nil = t1
+
+    go Nil _ = Nil
+
 
 {--------------------------------------------------------------------
   Intersection
@@ -970,6 +1015,50 @@ intersection :: IntMap a -> IntMap b -> IntMap a
 intersection m1 m2
   = mergeWithKey' bin const (const Nil) (const Nil) m1 m2
 
+-- | /O(n+m)/. The restriction of a map to the keys in a set.
+--
+-- @
+-- m `restrictKeys` s = 'filterWithKey' (\k _ -> k `'IntSet.member'` s) m
+-- @
+--
+-- @since 0.5.8
+restrictKeys :: IntMap a -> IntSet.IntSet -> IntMap a
+restrictKeys = go
+  where
+    go t1@(Bin p1 m1 l1 r1) t2@(IntSet.Bin p2 m2 l2 r2)
+      | shorter m1 m2  = merge1
+      | shorter m2 m1  = merge2
+      | p1 == p2       = bin p1 m1 (go l1 l2) (go r1 r2)
+      | otherwise      = Nil
+      where
+        merge1 | nomatch p2 p1 m1  = Nil
+               | zero p2 m1        = bin p1 m1 (go l1 t2) Nil
+               | otherwise         = bin p1 m1 Nil (go r1 t2)
+        merge2 | nomatch p1 p2 m2  = Nil
+               | zero p1 m2        = bin p2 m2 (go t1 l2) Nil
+               | otherwise         = bin p2 m2 Nil (go t1 r2)
+
+    go t1'@(Bin _ _ _ _) t2'@(IntSet.Tip k2' _) = merge t2' k2' t1'
+      where merge t2 k2 (Bin p1 m1 l1 r1) | nomatch k2 p1 m1 = Nil
+                                          | zero k2 m1 = bin p1 m1 (merge t2 k2 l1) Nil
+                                          | otherwise  = bin p1 m1 Nil (merge t2 k2 r1)
+            merge _ k2 t1@(Tip k1 _) | k1 == k2 = t1
+                                     | otherwise = Nil
+            merge _ _  Nil = Nil
+
+    go (Bin _ _ _ _) IntSet.Nil = Nil
+
+    go t1'@(Tip k1' _) t2' = merge t1' k1' t2'
+      where merge t1 k1 (IntSet.Bin p2 m2 l2 r2)
+              | nomatch k1 p2 m2 = Nil
+              | zero k1 m2 = bin p2 m2 (merge t1 k1 l2) Nil
+              | otherwise  = bin p2 m2 Nil (merge t1 k1 r2)
+            merge t1 k1 (IntSet.Tip k2 _) | k1 == k2 = t1
+                                          | otherwise = Nil
+            merge _ _  IntSet.Nil = Nil
+
+    go Nil _ = Nil
+
 -- | /O(n+m)/. The intersection with a combining function.
 --
 -- > intersectionWith (++) (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 5 "aA"
index 8283017..de4d122 100644 (file)
@@ -169,6 +169,8 @@ module Data.IntMap.Lazy (
     -- * Filter
     , IM.filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
index d020e9f..3ec5610 100644 (file)
@@ -176,6 +176,8 @@ module Data.IntMap.Strict (
     -- * Filter
     , filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
index 792206b..1576848 100644 (file)
@@ -224,6 +224,8 @@ module Data.Map.Base (
     -- * Filter
     , filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
@@ -309,8 +311,10 @@ import Data.Typeable
 import Prelude hiding (lookup, map, filter, foldr, foldl, null)
 
 import qualified Data.Set.Base as Set
+import Data.Set.Base (Set)
 import Data.Utils.StrictFold
 import Data.Utils.StrictPair
+import Data.Utils.StrictMaybe
 import Data.Utils.BitQueue
 #if DEFINE_ALTERF_FALLBACK
 import Data.Utils.BitUtil (wordSize)
@@ -1578,7 +1582,7 @@ unionWithKey f t1 t2 = mergeWithKey (\k x1 x2 -> Just $ f k x1 x2) id id t1 t2
 -- > difference (fromList [(5, "a"), (3, "b")]) (fromList [(5, "A"), (7, "C")]) == singleton 3 "b"
 
 difference :: Ord k => Map k a -> Map k b -> Map k a
-difference Tip _   = Tip
+difference Tip !_   = Tip
 difference t1 Tip  = t1
 difference t1 t2   = hedgeDiff NothingS NothingS t1 t2
 #if __GLASGOW_HASKELL__
@@ -1595,6 +1599,32 @@ hedgeDiff blo bhi t (Bin _ kx _ l r) = merge (hedgeDiff blo bmi (trim blo bmi t)
 {-# INLINABLE hedgeDiff #-}
 #endif
 
+-- | Remove all keys in a 'Set' from a 'Map'.
+--
+-- @
+-- m `withoutKeys` s = 'filterWithKey' (\k _ -> k `'Set.notMember'` s) m
+-- @
+--
+-- @since 0.5.8
+withoutKeys :: Ord k => Map k a -> Set k -> Map k a
+withoutKeys Tip !_ = Tip
+withoutKeys m Set.Tip = m
+withoutKeys m s = hedgeWithout NothingS NothingS m s
+#if __GLASGOW_HASKELL__
+{-# INLINABLE withoutKeys #-}
+#endif
+
+hedgeWithout :: Ord a => MaybeS a -> MaybeS a -> Map a b -> Set a -> Map a b
+hedgeWithout _ _ Tip _ = Tip
+hedgeWithout blo bhi (Bin _ kx x l r) Set.Tip = link kx x (filterGt blo l) (filterLt bhi r)
+hedgeWithout blo bhi t (Set.Bin _ kx l r) =
+  merge (hedgeWithout blo bmi (trim blo bmi t) l)
+        (hedgeWithout bmi bhi (trim bmi bhi t) r)
+  where bmi = JustS kx
+#if __GLASGOW_HASKELL__
+{-# INLINABLE hedgeWithout #-}
+#endif
+
 -- | /O(n+m)/. Difference with a combining function.
 -- When two equal keys are
 -- encountered, the combining function is applied to the values of these keys.
@@ -1660,6 +1690,32 @@ hedgeInt blo bhi (Bin _ kx x l r) t2 = let l' = hedgeInt blo bmi l (trim blo bmi
 {-# INLINABLE hedgeInt #-}
 #endif
 
+-- | Restrict a 'Map' to only those keys found in a 'Set'.
+--
+-- @
+-- m `restrictKeys` s = 'filterWithKey' (\k _ -> k `'Set.member'` s) m
+-- @
+--
+-- @since 0.5.8
+restrictKeys :: Ord k => Map k a -> Set k -> Map k a
+restrictKeys Tip _ = Tip
+restrictKeys _ Set.Tip = Tip
+restrictKeys t1 t2 = hedgeRestr NothingS NothingS t1 t2
+#if __GLASGOW_HASKELL__
+{-# INLINABLE restrictKeys #-}
+#endif
+
+hedgeRestr :: Ord k => MaybeS k -> MaybeS k -> Map k a -> Set k -> Map k a
+hedgeRestr _ _ _   Set.Tip = Tip
+hedgeRestr _ _ Tip _ = Tip
+hedgeRestr blo bhi (Bin _ kx x l r) t2 = let l' = hedgeRestr blo bmi l (Set.trim blo bmi t2)
+                                             r' = hedgeRestr bmi bhi r (Set.trim bmi bhi t2)
+                                       in if kx `Set.member` t2 then link kx x l' r' else merge l' r'
+  where bmi = JustS kx
+#if __GLASGOW_HASKELL__
+{-# INLINABLE hedgeRestr #-}
+#endif
+
 -- | /O(n+m)/. Intersection with a combining function.  The implementation uses
 -- an efficient /hedge/ algorithm comparable with /hedge-union/.
 --
@@ -2613,8 +2669,6 @@ fromDistinctDescList ((kx0, x0) : xs0) = go (1 :: Int) (Bin 1 kx0 x0 Tip Tip) xs
                         was found in the tree.
 --------------------------------------------------------------------}
 
-data MaybeS a = NothingS | JustS !a
-
 {--------------------------------------------------------------------
   [trim blo bhi t] trims away all subtrees that surely contain no
   values between the range [blo] to [bhi]. The returned tree is either
index b1ef990..62921ff 100644 (file)
@@ -173,6 +173,8 @@ module Data.Map.Lazy (
     -- * Filter
     , M.filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
index 2258931..301f9f3 100644 (file)
@@ -181,6 +181,8 @@ module Data.Map.Strict
     -- * Filter
     , filter
     , filterWithKey
+    , restrictKeys
+    , withoutKeys
     , partition
     , partitionWithKey
 
index 8aabd08..573507d 100644 (file)
@@ -194,6 +194,9 @@ module Data.Set.Base (
             , balanced
             , link
             , merge
+
+            -- Used by Data.Map.Base
+            , trim
             ) where
 
 import Prelude hiding (filter,foldl,foldr,null,map)
@@ -211,6 +214,7 @@ import Control.DeepSeq (NFData(rnf))
 
 import Data.Utils.StrictFold
 import Data.Utils.StrictPair
+import Data.Utils.StrictMaybe
 
 #if __GLASGOW_HASKELL__
 import GHC.Exts ( build )
@@ -1047,8 +1051,6 @@ instance NFData a => NFData (Set a) where
                         was found in the tree.
 --------------------------------------------------------------------}
 
-data MaybeS a = NothingS | JustS !a
-
 {--------------------------------------------------------------------
   [trim blo bhi t] trims away all subtrees that surely contain no
   values between the range [blo] to [bhi]. The returned tree is either
diff --git a/Data/Utils/StrictMaybe.hs b/Data/Utils/StrictMaybe.hs
new file mode 100644 (file)
index 0000000..e0f6fec
--- /dev/null
@@ -0,0 +1,21 @@
+module Data.Utils.StrictMaybe (MaybeS (..), maybeS, toMaybe, toMaybeS) where
+import Data.Foldable (Foldable (..))
+import Data.Monoid (Monoid (..))
+
+data MaybeS a = NothingS | JustS !a
+
+instance Foldable MaybeS where
+  foldMap _ NothingS = mempty
+  foldMap f (JustS a) = f a
+
+maybeS :: r -> (a -> r) -> MaybeS a -> r
+maybeS n _ NothingS = n
+maybeS _ j (JustS a) = j a
+
+toMaybe :: MaybeS a -> Maybe a
+toMaybe NothingS = Nothing
+toMaybe (JustS a) = Just a
+
+toMaybeS :: Maybe a -> MaybeS a
+toMaybeS Nothing = NothingS
+toMaybeS (Just a) = JustS a
index a714e79..0ae21cc 100644 (file)
@@ -23,7 +23,8 @@
 
 ### New exports and instances
 
-  * Add `alterF` for `Data.Map` and `Data.IntMap`.
+  * Add `alterF`, `restrictKeys`, and `withoutKeys` to `Data.Map`
+    and `Data.IntMap`.
 
   * Add `fromDescList`, `fromDescListWith`, `fromDescListWithKey`,
     and `fromDistinctDescList` to `Data.Map`.
index 71f17d8..a4200e3 100644 (file)
@@ -61,6 +61,7 @@ Library
         Data.Utils.BitUtil
         Data.Utils.StrictFold
         Data.Utils.StrictPair
+        Data.Utils.StrictMaybe
 
     include-dirs: include
 
index 3e6cc5b..21ee9f6 100644 (file)
@@ -16,7 +16,7 @@ import qualified Prelude (map)
 
 import Data.List (nub,sort)
 import qualified Data.List as List
-import qualified Data.IntSet
+import qualified Data.IntSet as IntSet
 import Test.Framework
 import Test.Framework.Providers.HUnit
 import Test.Framework.Providers.QuickCheck2
@@ -506,13 +506,13 @@ test_assocs = do
 
 test_keysSet :: Assertion
 test_keysSet = do
-    keysSet (fromList [(5,"a"), (3,"b")]) @?= Data.IntSet.fromList [3,5]
-    keysSet (empty :: UMap) @?= Data.IntSet.empty
+    keysSet (fromList [(5,"a"), (3,"b")]) @?= IntSet.fromList [3,5]
+    keysSet (empty :: UMap) @?= IntSet.empty
 
 test_fromSet :: Assertion
 test_fromSet = do
-   fromSet (\k -> replicate k 'a') (Data.IntSet.fromList [3, 5]) @?= fromList [(5,"aaaaa"), (3,"aaa")]
-   fromSet undefined Data.IntSet.empty @?= (empty :: IMap)
+   fromSet (\k -> replicate k 'a') (IntSet.fromList [3, 5]) @?= fromList [(5,"aaaaa"), (3,"aaa")]
+   fromSet undefined IntSet.empty @?= (empty :: IMap)
 
 ----------------------------------------------------------------
 -- Lists
@@ -803,6 +803,18 @@ prop_intersectionWithKeyModel xs ys
           ys' = List.nubBy ((==) `on` fst) ys
           f k l r = k + 2 * l + 3 * r
 
+prop_restrictKeys :: IMap -> IMap -> Property
+prop_restrictKeys m s0 = m `restrictKeys` s === filterWithKey (\k _ -> k `IntSet.member` s) m
+  where
+    s = keysSet s0
+    restricted = restrictKeys m s
+
+prop_withoutKeys :: IMap -> IMap -> Property
+prop_withoutKeys m s0 = m `withoutKeys` s === filterWithKey (\k _ -> k `IntSet.notMember` s) m
+  where
+    s = keysSet s0
+    reduced = withoutKeys m s
+
 prop_mergeWithKeyModel :: [(Int,Int)] -> [(Int,Int)] -> Bool
 prop_mergeWithKeyModel xs ys
   = and [ testMergeWithKey f keep_x keep_y
@@ -1055,9 +1067,9 @@ prop_foldl' n ys = length ys > 0 ==>
 
 prop_keysSet :: [(Int, Int)] -> Bool
 prop_keysSet xs =
-  keysSet (fromList xs) == Data.IntSet.fromList (List.map fst xs)
+  keysSet (fromList xs) == IntSet.fromList (List.map fst xs)
 
 prop_fromSet :: [(Int, Int)] -> Bool
 prop_fromSet ys =
   let xs = List.nubBy ((==) `on` fst) ys
-  in fromSet (\k -> fromJust $ List.lookup k xs) (Data.IntSet.fromList $ List.map fst xs) == fromList xs
+  in fromSet (\k -> fromJust $ List.lookup k xs) (IntSet.fromList $ List.map fst xs) == fromList xs
index 4c03c76..0f25bbd 100644 (file)
@@ -18,7 +18,7 @@ import qualified Prelude (map)
 
 import Data.List (nub,sort)
 import qualified Data.List as List
-import qualified Data.Set
+import qualified Data.Set as Set
 import Test.Framework
 import Test.Framework.Providers.HUnit
 import Test.Framework.Providers.QuickCheck2
@@ -159,7 +159,9 @@ main = defaultMain
          , testProperty "union sum"            prop_unionSum
          , testProperty "difference"           prop_difference
          , testProperty "difference model"     prop_differenceModel
+         , testProperty "withoutKeys"          prop_withoutKeys
          , testProperty "intersection"         prop_intersection
+         , testProperty "restrictKeys"         prop_restrictKeys
          , testProperty "intersection model"   prop_intersectionModel
          , testProperty "intersectionWith"     prop_intersectionWith
          , testProperty "intersectionWithModel" prop_intersectionWithModel
@@ -593,13 +595,13 @@ test_assocs = do
 
 test_keysSet :: Assertion
 test_keysSet = do
-    keysSet (fromList [(5,"a"), (3,"b")]) @?= Data.Set.fromList [3,5]
-    keysSet (empty :: UMap) @?= Data.Set.empty
+    keysSet (fromList [(5,"a"), (3,"b")]) @?= Set.fromList [3,5]
+    keysSet (empty :: UMap) @?= Set.empty
 
 test_fromSet :: Assertion
 test_fromSet = do
-   fromSet (\k -> replicate k 'a') (Data.Set.fromList [3, 5]) @?= fromList [(5,"aaaaa"), (3,"aaa")]
-   fromSet undefined Data.Set.empty @?= (empty :: IMap)
+   fromSet (\k -> replicate k 'a') (Set.fromList [3, 5]) @?= fromList [(5,"aaaaa"), (3,"aaa")]
+   fromSet undefined Set.empty @?= (empty :: IMap)
 
 ----------------------------------------------------------------
 -- Lists
@@ -981,6 +983,18 @@ prop_differenceModel xs ys
   = sort (keys (difference (fromListWith (+) xs) (fromListWith (+) ys)))
     == sort ((List.\\) (nub (Prelude.map fst xs)) (nub (Prelude.map fst ys)))
 
+prop_restrictKeys :: IMap -> IMap -> Property
+prop_restrictKeys m s0 = valid restricted .&&. (m `restrictKeys` s === filterWithKey (\k _ -> k `Set.member` s) m)
+  where
+    s = keysSet s0
+    restricted = restrictKeys m s
+
+prop_withoutKeys :: IMap -> IMap -> Property
+prop_withoutKeys m s0 = valid reduced .&&. (m `withoutKeys` s === filterWithKey (\k _ -> k `Set.notMember` s) m)
+  where
+    s = keysSet s0
+    reduced = withoutKeys m s
+
 prop_intersection :: IMap -> IMap -> Bool
 prop_intersection t1 t2 = valid (intersection t1 t2)
 
@@ -1289,9 +1303,9 @@ prop_foldl' n ys = length ys > 0 ==>
 
 prop_keysSet :: [(Int, Int)] -> Bool
 prop_keysSet xs =
-  keysSet (fromList xs) == Data.Set.fromList (List.map fst xs)
+  keysSet (fromList xs) == Set.fromList (List.map fst xs)
 
 prop_fromSet :: [(Int, Int)] -> Bool
 prop_fromSet ys =
   let xs = List.nubBy ((==) `on` fst) ys
-  in fromSet (\k -> fromJust $ List.lookup k xs) (Data.Set.fromList $ List.map fst xs) == fromList xs
+  in fromSet (\k -> fromJust $ List.lookup k xs) (Set.fromList $ List.map fst xs) == fromList xs