Begin large refactoring of mutable vectors
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 4 Dec 2009 07:03:49 +0000 (07:03 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 4 Dec 2009 07:03:49 +0000 (07:03 +0000)
Mutable vectors are now parametrised by the state token. They are also
expected to live in the ST monad for the purposes of initialisating immutable
vectors.

Data/Vector/Generic.hs
Data/Vector/Generic/New.hs
Data/Vector/Primitive/Mutable.hs
Data/Vector/Storable.hs
Data/Vector/Storable/Mutable.hs

index 2e51749..0fff433 100644 (file)
@@ -88,6 +88,7 @@ import qualified Data.Vector.Fusion.Stream.Monadic as MStream
 import           Data.Vector.Fusion.Stream.Size
 import           Data.Vector.Fusion.Util
 
+import Control.Monad.ST ( ST )
 import Prelude hiding ( length, null,
                         replicate, (++),
                         head, last,
@@ -107,7 +108,7 @@ import Prelude hiding ( length, null,
 --
 class Vector v a where
   -- | Construct a pure vector from a monadic initialiser (not fusible!)
-  basicNew     :: (forall mv m. MVector mv m a => m (mv a)) -> v a
+  basicNew     :: (forall mv s. MVector mv (ST s) a => ST s (mv a)) -> v a
 
   -- | Length of the vector (not fusible!)
   basicLength      :: v a -> Int
index 176032a..8e90fa6 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE Rank2Types #-}
+{-# LANGUAGE Rank2Types, FlexibleContexts #-}
 
 -- |
 -- Module      : Data.Vector.Generic.New
@@ -24,14 +24,15 @@ import           Data.Vector.Generic.Mutable ( MVector, MVectorPure )
 import           Data.Vector.Fusion.Stream ( Stream, MStream )
 import qualified Data.Vector.Fusion.Stream as Stream
 
+import Control.Monad.ST ( ST )
 import Control.Monad  ( liftM )
 import Prelude hiding ( init, tail, take, drop, reverse, map, filter )
 
 #include "vector.h"
 
-newtype New a = New (forall m mv. MVector mv m a => m (mv a))
+newtype New a = New (forall mv s. MVector mv (ST s) a => ST s (mv a))
 
-run :: MVector mv m a => New a -> m (mv a)
+run :: MVector mv (ST s) a => New a -> ST s (mv a)
 {-# INLINE run #-}
 run (New p) = p
 
@@ -39,7 +40,7 @@ apply :: (forall mv a. MVectorPure mv a => mv a -> mv a) -> New a -> New a
 {-# INLINE apply #-}
 apply f (New p) = New (liftM f p)
 
-modify :: New a -> (forall m mv. MVector mv m a => mv a -> m ()) -> New a
+modify :: New a -> (forall mv s. MVector mv (ST s) a => mv a -> ST s ()) -> New a
 {-# INLINE modify #-}
 modify (New p) q = New (do { v <- p; q v; return v })
 
index 54d265a..8038c3c 100644 (file)
@@ -1,5 +1,5 @@
 {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, ScopedTypeVariables,
-             TypeFamilies #-}
+             FlexibleContexts #-}
 
 -- |
 -- Module      : Data.Vector.Primitive.Mutable
@@ -46,13 +46,9 @@ instance Prim a => G.MVectorPure (MVector s) a where
       between x y z = x >= y && x < z
 
 
-instance (Prim a, PrimMonad m, PrimState m ~ s)
-           => G.MVector (MVector s) m a where
+instance Prim a => G.MVector (MVector s) (ST s) a where
   {-# INLINE unsafeNew #-}
-  unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
-              $ do
-                  arr <- newByteArray (n * sizeOf (undefined :: a))
-                  return (MVector 0 n arr)
+  unsafeNew = unsafeNew_generic
 
   {-# INLINE unsafeRead #-}
   unsafeRead (MVector i n arr) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
@@ -65,3 +61,27 @@ instance (Prim a, PrimMonad m, PrimState m ~ s)
   {-# INLINE clear #-}
   clear _ = return ()
 
+instance Prim a => G.MVector (MVector RealWorld) IO a where
+  {-# INLINE unsafeNew #-}
+  unsafeNew  = unsafeNew_generic
+
+  {-# INLINE unsafeRead #-}
+  unsafeRead (MVector i n arr) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
+                                 $ readByteArray arr (i+j)
+
+  {-# INLINE unsafeWrite #-}
+  unsafeWrite (MVector i n arr) j x = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
+                                    $ writeByteArray arr (i+j) x
+
+  {-# INLINE clear #-}
+  clear _ = return ()
+
+unsafeNew_generic
+  :: forall m a. (PrimMonad m, Prim a, G.MVector (MVector (PrimState m)) m a)
+                        => Int -> m (MVector (PrimState m) a)
+{-# INLINE unsafeNew_generic #-}
+unsafeNew_generic n = UNSAFE_CHECK(checkLength) "unsafeNew" n $
+  do
+    arr <- newByteArray (n * sizeOf (undefined :: a))
+    return (MVector 0 n arr)
+
index 6749918..cf63af1 100644 (file)
@@ -72,7 +72,7 @@ import Data.Vector.Storable.Internal
 import Foreign.Storable
 import Foreign.ForeignPtr
 
-import System.IO.Unsafe ( unsafePerformIO )
+import Control.Monad.ST ( ST, runST )
 
 import Prelude hiding ( length, null,
                         replicate, (++),
@@ -102,9 +102,9 @@ instance (Show a, Storable a) => Show (Vector a) where
 
 instance Storable a => G.Vector Vector a where
   {-# INLINE basicNew #-}
-  basicNew init = unsafePerformIO (do
-                                     MVector i n p <- init
-                                     return (Vector i n p))
+  basicNew init = runST (do
+                           MVector i n p <- (id :: ST s (MVector s a) -> ST s (MVector s a)) init
+                           return (Vector i n p))
 
   {-# INLINE basicLength #-}
   basicLength (Vector _ n _) = n
index 39f4f89..37c96af 100644 (file)
@@ -20,14 +20,17 @@ import qualified Data.Vector.Generic.Mutable as G
 import Foreign.Storable
 import Foreign.ForeignPtr
 
+import Control.Monad.Primitive ( RealWorld )
+import Control.Monad.ST ( ST, unsafeIOToST )
+
 #include "vector.h"
 
--- | Mutable 'Storable'-based vectors in the 'IO' monad.
-data MVector a = MVector {-# UNPACK #-} !Int
-                         {-# UNPACK #-} !Int
-                         {-# UNPACK #-} !(ForeignPtr a)
+-- | Mutable 'Storable'-based vectors
+data MVector a = MVector {-# UNPACK #-} !Int
+                           {-# UNPACK #-} !Int
+                           {-# UNPACK #-} !(ForeignPtr a)
 
-instance G.MVectorPure MVector a where
+instance G.MVectorPure (MVector s) a where
   {-# INLINE length #-}
   length (MVector _ n _) = n
 
@@ -41,21 +44,47 @@ instance G.MVectorPure MVector a where
   overlaps (MVector i m p) (MVector j n q)
     = True
 
-instance Storable a => G.MVector MVector IO a where
+instance Storable a => G.MVector (MVector s) (ST s) a where
   {-# INLINE unsafeNew #-}
-  unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
-              $ MVector 0 n `fmap` mallocForeignPtrArray n
+  unsafeNew n = unsafeIOToST (unsafeNewIO n)
 
   {-# INLINE unsafeRead #-}
-  unsafeRead (MVector i n p) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
-                               $ withForeignPtr p $ \ptr ->
-                                 peekElemOff ptr (i+j)
+  unsafeRead v i = unsafeIOToST (unsafeReadIO v i)
+    
+  {-# INLINE unsafeWrite #-}
+  unsafeWrite v i x = unsafeIOToST (unsafeWriteIO v i x)
+
+  {-# INLINE clear #-}
+  clear _ = return ()
+
+
+instance Storable a => G.MVector (MVector RealWorld) IO a where
+  {-# INLINE unsafeNew #-}
+  unsafeNew = unsafeNewIO
+
+  {-# INLINE unsafeRead #-}
+  unsafeRead = unsafeReadIO
      
   {-# INLINE unsafeWrite #-}
-  unsafeWrite (MVector i n p) j x = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
-                                  $ withForeignPtr p $ \ptr ->
-                                    pokeElemOff ptr (i+j) x 
+  unsafeWrite = unsafeWriteIO
 
   {-# INLINE clear #-}
   clear _ = return ()
 
+unsafeNewIO :: Storable a => Int -> IO (MVector s a)
+{-# unsafeNewIO #-}
+unsafeNewIO n = UNSAFE_CHECK(checkLength) "unsafeNew" n
+              $ MVector 0 n `fmap` mallocForeignPtrArray n
+
+unsafeReadIO :: Storable a => MVector s a -> Int -> IO a
+{-# INLINE unsafeReadIO #-}
+unsafeReadIO (MVector i n p) j = UNSAFE_CHECK(checkIndex) "unsafeRead" j n
+                               $ withForeignPtr p $ \ptr ->
+                                peekElemOff ptr (i+j)
+
+unsafeWriteIO :: Storable a => MVector s a -> Int -> a -> IO ()
+{-# INLINE unsafeWriteIO #-}
+unsafeWriteIO (MVector i n p) j x = UNSAFE_CHECK(checkIndex) "unsafeWrite" j n
+                                  $ withForeignPtr p $ \ptr ->
+                                    pokeElemOff ptr (i+j) x 
+