Add the dreaded concatMap
[darcs-mirrors/vector.git] / Data / Vector / Fusion / Stream / Monadic.hs
index 8546190..ade0c0c 100644 (file)
@@ -15,7 +15,7 @@
 #include "phases.h"
 
 module Data.Vector.Fusion.Stream.Monadic (
-  Stream(..),
+  Stream(..), Step(..),
 
   -- * Size hints
   size, sized,
@@ -46,16 +46,19 @@ module Data.Vector.Fusion.Stream.Monadic (
   foldl', foldlM', foldl1', foldl1M',
   foldr, foldrM, foldr1, foldr1M,
 
+  -- * Specialised folds
+  and, or, concatMap, concatMapM,
+
   -- * Unfolding
-  unfold, unfoldM,
+  unfoldr, unfoldrM,
 
   -- * Scans
   prescanl, prescanlM, prescanl', prescanlM',
 
+  -- * Conversions
   toList, fromList
 ) where
 
-import Data.Vector.Fusion.Stream.Step
 import Data.Vector.Fusion.Stream.Size
 
 import Control.Monad  ( liftM )
@@ -66,9 +69,16 @@ import Prelude hiding ( length, null,
                         map, mapM, mapM_, zipWith,
                         filter, takeWhile, dropWhile,
                         elem, notElem,
-                        foldl, foldl1, foldr, foldr1 )
+                        foldl, foldl1, foldr, foldr1,
+                        and, or, concatMap )
 import qualified Prelude
 
+-- | Result of taking a single step in a stream
+data Step s a = Yield a s  -- ^ a new element and a new seed
+              | Skip    s  -- ^ just a new seed
+              | Done       -- ^ end of stream
+
+-- | Monadic streams
 data Stream m a = forall s. Stream (s -> m (Step s a)) s Size
 
 -- | 'Size' hint of a 'Stream'
@@ -187,7 +197,17 @@ last (Stream step s _) = last_loop0 s
 -- | Element at the given position
 (!!) :: Monad m => Stream m a -> Int -> m a
 {-# INLINE (!!) #-}
-s !! i = head (drop i s)
+Stream step s _ !! i | i < 0     = errorNegativeIndex "!!"
+                     | otherwise = loop s i
+  where
+    loop s i = i `seq`
+               do
+                 r <- step s
+                 case r of
+                   Yield x s' | i == 0    -> return x
+                              | otherwise -> loop s' (i-1)
+                   Skip    s'             -> loop s' i
+                   Done                   -> errorIndexOutOfRange "!!"
 
 -- Substreams
 -- ----------
@@ -209,7 +229,7 @@ init (Stream step s sz) = Stream step' (Nothing, s) (sz - 1)
                            case r of
                              Yield x s' -> Skip (Just x,  s')
                              Skip    s' -> Skip (Nothing, s')
-                             Done       -> Done  -- FIXME: should be an error
+                             Done       -> errorEmptyStream "init"
                          ) (step s)
 
     step' (Just x,  s) = liftM (\r -> 
@@ -229,7 +249,7 @@ tail (Stream step s sz) = Stream step' (Left s) (sz - 1)
                         case r of
                           Yield x s' -> Skip (Right s')
                           Skip    s' -> Skip (Left  s')
-                          Done       -> Done    -- FIXME: should be error?
+                          Done       -> errorEmptyStream "tail"
                       ) (step s)
 
     step' (Right s) = liftM (\r ->
@@ -608,18 +628,66 @@ foldr1M f (Stream step s _) = foldr1M_go0 s
                           Skip    s' -> foldr1M_go1 x s'
                           Done       -> return x
 
+-- Specialised folds
+-- -----------------
+
+and :: Monad m => Stream m Bool -> m Bool
+and (Stream step s _) = and_go s
+  where
+    and_go s = do
+                 r <- step s
+                 case r of
+                   Yield False _  -> return False
+                   Yield True  s' -> and_go s'
+                   Skip        s' -> and_go s'
+                   Done           -> return True
+
+or :: Monad m => Stream m Bool -> m Bool
+or (Stream step s _) = or_go s
+  where
+    or_go s = do
+                r <- step s
+                case r of
+                  Yield False s' -> or_go s'
+                  Yield True  _  -> return True
+                  Skip        s' -> or_go s'
+                  Done           -> return False
+
+concatMap :: Monad m => (a -> Stream m b) -> Stream m a -> Stream m b
+{-# INLINE concatMap #-}
+concatMap f = concatMapM (return . f)
+
+concatMapM :: Monad m => (a -> m (Stream m b)) -> Stream m a -> Stream m b
+{-# INLINE_STREAM concatMapM #-}
+concatMapM f (Stream step s _) = Stream concatMap_go (Left s) Unknown
+  where
+    concatMap_go (Left s) = do
+        r <- step s
+        case r of
+            Yield a s' -> do
+                b_stream <- f a
+                return $ Skip (Right (b_stream, s'))
+            Skip    s' -> return $ Skip (Left s')
+            Done       -> return Done
+    concatMap_go (Right (Stream inner_step inner_s sz, s)) = do
+        r <- inner_step inner_s
+        case r of
+            Yield b inner_s' -> return $ Yield b (Right (Stream inner_step inner_s' sz, s))
+            Skip    inner_s' -> return $ Skip (Right (Stream inner_step inner_s' sz, s))
+            Done             -> return $ Skip (Left s)
+
 -- Unfolding
 -- ---------
 
 -- | Unfold
-unfold :: Monad m => (s -> Maybe (a, s)) -> s -> Stream m a
-{-# INLINE_STREAM unfold #-}
-unfold f = unfoldM (return . f)
+unfoldr :: Monad m => (s -> Maybe (a, s)) -> s -> Stream m a
+{-# INLINE_STREAM unfoldr #-}
+unfoldr f = unfoldrM (return . f)
 
 -- | Unfold with a monadic function
-unfoldM :: Monad m => (s -> m (Maybe (a, s))) -> s -> Stream m a
-{-# INLINE_STREAM unfoldM #-}
-unfoldM f s = Stream step s Unknown
+unfoldrM :: Monad m => (s -> m (Maybe (a, s))) -> s -> Stream m a
+{-# INLINE_STREAM unfoldrM #-}
+unfoldrM f s = Stream step s Unknown
   where
     {-# INLINE step #-}
     step s = liftM (\r ->
@@ -689,7 +757,16 @@ fromList xs = Stream step xs Unknown
     step []     = return Done
 
 
+streamError :: String -> String -> a
+streamError fn msg = error $ "Data.Vector.Fusion.Stream.Monadic."
+                             Prelude.++ fn Prelude.++ ": " Prelude.++ msg
+
 errorEmptyStream :: String -> a
-errorEmptyStream s = error $ "Data.Vector.Fusion.Stream.Monadic."
-                        Prelude.++ s Prelude.++ ": empty stream"
+errorEmptyStream fn = streamError fn "empty stream"
+
+errorNegativeIndex :: String -> a
+errorNegativeIndex fn = streamError fn "negative index"
+
+errorIndexOutOfRange :: String -> a
+errorIndexOutOfRange fn = streamError fn "index out of range"