Make inplace fusion work on Streams rather than Bundles
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Sun, 7 Oct 2012 20:48:42 +0000 (20:48 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Sun, 7 Oct 2012 20:48:42 +0000 (20:48 +0000)
Data/Vector/Fusion/Bundle.hs
Data/Vector/Fusion/Bundle/Monadic.hs
Data/Vector/Generic.hs
Data/Vector/Generic/Mutable.hs
Data/Vector/Generic/New.hs

index a648ee7..355c9b9 100644 (file)
@@ -82,6 +82,7 @@ import Data.Vector.Fusion.Util
 import Data.Vector.Fusion.Stream.Monadic ( Stream(..), Step(..), SPEC(..) )
 import Data.Vector.Fusion.Bundle.Monadic ( Chunk(..) )
 import qualified Data.Vector.Fusion.Bundle.Monadic as M
+import qualified Data.Vector.Fusion.Stream.Monadic as S
 
 import Prelude hiding ( length, null,
                         replicate, (++),
@@ -107,18 +108,18 @@ type Bundle = M.Bundle Id
 -- | Alternative name for monadic streams
 type MBundle = M.Bundle
 
-inplace :: (forall m. Monad m => M.Bundle m v a -> M.Bundle m v b)
-        -> Bundle v a -> Bundle v b
+inplace :: (forall m. Monad m => S.Stream m a -> S.Stream m b)
+       -> (Size -> Size) -> Bundle v a -> Bundle v b
 {-# INLINE_FUSED inplace #-}
-inplace f s = s `seq` f s
+inplace f g b = b `seq` M.fromStream (f (M.elements b)) (g (M.size b))
 
 {-# RULES
 
 "inplace/inplace [Vector]"
-  forall (f :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         (g :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         s.
-  inplace f (inplace g s) = inplace (f . g) s
+  forall (f1 :: forall m. Monad m => S.Stream m a -> S.Stream m a)
+         (f2 :: forall m. Monad m => S.Stream m a -> S.Stream m a)
+         g1 g2 s.
+  inplace f1 g1 (inplace f2 g2 s) = inplace (f1 . f2) (g1 . g2) s
 
   #-}
 
index ad2215d..0548a48 100644 (file)
@@ -73,7 +73,7 @@ module Data.Vector.Fusion.Bundle.Monadic (
   -- * Conversions
   toList, fromList, fromListN, unsafeFromList,
   fromVector, reVector, fromVectors, concatVectors,
-  fromStream, chunks
+  fromStream, chunks, elements
 ) where
 
 import Data.Vector.Generic.Base
@@ -126,6 +126,10 @@ chunks :: Bundle m v a -> Stream m (Chunk v a)
 {-# INLINE chunks #-}
 chunks = sChunks
 
+elements :: Bundle m v a -> Stream m a
+{-# INLINE elements #-}
+elements = sElems
+
 -- | 'Size' hint of a 'Bundle'
 size :: Bundle m v a -> Size
 {-# INLINE size #-}
index 9483020..540ef11 100644 (file)
@@ -169,8 +169,10 @@ import qualified Data.Vector.Generic.New as New
 import           Data.Vector.Generic.New ( New )
 
 import qualified Data.Vector.Fusion.Bundle as Bundle
-import           Data.Vector.Fusion.Bundle ( Bundle, MBundle, Step(..), inplace, lift )
+import           Data.Vector.Fusion.Bundle ( Bundle, MBundle, Step(..), lift, inplace )
 import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle
+import           Data.Vector.Fusion.Stream.Monadic ( Stream )
+import qualified Data.Vector.Fusion.Stream.Monadic as S
 import           Data.Vector.Fusion.Bundle.Size
 import           Data.Vector.Fusion.Util
 
@@ -962,12 +964,12 @@ indexed = unstream . Bundle.indexed . stream
 -- | /O(n)/ Map a function over a vector
 map :: (Vector v a, Vector v b) => (a -> b) -> v a -> v b
 {-# INLINE map #-}
-map f = unstream . inplace (MBundle.map f) . stream
+map f = unstream . inplace (S.map f) id . stream
 
 -- | /O(n)/ Apply a function to every element of a vector and its index
 imap :: (Vector v a, Vector v b) => (Int -> a -> b) -> v a -> v b
 {-# INLINE imap #-}
-imap f = unstream . inplace (MBundle.map (uncurry f) . MBundle.indexed)
+imap f = unstream . inplace (S.map (uncurry f) . S.indexed) id
                   . stream
 
 -- | Map a function over a vector and concatenate the results.
@@ -1230,15 +1232,14 @@ unzip6 xs = (map (\(a, b, c, d, e, f) -> a) xs,
 -- | /O(n)/ Drop elements that do not satisfy the predicate
 filter :: Vector v a => (a -> Bool) -> v a -> v a
 {-# INLINE filter #-}
-filter f = unstream . inplace (MBundle.filter f) . stream
+filter f = unstream . inplace (S.filter f) toMax . stream
 
 -- | /O(n)/ Drop elements that do not satisfy the predicate which is applied to
 -- values and their indices
 ifilter :: Vector v a => (Int -> a -> Bool) -> v a -> v a
 {-# INLINE ifilter #-}
 ifilter f = unstream
-          . inplace (MBundle.map snd . MBundle.filter (uncurry f)
-                                     . MBundle.indexed)
+          . inplace (S.map snd . S.filter (uncurry f) . S.indexed) toMax
           . stream
 
 -- | /O(n)/ Drop elements that do not satisfy the monadic predicate
@@ -1366,8 +1367,7 @@ findIndex f = Bundle.findIndex f . stream
 findIndices :: (Vector v a, Vector v Int) => (a -> Bool) -> v a -> v Int
 {-# INLINE findIndices #-}
 findIndices f = unstream
-              . inplace (MBundle.map fst . MBundle.filter (f . snd)
-                                         . MBundle.indexed)
+              . inplace (S.map fst . S.filter (f . snd) . S.indexed) toMax
               . stream
 
 -- | /O(n)/ Yield 'Just' the index of the first occurence of the given element or
@@ -1624,12 +1624,12 @@ sequence_ = mapM_ id
 --
 prescanl :: (Vector v a, Vector v b) => (a -> b -> a) -> a -> v b -> v a
 {-# INLINE prescanl #-}
-prescanl f z = unstream . inplace (MBundle.prescanl f z) . stream
+prescanl f z = unstream . inplace (S.prescanl f z) id . stream
 
 -- | /O(n)/ Prescan with strict accumulator
 prescanl' :: (Vector v a, Vector v b) => (a -> b -> a) -> a -> v b -> v a
 {-# INLINE prescanl' #-}
-prescanl' f z = unstream . inplace (MBundle.prescanl' f z) . stream
+prescanl' f z = unstream . inplace (S.prescanl' f z) id . stream
 
 -- | /O(n)/ Scan
 --
@@ -1641,12 +1641,12 @@ prescanl' f z = unstream . inplace (MBundle.prescanl' f z) . stream
 --
 postscanl :: (Vector v a, Vector v b) => (a -> b -> a) -> a -> v b -> v a
 {-# INLINE postscanl #-}
-postscanl f z = unstream . inplace (MBundle.postscanl f z) . stream
+postscanl f z = unstream . inplace (S.postscanl f z) id . stream
 
 -- | /O(n)/ Scan with strict accumulator
 postscanl' :: (Vector v a, Vector v b) => (a -> b -> a) -> a -> v b -> v a
 {-# INLINE postscanl' #-}
-postscanl' f z = unstream . inplace (MBundle.postscanl' f z) . stream
+postscanl' f z = unstream . inplace (S.postscanl' f z) id . stream
 
 -- | /O(n)/ Haskell-style scan
 --
@@ -1673,12 +1673,12 @@ scanl' f z = unstream . Bundle.scanl' f z . stream
 --
 scanl1 :: Vector v a => (a -> a -> a) -> v a -> v a
 {-# INLINE scanl1 #-}
-scanl1 f = unstream . inplace (MBundle.scanl1 f) . stream
+scanl1 f = unstream . inplace (S.scanl1 f) id . stream
 
 -- | /O(n)/ Scan over a non-empty vector with a strict accumulator
 scanl1' :: Vector v a => (a -> a -> a) -> v a -> v a
 {-# INLINE scanl1' #-}
-scanl1' f = unstream . inplace (MBundle.scanl1' f) . stream
+scanl1' f = unstream . inplace (S.scanl1' f) id . stream
 
 -- | /O(n)/ Right-to-left prescan
 --
@@ -1688,22 +1688,22 @@ scanl1' f = unstream . inplace (MBundle.scanl1' f) . stream
 --
 prescanr :: (Vector v a, Vector v b) => (a -> b -> b) -> b -> v a -> v b
 {-# INLINE prescanr #-}
-prescanr f z = unstreamR . inplace (MBundle.prescanl (flip f) z) . streamR
+prescanr f z = unstreamR . inplace (S.prescanl (flip f) z) id . streamR
 
 -- | /O(n)/ Right-to-left prescan with strict accumulator
 prescanr' :: (Vector v a, Vector v b) => (a -> b -> b) -> b -> v a -> v b
 {-# INLINE prescanr' #-}
-prescanr' f z = unstreamR . inplace (MBundle.prescanl' (flip f) z) . streamR
+prescanr' f z = unstreamR . inplace (S.prescanl' (flip f) z) id . streamR
 
 -- | /O(n)/ Right-to-left scan
 postscanr :: (Vector v a, Vector v b) => (a -> b -> b) -> b -> v a -> v b
 {-# INLINE postscanr #-}
-postscanr f z = unstreamR . inplace (MBundle.postscanl (flip f) z) . streamR
+postscanr f z = unstreamR . inplace (S.postscanl (flip f) z) id . streamR
 
 -- | /O(n)/ Right-to-left scan with strict accumulator
 postscanr' :: (Vector v a, Vector v b) => (a -> b -> b) -> b -> v a -> v b
 {-# INLINE postscanr' #-}
-postscanr' f z = unstreamR . inplace (MBundle.postscanl' (flip f) z) . streamR
+postscanr' f z = unstreamR . inplace (S.postscanl' (flip f) z) id . streamR
 
 -- | /O(n)/ Right-to-left Haskell-style scan
 scanr :: (Vector v a, Vector v b) => (a -> b -> b) -> b -> v a -> v b
@@ -1718,13 +1718,13 @@ scanr' f z = unstreamR . Bundle.scanl' (flip f) z . streamR
 -- | /O(n)/ Right-to-left scan over a non-empty vector
 scanr1 :: Vector v a => (a -> a -> a) -> v a -> v a
 {-# INLINE scanr1 #-}
-scanr1 f = unstreamR . inplace (MBundle.scanl1 (flip f)) . streamR
+scanr1 f = unstreamR . inplace (S.scanl1 (flip f)) id . streamR
 
 -- | /O(n)/ Right-to-left scan over a non-empty vector with a strict
 -- accumulator
 scanr1' :: Vector v a => (a -> a -> a) -> v a -> v a
 {-# INLINE scanr1' #-}
-scanr1' f = unstreamR . inplace (MBundle.scanl1' (flip f)) . streamR
+scanr1' f = unstreamR . inplace (S.scanl1' (flip f)) id . streamR
 
 -- Conversions - Lists
 -- ------------------------
@@ -1873,12 +1873,12 @@ unstream s = new (New.unstream s)
   clone (new p) = p
 
 "inplace [Vector]"
-  forall (f :: forall m. Monad m => MBundle m u a -> MBundle m u a) m.
-  New.unstream (inplace f (stream (new m))) = New.transform f m
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
+  New.unstream (inplace f g (stream (new m))) = New.transform f g m
 
 "uninplace [Vector]"
-  forall (f :: forall m. Monad m => MBundle m u a -> MBundle m u a) m.
-  stream (new (New.transform f m)) = inplace f (stream (new m))
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
+  stream (new (New.transform f g m)) = inplace f g (stream (new m))
 
  #-}
 
@@ -1915,12 +1915,12 @@ unstreamR s = new (New.unstreamR s)
   New.unstreamR (stream (new p)) = New.modify M.reverse p
 
 "inplace right [Vector]"
-  forall (f :: forall m. Monad m => MBundle m u a -> MBundle m u a) m.
-  New.unstreamR (inplace f (streamR (new m))) = New.transformR f m
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
+  New.unstreamR (inplace f g (streamR (new m))) = New.transformR f g m
 
 "uninplace right [Vector]"
-  forall (f :: forall m. Monad m => MBundle m u a -> MBundle m u a) m.
-  streamR (new (New.transformR f m)) = inplace f (streamR (new m))
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a) g m.
+  streamR (new (New.transformR f g m)) = inplace f g (streamR (new m))
 
  #-}
 
index 3cbebf1..43bffb2 100644 (file)
@@ -63,7 +63,8 @@ import qualified Data.Vector.Generic.Base as V
 import qualified Data.Vector.Fusion.Bundle      as Bundle
 import           Data.Vector.Fusion.Bundle      ( Bundle, MBundle, Chunk(..) )
 import qualified Data.Vector.Fusion.Bundle.Monadic as MBundle
-import qualified Data.Vector.Fusion.Stream.Monadic as MStream
+import           Data.Vector.Fusion.Stream.Monadic ( Stream )
+import qualified Data.Vector.Fusion.Stream.Monadic as Stream
 import           Data.Vector.Fusion.Bundle.Size
 import           Data.Vector.Fusion.Util        ( delay_inline )
 
@@ -239,9 +240,9 @@ unsafePrepend1 v i x
                     $ unsafeWrite v' i' x
                   return (v', i')
 
-mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MBundle m u a
+mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
 {-# INLINE mstream #-}
-mstream v = v `seq` n `seq` (MBundle.unfoldrM get 0 `MBundle.sized` Exact n)
+mstream v = v `seq` n `seq` (Stream.unfoldrM get 0)
   where
     n = length v
 
@@ -251,12 +252,10 @@ mstream v = v `seq` n `seq` (MBundle.unfoldrM get 0 `MBundle.sized` Exact n)
           | otherwise = return $ Nothing
 
 fill :: (PrimMonad m, MVector v a)
-     => v (PrimState m) a
-     -> MBundle m u a
-     -> m (v (PrimState m) a)
+     => v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
 {-# INLINE fill #-}
 fill v s = v `seq` do
-                     n' <- MBundle.foldM put 0 s
+                     n' <- Stream.foldM put 0 s
                      return $ unsafeSlice 0 n' v
   where
     {-# INLINE_INNER put #-}
@@ -265,18 +264,15 @@ fill v s = v `seq` do
                   $ unsafeWrite v i x
                 return (i+1)
 
-transform :: (PrimMonad m, MVector v a)
-          => (MBundle m u a -> MBundle m u a)
-          -> v (PrimState m) a
-          -> m (v (PrimState m) a)
+transform
+  :: (PrimMonad m, MVector v a)
+  => (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
 {-# INLINE_FUSED transform #-}
 transform f v = fill v (f (mstream v))
 
-mstreamR :: (PrimMonad m, MVector v a)
-         => v (PrimState m) a
-         -> MBundle m u a
+mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Stream m a
 {-# INLINE mstreamR #-}
-mstreamR v = v `seq` n `seq` (MBundle.unfoldrM get n `MBundle.sized` Exact n)
+mstreamR v = v `seq` n `seq` (Stream.unfoldrM get n)
   where
     n = length v
 
@@ -288,12 +284,10 @@ mstreamR v = v `seq` n `seq` (MBundle.unfoldrM get n `MBundle.sized` Exact n)
         j = i-1
 
 fillR :: (PrimMonad m, MVector v a)
-      => v (PrimState m) a
-      -> MBundle m u a
-      -> m (v (PrimState m) a)
+      => v (PrimState m) a -> Stream m a -> m (v (PrimState m) a)
 {-# INLINE fillR #-}
 fillR v s = v `seq` do
-                      i <- MBundle.foldM put n s
+                      i <- Stream.foldM put n s
                       return $ unsafeSlice i (n-i) v
   where
     n = length v
@@ -305,10 +299,9 @@ fillR v s = v `seq` do
       where
         j = i-1
 
-transformR :: (PrimMonad m, MVector v a)
-           => (MBundle m u a -> MBundle m u a)
-           -> v (PrimState m) a
-           -> m (v (PrimState m) a)
+transformR
+  :: (PrimMonad m, MVector v a)
+  => (Stream m a -> Stream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
 {-# INLINE_FUSED transformR #-}
 transformR f v = fillR v (f (mstreamR v))
 
@@ -419,7 +412,7 @@ vmunstreamMax s n
               f (basicUnsafeSlice i n v)
               return (i+n)
 
-      n' <- MStream.foldlM' copy 0 (MBundle.chunks s)
+      n' <- Stream.foldlM' copy 0 (MBundle.chunks s)
       return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
              $ unsafeSlice 0 n' v
 
@@ -429,7 +422,7 @@ vmunstreamUnknown :: (PrimMonad m, V.Vector v a)
 vmunstreamUnknown s
   = do
       v <- unsafeNew 0
-      (v', n) <- MStream.foldlM copy (v,0) (MBundle.chunks s)
+      (v', n) <- Stream.foldlM copy (v,0) (MBundle.chunks s)
       return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
              $ unsafeSlice 0 n v'
   where
index f2700be..e97565c 100644 (file)
@@ -26,6 +26,8 @@ import           Data.Vector.Generic.Base ( Vector, Mutable )
 
 import           Data.Vector.Fusion.Bundle ( Bundle, MBundle )
 import qualified Data.Vector.Fusion.Bundle as Bundle
+import           Data.Vector.Fusion.Stream.Monadic ( Stream )
+import           Data.Vector.Fusion.Bundle.Size
 
 import Control.Monad.Primitive
 import Control.Monad.ST ( ST )
@@ -65,23 +67,24 @@ unstream :: Vector v a => Bundle v a -> New v a
 {-# INLINE_FUSED unstream #-}
 unstream s = s `seq` New (MVector.vunstream s)
 
-transform :: Vector v a =>
-        (forall m. Monad m => MBundle m u a -> MBundle m u a) -> New v a -> New v a
+transform
+  :: Vector v a => (forall m. Monad m => Stream m a -> Stream m a)
+                -> (Size -> Size) -> New v a -> New v a
 {-# INLINE_FUSED transform #-}
-transform f (New p) = New (MVector.transform f =<< p)
+transform f (New p) = New (MVector.transform f =<< p)
 
 {-# RULES
 
 "transform/transform [New]"
-  forall (f :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         (g :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         p .
-  transform f (transform g p) = transform (f . g) p
+  forall (f1 :: forall m. Monad m => Stream m a -> Stream m a)
+         (f2 :: forall m. Monad m => Stream m a -> Stream m a)
+         g1 g2 p .
+  transform f1 g1 (transform f2 g2 p) = transform (f1 . f2) (g1 . g2) p
 
 "transform/unstream [New]"
-  forall (f :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         s.
-  transform f (unstream s) = unstream (f s)
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a)
+         s.
+  transform f g (unstream s) = unstream (Bundle.inplace f g s)
 
  #-}
 
@@ -90,23 +93,25 @@ unstreamR :: Vector v a => Bundle v a -> New v a
 {-# INLINE_FUSED unstreamR #-}
 unstreamR s = s `seq` New (MVector.unstreamR s)
 
-transformR :: Vector v a =>
-        (forall m. Monad m => MBundle m u a -> MBundle m u a) -> New v a -> New v a
+transformR
+  :: Vector v a => (forall m. Monad m => Stream m a -> Stream m a)
+                -> (Size -> Size) -> New v a -> New v a
 {-# INLINE_FUSED transformR #-}
-transformR f (New p) = New (MVector.transformR f =<< p)
+transformR f (New p) = New (MVector.transformR f =<< p)
 
 {-# RULES
 
 "transformR/transformR [New]"
-  forall (f :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         (g :: forall m. Monad m => MBundle m v a -> MBundle m v a)
+  forall (f1 :: forall m. Monad m => Stream m a -> Stream m a)
+         (f2 :: forall m. Monad m => Stream m a -> Stream m a)
+         g1 g2
          p .
-  transformR f (transformR g p) = transformR (f . g) p
+  transformR f1 g1 (transformR f2 g2 p) = transformR (f1 . f2) (g1 . g2) p
 
 "transformR/unstreamR [New]"
-  forall (f :: forall m. Monad m => MBundle m v a -> MBundle m v a)
-         s.
-  transformR f (unstreamR s) = unstreamR (f s)
+  forall (f :: forall m. Monad m => Stream m a -> Stream m a)
+         s.
+  transformR f g (unstreamR s) = unstreamR (Bundle.inplace f g s)
 
  #-}