Implement indexing operations on 'Set'
authorPatrick Palka <patrick@parcs.ath.cx>
Sun, 4 Nov 2012 19:49:24 +0000 (14:49 -0500)
committerMilan Straka <fox@ucw.cz>
Sun, 11 Nov 2012 21:23:01 +0000 (22:23 +0100)
Data/Set.hs
Data/Set/Base.hs
tests/set-properties.hs

index 9a32f16..d9029be 100644 (file)
@@ -81,6 +81,12 @@ module Data.Set (
             , split
             , splitMember
 
+            -- * Indexed
+            , lookupIndex
+            , findIndex
+            , elemAt
+            , deleteAt
+
             -- * Map
             , S.map
             , mapMonotonic
index 9790163..59c74eb 100644 (file)
@@ -127,6 +127,12 @@ module Data.Set.Base (
             , split
             , splitMember
 
+            -- * Indexed
+            , lookupIndex
+            , findIndex
+            , elemAt
+            , deleteAt
+
             -- * Map
             , map
             , mapMonotonic
@@ -197,6 +203,7 @@ import Data.Data
 -- want the compilers to be compiled by as many compilers as possible.
 #define STRICT_1_OF_2(fn) fn arg _ | arg `seq` False = undefined
 #define STRICT_1_OF_3(fn) fn arg _ _ | arg `seq` False = undefined
+#define STRICT_2_OF_3(fn) fn _ arg _ | arg `seq` False = undefined
 
 {--------------------------------------------------------------------
   Operators
@@ -1034,6 +1041,96 @@ splitMember x (Bin _ y l r)
 #endif
 
 {--------------------------------------------------------------------
+  Indexing
+--------------------------------------------------------------------}
+
+-- | /O(log n)/. Return the /index/ of an element. The index is a number from
+-- /0/ up to, but not including, the 'size' of the set. Calls 'error' when
+-- the element is not a 'member' of the set.
+--
+-- > findIndex 2 (fromList [5,3])    Error: element is not in the set
+-- > findIndex 3 (fromList [5,3]) == 0
+-- > findIndex 5 (fromList [5,3]) == 1
+-- > findIndex 6 (fromList [5,3])    Error: element is not in the set
+
+-- See Note: Type of local 'go' function
+findIndex :: Ord a => a -> Set a -> Int
+findIndex = go 0
+  where
+    go :: Ord a => Int -> a -> Set a -> Int
+    STRICT_1_OF_3(go)
+    STRICT_2_OF_3(go)
+    go _   _ Tip  = error "Set.findIndex: element is not in the set"
+    go idx x (Bin _ kx l r) = case compare x kx of
+      LT -> go idx x l
+      GT -> go (idx + size l + 1) x r
+      EQ -> idx + size l
+#if __GLASGOW_HASKELL__ >= 700
+{-# INLINABLE findIndex #-}
+#endif
+
+-- | /O(log n)/. Lookup the /index/ of an element. The index is a number from
+-- /0/ up to, but not including, the 'size' of the set.
+--
+-- > isJust   (lookupIndex 2 (fromList [5,3])) == False
+-- > fromJust (lookupIndex 3 (fromList [5,3])) == 0
+-- > fromJust (lookupIndex 5 (fromList [5,3])) == 1
+-- > isJust   (lookupIndex 6 (fromList [5,3])) == False
+
+-- See Note: Type of local 'go' function
+lookupIndex :: Ord a => a -> Set a -> Maybe Int
+lookupIndex = go 0
+  where
+    go :: Ord a => Int -> a -> Set a -> Maybe Int
+    STRICT_1_OF_3(go)
+    STRICT_2_OF_3(go)
+    go _   _ Tip  = Nothing
+    go idx x (Bin _ kx l r) = case compare x kx of
+      LT -> go idx x l
+      GT -> go (idx + size l + 1) x r
+      EQ -> Just $! idx + size l
+#if __GLASGOW_HASKELL__ >= 700
+{-# INLINABLE lookupIndex #-}
+#endif
+
+-- | /O(log n)/. Retrieve an element by /index/. Calls 'error' when an
+-- invalid index is used.
+--
+-- > elemAt 0 (fromList [5,3]) == 3
+-- > elemAt 1 (fromList [5,3]) == 5
+-- > elemAt 2 (fromList [5,3])    Error: index out of range
+
+elemAt :: Int -> Set a -> a
+STRICT_1_OF_2(elemAt)
+elemAt _ Tip = error "Set.elemAt: index out of range"
+elemAt i (Bin _ x l r)
+  = case compare i sizeL of
+      LT -> elemAt i l
+      GT -> elemAt (i-sizeL-1) r
+      EQ -> x
+  where
+    sizeL = size l
+
+-- | /O(log n)/. Delete the element at /index/.
+--
+-- > deleteAt 0    (fromList [5,3]) == singleton 5
+-- > deleteAt 1    (fromList [5,3]) == singleton 3
+-- > deleteAt 2    (fromList [5,3])    Error: index out of range
+-- > deleteAt (-1) (fromList [5,3])    Error: index out of range
+
+deleteAt :: Int -> Set a -> Set a
+deleteAt i t = i `seq`
+  case t of
+    Tip -> error "Set.deleteAt: index out of range"
+    Bin _ x l r -> case compare i sizeL of
+      LT -> balanceR x (deleteAt i l) r
+      GT -> balanceL x l (deleteAt (i-sizeL-1) r)
+      EQ -> glue l r
+      where
+        sizeL = size l
+
+
+{--------------------------------------------------------------------
   Utility functions that maintain the balance properties of the tree.
   All constructors assume that all values in [l] < [x] and all values
   in [r] > [x], and that [l] and [r] are valid trees.
index 08c6b20..56e0b70 100644 (file)
@@ -2,6 +2,7 @@ import qualified Data.IntSet as IntSet
 import Data.List (nub,sort)
 import qualified Data.List as List
 import Data.Monoid (mempty)
+import Data.Maybe
 import Data.Set
 import Prelude hiding (lookup, null, map, filter, foldr, foldl)
 import Test.Framework
@@ -36,6 +37,10 @@ main = defaultMain [ testCase "lookupLT" test_lookupLT
                    , testProperty "prop_Diff" prop_Diff
                    , testProperty "prop_IntValid" prop_IntValid
                    , testProperty "prop_Int" prop_Int
+                   , testCase "lookupIndex" test_lookupIndex
+                   , testCase "findIndex" test_findIndex
+                   , testCase "elemAt" test_elemAt
+                   , testCase "deleteAt" test_deleteAt
                    , testProperty "prop_Ordered" prop_Ordered
                    , testProperty "prop_List" prop_List
                    , testProperty "prop_DescList" prop_DescList
@@ -231,6 +236,32 @@ prop_Int xs ys = toAscList (intersection (fromList xs) (fromList ys))
                  == List.sort (nub ((List.intersect) (xs)  (ys)))
 
 {--------------------------------------------------------------------
+  Indexed
+--------------------------------------------------------------------}
+
+test_lookupIndex :: Assertion
+test_lookupIndex = do
+    isJust   (lookupIndex 2 (fromList [5,3])) @?= False
+    fromJust (lookupIndex 3 (fromList [5,3])) @?= 0
+    fromJust (lookupIndex 5 (fromList [5,3])) @?= 1
+    isJust   (lookupIndex 6 (fromList [5,3])) @?= False
+
+test_findIndex :: Assertion
+test_findIndex = do
+    findIndex 3 (fromList [5,3]) @?= 0
+    findIndex 5 (fromList [5,3]) @?= 1
+
+test_elemAt :: Assertion
+test_elemAt = do
+    elemAt 0 (fromList [5,3]) @?= 3
+    elemAt 1 (fromList [5,3]) @?= 5
+
+test_deleteAt :: Assertion
+test_deleteAt = do
+    deleteAt 0 (fromList [5,3]) @?= singleton 5
+    deleteAt 1 (fromList [5,3]) @?= singleton 3
+
+{--------------------------------------------------------------------
   Lists
 --------------------------------------------------------------------}
 prop_Ordered :: Property