Add unstablePartition, span, break
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 9 Dec 2009 04:51:48 +0000 (04:51 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 9 Dec 2009 04:51:48 +0000 (04:51 +0000)
Data/Vector.hs
Data/Vector/Generic.hs
Data/Vector/Generic/Mutable.hs
Data/Vector/Primitive.hs
Data/Vector/Storable.hs
Data/Vector/Unboxed.hs

index 0db69d8..d4bfc44 100644 (file)
@@ -44,6 +44,7 @@ module Data.Vector (
 
   -- * Filtering
   filter, ifilter, takeWhile, dropWhile,
+  unstablePartition, span, break,
 
   -- * Searching
   elem, notElem, find, findIndex, findIndices, elemIndex, elemIndices,
@@ -90,7 +91,7 @@ import Prelude hiding ( length, null,
                         init, tail, take, drop, reverse,
                         map, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
-                        filter, takeWhile, dropWhile,
+                        filter, takeWhile, dropWhile, span, break,
                         elem, notElem,
                         foldl, foldl1, foldr, foldr1,
                         all, any, and, or, sum, product, minimum, maximum,
@@ -471,6 +472,25 @@ dropWhile :: (a -> Bool) -> Vector a -> Vector a
 {-# INLINE dropWhile #-}
 dropWhile = G.dropWhile
 
+-- | Split the vector in two parts, the first one containing those elements
+-- that satisfy the predicate and the second one those that don't. The order
+-- of the elements is not preserved.
+unstablePartition :: (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE unstablePartition #-}
+unstablePartition = G.unstablePartition
+
+-- | Split the vector into the longest prefix of elements that satisfy the
+-- predicate and the rest.
+span :: (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE span #-}
+span = G.span
+
+-- | Split the vector into the longest prefix of elements that do not satisfy
+-- the predicate and the rest.
+break :: (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE break #-}
+break = G.break
+
 -- Searching
 -- ---------
 
index d777f37..6ef1909 100644 (file)
@@ -48,6 +48,7 @@ module Data.Vector.Generic (
 
   -- * Filtering
   filter, ifilter, takeWhile, dropWhile,
+  unstablePartition, span, break,
 
   -- * Searching
   elem, notElem, find, findIndex, findIndices, elemIndex, elemIndices,
@@ -89,6 +90,7 @@ module Data.Vector.Generic (
 ) where
 
 import           Data.Vector.Generic.Mutable ( MVector )
+import qualified Data.Vector.Generic.Mutable as M
 
 import qualified Data.Vector.Generic.New as New
 import           Data.Vector.Generic.New ( New )
@@ -107,7 +109,7 @@ import Prelude hiding ( length, null,
                         init, tail, take, drop, reverse,
                         map, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
-                        filter, takeWhile, dropWhile,
+                        filter, takeWhile, dropWhile, span, break,
                         elem, notElem,
                         foldl, foldl1, foldr, foldr1,
                         all, any, and, or, sum, product, maximum, minimum,
@@ -757,6 +759,58 @@ dropWhile :: Vector v a => (a -> Bool) -> v a -> v a
 {-# INLINE dropWhile #-}
 dropWhile f = unstream . Stream.dropWhile f . stream
 
+-- | Split the vector in two parts, the first one containing those elements
+-- that satisfy the predicate and the second one those that don't. The order
+-- of the elements is not preserved.
+unstablePartition :: Vector v a => (a -> Bool) -> v a -> (v a, v a)
+{-# INLINE unstablePartition #-}
+unstablePartition f = unstablePartition_stream f . stream
+
+unstablePartition_stream
+  :: Vector v a => (a -> Bool) -> Stream a -> (v a, v a)
+{-# INLINE_STREAM unstablePartition_stream #-}
+unstablePartition_stream f s = s `seq` runST (
+  do
+    (mv1,mv2) <- M.unstablePartitionStream f s
+    v1 <- unsafeFreeze mv1
+    v2 <- unsafeFreeze mv2
+    return (v1,v2))
+
+unstablePartition_new :: Vector v a => (a -> Bool) -> New a -> (v a, v a)
+{-# INLINE_STREAM unstablePartition_new #-}
+unstablePartition_new f (New.New p) = runST (
+  do
+    mv <- p
+    i <- M.unstablePartition f mv
+    v <- unsafeFreeze mv
+    return (take i v, drop i v))
+
+{-# RULES
+
+"unstablePartition" forall f v p.
+  unstablePartition_stream f (stream (new' v p))
+    = unstablePartition_new f p
+
+  #-}
+
+
+-- FIXME: make span and break fusible
+
+-- | Split the vector into the longest prefix of elements that satisfy the
+-- predicate and the rest.
+span :: Vector v a => (a -> Bool) -> v a -> (v a, v a)
+{-# INLINE span #-}
+span f = break (not . f)
+
+-- | Split the vector into the longest prefix of elements that do not satisfy
+-- the predicate and the rest.
+break :: Vector v a => (a -> Bool) -> v a -> (v a, v a)
+{-# INLINE break #-}
+break f xs = case findIndex f xs of
+               Just i  -> (unsafeSlice xs 0 i, unsafeSlice xs i (length xs - i))
+               Nothing -> (xs, empty)
+    
+
 -- Searching
 -- ---------
 
index 7fbb521..846e530 100644 (file)
@@ -23,7 +23,8 @@ module Data.Vector.Generic.Mutable (
   unsafeCopy, unsafeGrow,
 
   -- * Internal operations
-  unstream, transform, unsafeAccum, accum, unsafeUpdate, update, reverse
+  unstream, transform, unsafeAccum, accum, unsafeUpdate, update, reverse,
+  unstablePartition, unstablePartitionStream
 ) where
 
 import qualified Data.Vector.Fusion.Stream      as Stream
@@ -267,6 +268,33 @@ grow :: (PrimMonad m, MVector v a)
 grow v by = BOUNDS_CHECK(checkLength) "grow" by
           $ unsafeGrow v by
 
+-- | Grow a vector logarithmically
+enlarge :: (PrimMonad m, MVector v a)
+                => v (PrimState m) a -> m (v (PrimState m) a)
+{-# INLINE enlarge #-}
+enlarge v = unsafeGrow v
+          $ max 1
+          $ double2Int
+          $ int2Double (length v) * gROWTH_FACTOR
+
+unsafeAppend1 :: (PrimMonad m, MVector v a)
+        => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
+{-# INLINE_INNER unsafeAppend1 #-}
+    -- NOTE: The case distinction has to be on the outside because
+    -- GHC creates a join point for the unsafeWrite even when everything
+    -- is inlined. This is bad because with the join point, v isn't getting
+    -- unboxed.
+unsafeAppend1 v i x
+  | i < length v = do
+                     unsafeWrite v i x
+                     return v
+  | otherwise    = do
+                     v' <- enlarge v
+                     INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
+                       $ unsafeWrite v' i x
+                     return v'
+
+
 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
 {-# INLINE mstream #-}
 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
@@ -361,26 +389,10 @@ unstreamUnknown s
       (v', n) <- Stream.foldM put (v, 0) s
       return $ slice v' 0 n
   where
-    -- NOTE: The case distinction has to be on the outside because
-    -- GHC creates a join point for the unsafeWrite even when everything
-    -- is inlined. This is bad because with the join point, v isn't getting
-    -- unboxed.
     {-# INLINE_INNER put #-}
-    put (v, i) x
-      | i < length v = do
-                         unsafeWrite v i x
-                         return (v, i+1)
-      | otherwise    = do
-                         v' <- enlarge v
-                         INTERNAL_CHECK(checkIndex) "unstreamMax" i (length v')
-                           $ unsafeWrite v' i x
-                         return (v', i+1)
-
-    {-# INLINE_INNER enlarge #-}
-    enlarge v = unsafeGrow v
-              $ max 1
-              $ double2Int
-              $ int2Double (length v) * gROWTH_FACTOR
+    put (v,i) x = do
+                    v' <- unsafeAppend1 v i x
+                    return (v',i+1)
 
 unsafeAccum :: (PrimMonad m, MVector v a)
             => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
@@ -434,3 +446,79 @@ reverse !v = reverse_loop 0 (length v - 1)
                                  reverse_loop (i + 1) (j - 1)
     reverse_loop _ _ = return ()
 
+unstablePartition :: (PrimMonad m, MVector v a)
+                  => (a -> Bool) -> v (PrimState m) a -> m Int
+{-# INLINE unstablePartition #-}
+unstablePartition f !v = from_left 0 (length v)
+  where
+    from_left i j
+      | i == j    = return i
+      | otherwise = do
+                      x <- unsafeRead v i
+                      if f x
+                        then from_left (i+1) j
+                        else from_right i (j-1)
+
+    from_right i j
+      | i == j    = return i
+      | otherwise = do
+                      x <- unsafeRead v j
+                      if f x
+                        then do
+                               y <- unsafeRead v i
+                               unsafeWrite v i x
+                               unsafeWrite v j y
+                               from_left (i+1) j
+                        else from_right i (j-1)
+
+unstablePartitionStream :: (PrimMonad m, MVector v a)
+        => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
+{-# INLINE unstablePartitionStream #-}
+unstablePartitionStream f s
+  = case upperBound (Stream.size s) of
+      Just n  -> unstablePartitionMax f s n
+      Nothing -> partitionUnknown f s
+
+
+unstablePartitionMax :: (PrimMonad m, MVector v a)
+        => (a -> Bool) -> Stream a -> Int
+        -> m (v (PrimState m) a, v (PrimState m) a)
+{-# INLINE unstablePartitionMax #-}
+unstablePartitionMax f s n
+  = do
+      v <- new n
+      let {-# INLINE_INNER put #-}
+          put (i, j) x
+            | f x       = do
+                            unsafeWrite v i x
+                            return (i+1, j)
+            | otherwise = do
+                            unsafeWrite v (j-1) x
+                            return (i, j-1)
+                                
+      (i,j) <- Stream.foldM' put (0, n) s
+      return (slice v 0 i, slice v j (n-j))
+
+partitionUnknown :: (PrimMonad m, MVector v a)
+        => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
+{-# INLINE partitionUnknown #-}
+partitionUnknown f s
+  = do
+      v1 <- new 0
+      v2 <- new 0
+      (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
+      return (slice v1' 0 n1, slice v2' 0 n2)
+  where
+    -- NOTE: The case distinction has to be on the outside because
+    -- GHC creates a join point for the unsafeWrite even when everything
+    -- is inlined. This is bad because with the join point, v isn't getting
+    -- unboxed.
+    {-# INLINE_INNER put #-}
+    put (v1, i1, v2, i2) x
+      | f x       = do
+                      v1' <- unsafeAppend1 v1 i1 x
+                      return (v1', i1+1, v2, i2)
+      | otherwise = do
+                      v2' <- unsafeAppend1 v2 i2 x
+                      return (v1, i1, v2', i2+1)
+
index 7eae262..ffde452 100644 (file)
@@ -40,6 +40,7 @@ module Data.Vector.Primitive (
 
   -- * Filtering
   filter, ifilter, takeWhile, dropWhile,
+  unstablePartition, span, break,
 
   -- * Searching
   elem, notElem, find, findIndex, findIndices, elemIndex, elemIndices,
@@ -87,7 +88,7 @@ import Prelude hiding ( length, null,
                         init, tail, take, drop, reverse,
                         map, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
-                        filter, takeWhile, dropWhile,
+                        filter, takeWhile, dropWhile, span, break,
                         elem, notElem,
                         foldl, foldl1, foldr, foldr1,
                         all, any, sum, product, minimum, maximum,
@@ -403,6 +404,25 @@ dropWhile :: Prim a => (a -> Bool) -> Vector a -> Vector a
 {-# INLINE dropWhile #-}
 dropWhile = G.dropWhile
 
+-- | Split the vector in two parts, the first one containing those elements
+-- that satisfy the predicate and the second one those that don't. The order
+-- of the elements is not preserved.
+unstablePartition :: Prim a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE unstablePartition #-}
+unstablePartition = G.unstablePartition
+
+-- | Split the vector into the longest prefix of elements that satisfy the
+-- predicate and the rest.
+span :: Prim a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE span #-}
+span = G.span
+
+-- | Split the vector into the longest prefix of elements that do not satisfy
+-- the predicate and the rest.
+break :: Prim a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE break #-}
+break = G.break
+
 -- Searching
 -- ---------
 
index cbaeaee..1bf6671 100644 (file)
@@ -40,6 +40,7 @@ module Data.Vector.Storable (
 
   -- * Filtering
   filter, ifilter, takeWhile, dropWhile,
+  unstablePartition, span, break,
 
   -- * Searching
   elem, notElem, find, findIndex, findIndices, elemIndex, elemIndices,
@@ -91,7 +92,7 @@ import Prelude hiding ( length, null,
                         init, tail, take, drop, reverse,
                         map, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
-                        filter, takeWhile, dropWhile,
+                        filter, takeWhile, dropWhile, span, break,
                         elem, notElem,
                         foldl, foldl1, foldr, foldr1,
                         all, any, and, or, sum, product, minimum, maximum,
@@ -438,6 +439,26 @@ dropWhile :: Storable a => (a -> Bool) -> Vector a -> Vector a
 {-# INLINE dropWhile #-}
 dropWhile = G.dropWhile
 
+-- | Split the vector in two parts, the first one containing those elements
+-- that satisfy the predicate and the second one those that don't. The order
+-- of the elements is not preserved.
+unstablePartition
+        :: Storable a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE unstablePartition #-}
+unstablePartition = G.unstablePartition
+
+-- | Split the vector into the longest prefix of elements that satisfy the
+-- predicate and the rest.
+span :: Storable a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE span #-}
+span = G.span
+
+-- | Split the vector into the longest prefix of elements that do not satisfy
+-- the predicate and the rest.
+break :: Storable a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE break #-}
+break = G.break
+
 -- Searching
 -- ---------
 
index b746268..96e0d58 100644 (file)
@@ -42,6 +42,7 @@ module Data.Vector.Unboxed (
 
   -- * Filtering
   filter, ifilter, takeWhile, dropWhile,
+  unstablePartition, span, break,
 
   -- * Searching
   elem, notElem, find, findIndex, findIndices, elemIndex, elemIndices,
@@ -86,7 +87,7 @@ import Prelude hiding ( length, null,
                         init, tail, take, drop, reverse,
                         map, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
-                        filter, takeWhile, dropWhile,
+                        filter, takeWhile, dropWhile, span, break,
                         elem, notElem,
                         foldl, foldl1, foldr, foldr1,
                         all, any, and, or, sum, product, minimum, maximum,
@@ -393,6 +394,25 @@ dropWhile :: Unbox a => (a -> Bool) -> Vector a -> Vector a
 {-# INLINE dropWhile #-}
 dropWhile = G.dropWhile
 
+-- | Split the vector in two parts, the first one containing those elements
+-- that satisfy the predicate and the second one those that don't. The order
+-- of the elements is not preserved.
+unstablePartition :: Unbox a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE unstablePartition #-}
+unstablePartition = G.unstablePartition
+
+-- | Split the vector into the longest prefix of elements that satisfy the
+-- predicate and the rest.
+span :: Unbox a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE span #-}
+span = G.span
+
+-- | Split the vector into the longest prefix of elements that do not satisfy
+-- the predicate and the rest.
+break :: Unbox a => (a -> Bool) -> Vector a -> (Vector a, Vector a)
+{-# INLINE break #-}
+break = G.break
+
 -- Searching
 -- ---------