Unify MVectorPure and MVector
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 4 Dec 2009 08:25:19 +0000 (08:25 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 4 Dec 2009 08:25:19 +0000 (08:25 +0000)
Data/Vector/Generic.hs
Data/Vector/Generic/Mutable.hs
Data/Vector/Generic/New.hs
Data/Vector/Mutable.hs
Data/Vector/Primitive/Mutable.hs
Data/Vector/Storable/Mutable.hs

index 0fff433..71f6b31 100644 (file)
@@ -108,7 +108,7 @@ import Prelude hiding ( length, null,
 --
 class Vector v a where
   -- | Construct a pure vector from a monadic initialiser (not fusible!)
-  basicNew     :: (forall mv s. MVector mv (ST s) a => ST s (mv a)) -> v a
+  basicNew     :: (forall mv s. MVector mv a => ST s (mv s a)) -> v a
 
   -- | Length of the vector (not fusible!)
   basicLength      :: v a -> Int
index af44348..bf70aaf 100644 (file)
@@ -12,7 +12,7 @@
 --
 
 module Data.Vector.Generic.Mutable (
-  MVectorPure(..), MVector(..),
+  MVector(..),
 
   slice, new, newWith, read, write, copy, grow,
   unstream, transform,
@@ -24,6 +24,8 @@ import           Data.Vector.Fusion.Stream      ( Stream, MStream )
 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
 import           Data.Vector.Fusion.Stream.Size
 
+import Control.Monad.Primitive ( PrimMonad, PrimState )
+
 import GHC.Float (
     double2Int, int2Double
   )
@@ -35,50 +37,47 @@ import Prelude hiding ( length, reverse, map, read )
 gROWTH_FACTOR :: Double
 gROWTH_FACTOR = 1.5
 
--- | Basic pure functions on mutable vectors
-class MVectorPure v a where
+-- | Class of mutable vectors parametrised with a primitive state token.
+--
+class MVector v a where
   -- | Length of the mutable vector
-  length           :: v a -> Int
+  length           :: v a -> Int
 
   -- | Yield a part of the mutable vector without copying it. No range checks!
-  unsafeSlice      :: v a -> Int  -- ^ starting index
-                          -> Int  -- ^ length of the slice
-                          -> v a
+  unsafeSlice      :: v a -> Int  -- ^ starting index
+                            -> Int  -- ^ length of the slice
+                            -> v s a
 
   -- Check whether two vectors overlap.
-  overlaps         :: v a -> v a -> Bool
+  overlaps         :: v s a -> v s a -> Bool
 
--- | Class of mutable vectors. The type @m@ is the monad in which the mutable
--- vector can be transformed and @a@ is the type of elements.
---
-class (Monad m, MVectorPure v a) => MVector v m a where
   -- | Create a mutable vector of the given length. Length is not checked!
-  unsafeNew        :: Int -> m (v a)
+  unsafeNew        :: PrimMonad m => Int -> m (v (PrimState m) a)
 
   -- | Create a mutable vector of the given length and fill it with an
   -- initial value. Length is not checked!
-  unsafeNewWith    :: Int -> a -> m (v a)
+  unsafeNewWith    :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
 
   -- | Yield the element at the given position. Index is not checked!
-  unsafeRead       :: v a -> Int -> m a
+  unsafeRead       :: PrimMonad m => v (PrimState m) a -> Int -> m a
 
   -- | Replace the element at the given position. Index is not checked!
-  unsafeWrite      :: v a -> Int -> a -> m ()
+  unsafeWrite      :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
 
   -- | Clear all references to external objects
-  clear            :: v a -> m ()
+  clear            :: PrimMonad m => v (PrimState m) a -> m ()
 
   -- | Write the value at each position.
-  set              :: v a -> a -> m ()
+  set              :: PrimMonad m => v (PrimState m) a -> a -> m ()
 
   -- | Copy a vector. The two vectors may not overlap. This is not checked!
-  unsafeCopy       :: v a   -- ^ target
-                   -> v a   -- ^ source
-                   -> m ()
+  unsafeCopy       :: PrimMonad m => v (PrimState m) a   -- ^ target
+                                  -> v (PrimState m) a   -- ^ source
+                                  -> m ()
 
   -- | Grow a vector by the given number of elements. The length is not
   -- checked!
-  unsafeGrow       :: v a -> Int -> m (v a)
+  unsafeGrow :: PrimMonad m => v (PrimState m) a -> Int -> m (v (PrimState m) a)
 
   {-# INLINE unsafeNewWith #-}
   unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
@@ -124,41 +123,42 @@ class (Monad m, MVectorPure v a) => MVector v m a where
 
 -- | Yield a part of the mutable vector without copying it. Safer version of
 -- 'unsafeSlice'.
-slice :: MVectorPure v a => v a -> Int -> Int -> v a
+slice :: MVector v a => v s a -> Int -> Int -> v s a
 {-# INLINE slice #-}
 slice v i n = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
             $ unsafeSlice v i n
 
 -- | Create a mutable vector of the given length. Safer version of
 -- 'unsafeNew'.
-new :: MVector v m a => Int -> m (v a)
+new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
 {-# INLINE new #-}
 new n = BOUNDS_CHECK(checkLength) "new" n
       $ unsafeNew n
 
 -- | Create a mutable vector of the given length and fill it with an
 -- initial value. Safer version of 'unsafeNewWith'.
-newWith :: MVector v m a => Int -> a -> m (v a)
+newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
 {-# INLINE newWith #-}
 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
             $ unsafeNewWith n x
 
 -- | Yield the element at the given position. Safer version of 'unsafeRead'.
-read :: MVector v m a => v a -> Int -> m a
+read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
 {-# INLINE read #-}
 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
          $ unsafeRead v i
 
 -- | Replace the element at the given position. Safer version of
 -- 'unsafeWrite'.
-write :: MVector v m a => v a -> Int -> a -> m ()
+write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
 {-# INLINE write #-}
 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
             $ unsafeWrite v i x
 
 -- | Copy a vector. The two vectors may not overlap. Safer version of
 -- 'unsafeCopy'.
-copy :: MVector v m a => v a -> v a -> m ()
+copy :: (PrimMonad m, MVector v a)
+                => v (PrimState m) a -> v (PrimState m) a -> m ()
 {-# INLINE copy #-}
 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
                                           (not (dst `overlaps` src))
@@ -168,12 +168,13 @@ copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
 
 -- | Grow a vector by the given number of elements. Safer version of
 -- 'unsafeGrow'.
-grow :: MVector v m a => v a -> Int -> m (v a)
+grow :: (PrimMonad m, MVector v a)
+                => v (PrimState m) a -> Int -> m (v (PrimState m) a)
 {-# INLINE grow #-}
 grow v by = BOUNDS_CHECK(checkLength) "grow" by
           $ unsafeGrow v by
 
-mstream :: MVector v m a => v a -> MStream m a
+mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
 {-# INLINE mstream #-}
 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
   where
@@ -184,7 +185,8 @@ mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
                            return $ Just (x, i+1)
           | otherwise = return $ Nothing
 
-internal_munstream :: MVector v m a => v a -> MStream m a -> m (v a)
+internal_munstream :: (PrimMonad m, MVector v a)
+        => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
 {-# INLINE internal_munstream #-}
 internal_munstream v s = v `seq` do
                                    n' <- MStream.foldM put 0 s
@@ -196,20 +198,22 @@ internal_munstream v s = v `seq` do
                   $ unsafeWrite v i x
                 return (i+1)
 
-transform :: MVector v m a => (MStream m a -> MStream m a) -> v a -> m (v a)
+transform :: (PrimMonad m, MVector v a)
+  => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
 {-# INLINE_STREAM transform #-}
 transform f v = internal_munstream v (f (mstream v))
 
 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
 -- inexact.
-unstream :: MVector v m a => Stream a -> m (v a)
+unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
 {-# INLINE_STREAM unstream #-}
 unstream s = case upperBound (Stream.size s) of
                Just n  -> unstreamMax     s n
                Nothing -> unstreamUnknown s
 
-unstreamMax :: MVector v m a => Stream a -> Int -> m (v a)
+unstreamMax
+  :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
 {-# INLINE unstreamMax #-}
 unstreamMax s n
   = do
@@ -221,7 +225,8 @@ unstreamMax s n
       n' <- Stream.foldM' put 0 s
       return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n $ slice v 0 n'
 
-unstreamUnknown :: MVector v m a => Stream a -> m (v a)
+unstreamUnknown
+  :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
 {-# INLINE unstreamUnknown #-}
 unstreamUnknown s
   = do
@@ -250,7 +255,8 @@ unstreamUnknown s
               $ double2Int
               $ int2Double (length v) * gROWTH_FACTOR
 
-accum :: MVector v m a => (a -> b -> a) -> v a -> Stream (Int, b) -> m ()
+accum :: (PrimMonad m, MVector v a)
+        => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
 {-# INLINE accum #-}
 accum f !v s = Stream.mapM_ upd s
   where
@@ -259,11 +265,12 @@ accum f !v s = Stream.mapM_ upd s
                   a <- read v i
                   write v i (f a b)
 
-update :: MVector v m a => v a -> Stream (Int, a) -> m ()
+update :: (PrimMonad m, MVector v a)
+                        => v (PrimState m) a -> Stream (Int, a) -> m ()
 {-# INLINE update #-}
 update = accum (const id)
 
-reverse :: MVector v m a => v a -> m ()
+reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
 {-# INLINE reverse #-}
 reverse !v = reverse_loop 0 (length v - 1)
   where
index 8e90fa6..4f77169 100644 (file)
@@ -19,7 +19,7 @@ module Data.Vector.Generic.New (
 ) where
 
 import qualified Data.Vector.Generic.Mutable as MVector
-import           Data.Vector.Generic.Mutable ( MVector, MVectorPure )
+import           Data.Vector.Generic.Mutable ( MVector )
 
 import           Data.Vector.Fusion.Stream ( Stream, MStream )
 import qualified Data.Vector.Fusion.Stream as Stream
@@ -30,17 +30,17 @@ import Prelude hiding ( init, tail, take, drop, reverse, map, filter )
 
 #include "vector.h"
 
-newtype New a = New (forall mv s. MVector mv (ST s) a => ST s (mv a))
+newtype New a = New (forall mv s. MVector mv a => ST s (mv s a))
 
-run :: MVector mv (ST s) a => New a -> ST s (mv a)
+run :: MVector mv a => New a -> ST s (mv s a)
 {-# INLINE run #-}
 run (New p) = p
 
-apply :: (forall mv a. MVectorPure mv a => mv a -> mv a) -> New a -> New a
+apply :: (forall mv s a. MVector mv a => mv s a -> mv s a) -> New a -> New a
 {-# INLINE apply #-}
 apply f (New p) = New (liftM f p)
 
-modify :: New a -> (forall mv s. MVector mv (ST s) a => mv a -> ST s ()) -> New a
+modify :: New a -> (forall mv s. MVector mv a => mv s a -> ST s ()) -> New a
 {-# INLINE modify #-}
 modify (New p) q = New (do { v <- p; q v; return v })
 
index 148a8b2..5851a74 100644 (file)
@@ -30,7 +30,7 @@ data MVector s a = MVector {-# UNPACK #-} !Int
 type IOVector = MVector RealWorld
 type STVector s = MVector s
 
-instance G.MVectorPure (MVector s) a where
+instance G.MVector MVector a where
   length (MVector _ n _) = n
   unsafeSlice (MVector i n arr) j m
     = UNSAFE_CHECK(checkSlice) "unsafeSlice" j m n
@@ -43,8 +43,6 @@ instance G.MVectorPure (MVector s) a where
     where
       between x y z = x >= y && x < z
 
-
-instance (PrimMonad m, PrimState m ~ s) => G.MVector (MVector s) m a where
   {-# INLINE unsafeNew #-}
   unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
               $ do
index 8038c3c..7382c31 100644 (file)
@@ -32,7 +32,7 @@ data MVector s a = MVector {-# UNPACK #-} !Int
 type IOVector = MVector RealWorld
 type STVector s = MVector s
 
-instance Prim a => G.MVectorPure (MVector s) a where
+instance Prim a => G.MVector MVector a where
   length (MVector _ n _) = n
   unsafeSlice (MVector i n arr) j m
     = UNSAFE_CHECK(checkSlice) "unsafeSlice" j m n
@@ -45,10 +45,10 @@ instance Prim a => G.MVectorPure (MVector s) a where
     where
       between x y z = x >= y && x < z
 
-
-instance Prim a => G.MVector (MVector s) (ST s) a where
   {-# INLINE unsafeNew #-}
-  unsafeNew = unsafeNew_generic
+  unsafeNew n = do
+                  arr <- newByteArray (n * sizeOf (undefined :: a))
+                  return (MVector 0 n arr)
 
   {-# INLINE unsafeRead #-}
   unsafeRead (MVector i n arr) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
@@ -61,27 +61,3 @@ instance Prim a => G.MVector (MVector s) (ST s) a where
   {-# INLINE clear #-}
   clear _ = return ()
 
-instance Prim a => G.MVector (MVector RealWorld) IO a where
-  {-# INLINE unsafeNew #-}
-  unsafeNew  = unsafeNew_generic
-
-  {-# INLINE unsafeRead #-}
-  unsafeRead (MVector i n arr) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
-                                 $ readByteArray arr (i+j)
-
-  {-# INLINE unsafeWrite #-}
-  unsafeWrite (MVector i n arr) j x = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
-                                    $ writeByteArray arr (i+j) x
-
-  {-# INLINE clear #-}
-  clear _ = return ()
-
-unsafeNew_generic
-  :: forall m a. (PrimMonad m, Prim a, G.MVector (MVector (PrimState m)) m a)
-                        => Int -> m (MVector (PrimState m) a)
-{-# INLINE unsafeNew_generic #-}
-unsafeNew_generic n = UNSAFE_CHECK(checkLength) "unsafeNew" n $
-  do
-    arr <- newByteArray (n * sizeOf (undefined :: a))
-    return (MVector 0 n arr)
-
index 37c96af..4366bd1 100644 (file)
@@ -20,8 +20,7 @@ import qualified Data.Vector.Generic.Mutable as G
 import Foreign.Storable
 import Foreign.ForeignPtr
 
-import Control.Monad.Primitive ( RealWorld )
-import Control.Monad.ST ( ST, unsafeIOToST )
+import Control.Monad.Primitive ( unsafePrimToPrim )
 
 #include "vector.h"
 
@@ -30,7 +29,7 @@ data MVector s a = MVector {-# UNPACK #-} !Int
                            {-# UNPACK #-} !Int
                            {-# UNPACK #-} !(ForeignPtr a)
 
-instance G.MVectorPure (MVector s) a where
+instance Storable a => G.MVector MVector a where
   {-# INLINE length #-}
   length (MVector _ n _) = n
 
@@ -44,47 +43,23 @@ instance G.MVectorPure (MVector s) a where
   overlaps (MVector i m p) (MVector j n q)
     = True
 
-instance Storable a => G.MVector (MVector s) (ST s) a where
   {-# INLINE unsafeNew #-}
-  unsafeNew n = unsafeIOToST (unsafeNewIO n)
+  unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
+              $ unsafePrimToPrim
+              $ MVector 0 n `fmap` mallocForeignPtrArray n
 
   {-# INLINE unsafeRead #-}
-  unsafeRead v i = unsafeIOToST (unsafeReadIO v i)
-    
-  {-# INLINE unsafeWrite #-}
-  unsafeWrite v i x = unsafeIOToST (unsafeWriteIO v i x)
+  unsafeRead (MVector i n p) j
+    = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
+    $ unsafePrimToPrim
+    $ withForeignPtr p $ \ptr -> peekElemOff ptr (i+j)
 
-  {-# INLINE clear #-}
-  clear _ = return ()
-
-
-instance Storable a => G.MVector (MVector RealWorld) IO a where
-  {-# INLINE unsafeNew #-}
-  unsafeNew = unsafeNewIO
-
-  {-# INLINE unsafeRead #-}
-  unsafeRead = unsafeReadIO
-     
   {-# INLINE unsafeWrite #-}
-  unsafeWrite = unsafeWriteIO
+  unsafeWrite (MVector i n p) j x
+    = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
+    $ unsafePrimToPrim
+    $ withForeignPtr p $ \ptr -> pokeElemOff ptr (i+j) x 
 
   {-# INLINE clear #-}
   clear _ = return ()
 
-unsafeNewIO :: Storable a => Int -> IO (MVector s a)
-{-# unsafeNewIO #-}
-unsafeNewIO n = UNSAFE_CHECK(checkLength) "unsafeNew" n
-              $ MVector 0 n `fmap` mallocForeignPtrArray n
-
-unsafeReadIO :: Storable a => MVector s a -> Int -> IO a
-{-# INLINE unsafeReadIO #-}
-unsafeReadIO (MVector i n p) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
-                               $ withForeignPtr p $ \ptr ->
-                                peekElemOff ptr (i+j)
-
-unsafeWriteIO :: Storable a => MVector s a -> Int -> a -> IO ()
-{-# INLINE unsafeWriteIO #-}
-unsafeWriteIO (MVector i n p) j x = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
-                                  $ withForeignPtr p $ \ptr ->
-                                    pokeElemOff ptr (i+j) x 
-