Reimplement specialisation of monadic functions
[darcs-mirrors/vector.git] / Data / Vector / Generic.hs
index 0767d55..1bdcb63 100644 (file)
@@ -30,7 +30,7 @@ module Data.Vector.Generic (
   unsafeIndexM, unsafeHeadM, unsafeLastM,
 
   -- ** Extracting subvectors (slicing)
-  slice, init, tail, take, drop,
+  slice, init, tail, take, drop, splitAt,
   unsafeSlice, unsafeInit, unsafeTail, unsafeTake, unsafeDrop,
 
   -- * Construction
@@ -71,6 +71,9 @@ module Data.Vector.Generic (
 
   -- * Elementwise operations
 
+  -- ** Indexing
+  indexed,
+
   -- ** Mapping
   map, imap, concatMap,
 
@@ -170,7 +173,7 @@ import qualified Data.List as List
 import Prelude hiding ( length, null,
                         replicate, (++), concat,
                         head, last,
-                        init, tail, take, drop, reverse,
+                        init, tail, take, drop, splitAt, reverse,
                         map, concat, concatMap,
                         zipWith, zipWith3, zip, zip3, unzip, unzip3,
                         filter, takeWhile, dropWhile, span, break,
@@ -182,10 +185,18 @@ import Prelude hiding ( length, null,
                         mapM, mapM_ )
 
 import Data.Typeable ( Typeable1, gcast1 )
-import Data.Data ( Data, DataType, mkNorepType )
 
 #include "vector.h"
 
+import Data.Data ( Data, DataType )
+#if MIN_VERSION_base(4,2,0)
+import Data.Data ( mkNoRepType )
+#else
+import Data.Data ( mkNorepType )
+mkNoRepType :: String -> DataType
+mkNoRepType = mkNorepType
+#endif
+
 -- Length information
 -- ------------------
 
@@ -397,6 +408,20 @@ drop n v = unsafeSlice (delay_inline min n' len)
   where n' = max n 0
         len = length v
 
+-- | /O(1)/ Yield the first @n@ elements paired with the remainder without copying.
+--
+-- Note that @'splitAt' n v@ is equivalent to @('take' n v, 'drop' n v)@
+-- but slightly more efficient.
+{-# INLINE_STREAM splitAt #-}
+splitAt :: Vector v a => Int -> v a -> (v a, v a)
+splitAt n v = ( unsafeSlice 0 m v
+              , unsafeSlice m (delay_inline max 0 (len - n')) v
+              )
+    where
+      m   = delay_inline min n' len
+      n'  = max n 0
+      len = length v
+
 -- | /O(1)/ Yield a slice of the vector without copying. The vector must
 -- contain at least @i+n@ elements but this is not checked.
 unsafeSlice :: Vector v a => Int   -- ^ @i@ starting index
@@ -595,9 +620,8 @@ concat vs = unstream (Stream.flatten mk step (Exact n) (Stream.fromList vs))
 -- | /O(n)/ Execute the monadic action the given number of times and store the
 -- results in a vector.
 replicateM :: (Monad m, Vector v a) => Int -> m a -> m (v a)
--- FIXME: specialise for ST and IO?
 {-# INLINE replicateM #-}
-replicateM n m = fromListN n `Monad.liftM` Monad.replicateM n m
+replicateM n m = unstreamM (MStream.replicateM n m)
 
 -- | Execute the monadic action and freeze the resulting vector.
 --
@@ -852,6 +876,14 @@ modifyWithStream :: Vector v a
 {-# INLINE modifyWithStream #-}
 modifyWithStream p v s = new (New.modifyWithStream p (clone v) s)
 
+-- Indexing
+-- --------
+
+-- | /O(n)/ Pair each element in a vector with its index
+indexed :: (Vector v a, Vector v (Int,a)) => v a -> v (Int,a)
+{-# INLINE indexed #-}
+indexed = unstream . Stream.indexed . stream
+
 -- Mapping
 -- -------
 
@@ -879,7 +911,6 @@ concatMap f = concat . Stream.toList . Stream.map f . stream
 -- | /O(n)/ Apply the monadic action to all elements of the vector, yielding a
 -- vector of results
 mapM :: (Monad m, Vector v a, Vector v b) => (a -> m b) -> v a -> m (v b)
--- FIXME: specialise for ST and IO?
 {-# INLINE mapM #-}
 mapM f = unstreamM . Stream.mapM f . stream
 
@@ -1736,12 +1767,33 @@ unstreamR s = new (New.unstreamR s)
 
  #-}
 
-unstreamM :: (Vector v a, Monad m) => MStream m a -> m (v a)
+unstreamM :: (Monad m, Vector v a) => MStream m a -> m (v a)
 {-# INLINE_STREAM unstreamM #-}
 unstreamM s = do
                 xs <- MStream.toList s
                 return $ unstream $ Stream.unsafeFromList (MStream.size s) xs
 
+unstreamPrimM :: (PrimMonad m, Vector v a) => MStream m a -> m (v a)
+{-# INLINE_STREAM unstreamPrimM #-}
+unstreamPrimM s = M.munstream s >>= unsafeFreeze
+
+-- FIXME: the next two functions are only necessary for the specialisations
+unstreamPrimM_IO :: Vector v a => MStream IO a -> IO (v a)
+{-# INLINE unstreamPrimM_IO #-}
+unstreamPrimM_IO = unstreamPrimM
+
+unstreamPrimM_ST :: Vector v a => MStream (ST s) a -> ST s (v a)
+{-# INLINE unstreamPrimM_ST #-}
+unstreamPrimM_ST = unstreamPrimM
+
+{-# RULES
+
+"unstreamM[IO]" unstreamM = unstreamPrimM_IO
+"unstreamM[ST]" unstreamM = unstreamPrimM_ST
+
+ #-}
+
+
 -- Recycling support
 -- -----------------
 
@@ -1794,7 +1846,7 @@ gfoldl f z v = z fromList `f` toList v
 
 mkType :: String -> DataType
 {-# INLINE mkType #-}
-mkType = mkNorepType
+mkType = mkNoRepType
 
 dataCast :: (Vector v a, Data a, Typeable1 v, Typeable1 t)
          => (forall d. Data  d => c (t d)) -> Maybe  (c (v a))