Replace type families by GADTs for associating a monad with a mutable vector
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 11 Jul 2008 16:22:01 +0000 (16:22 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 11 Jul 2008 16:22:01 +0000 (16:22 +0000)
This is mostly to work around #2440, but it's perhaps also more consistent.

Data/Vector.hs
Data/Vector/Base.hs
Data/Vector/Base/Mutable.hs
Data/Vector/Mutable.hs
Data/Vector/Unboxed.hs
Data/Vector/Unboxed/Mutable.hs

index 583156d..f1f1ece 100644 (file)
@@ -19,12 +19,14 @@ data Vector a = Vector {-# UNPACK #-} !Int
 
 instance Base Vector a where
   {-# INLINE create #-}
-  create init = runST (do
-      Mut.Vector i n marr# <- init
-      ST (\s# -> case unsafeFreezeArray# marr# s# of
-                   (# s2#, arr# #) -> (# s2#, Vector i n arr# #)
-         )
-    )
+  create init = runST (do_create init)
+    where
+      do_create :: ST s (Mut.Vector (ST s) a) -> ST s (Vector a)
+      do_create init = do
+                         Mut.Vector i n marr# <- init
+                         ST (\s# -> case unsafeFreezeArray# marr# s# of
+                              (# s2#, arr# #) -> (# s2#, Vector i n arr# #)
+                            )
 
   {-# INLINE length #-}
   length (Vector _ n _) = n
index ae88f17..f404709 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE TypeFamilies, FlexibleContexts, RankNTypes, MultiParamTypeClasses, BangPatterns, CPP #-}
+{-# LANGUAGE Rank2Types, MultiParamTypeClasses, BangPatterns, CPP #-}
 
 #include "phases.h"
 
@@ -14,7 +14,7 @@ import           Data.Vector.Stream.Size
 import Prelude hiding ( length, map, zipWith, sum )
 
 class Base v a where
-  create       :: (forall mv. Mut.Base mv a => Mut.Trans mv (mv a)) -> v a
+  create       :: (forall mv m. Mut.Base mv m a => m (mv m a)) -> v a
 
   length       :: v a -> Int
   unsafeSlice  :: v a -> Int -> Int -> v a
index f6cc57b..f28490c 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE TypeFamilies, FlexibleContexts, MultiParamTypeClasses #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
 module Data.Vector.Base.Mutable (
   Base(..),
 
@@ -21,23 +21,21 @@ import Prelude hiding ( length, read )
 gROWTH_FACTOR :: Double
 gROWTH_FACTOR = 1.5
 
-class Monad (Trans v) => Base v a where
-  type Trans   v :: * -> *
+class Monad m => Base v m a where
+  length           :: v m a -> Int
+  unsafeSlice      :: v m a -> Int -> Int -> v m a
 
-  length           :: v a -> Int
-  unsafeSlice      :: v a -> Int -> Int -> v a
+  unsafeNew        :: Int -> m (v m a)
+  unsafeNewWith    :: Int -> a -> m (v m a)
 
-  unsafeNew        :: Int -> Trans v (v a)
-  unsafeNewWith    :: Int -> a -> Trans v (v a)
+  unsafeRead       :: v m a -> Int -> m a
+  unsafeWrite      :: v m a -> Int -> a -> m ()
 
-  unsafeRead       :: v a -> Int -> Trans v a
-  unsafeWrite      :: v a -> Int -> a -> Trans v ()
+  set              :: v m a -> a -> m ()
+  unsafeCopy       :: v m a -> v m a -> m ()
+  unsafeGrow       :: v m a -> Int -> m (v m a)
 
-  set              :: v a -> a -> Trans v ()
-  unsafeCopy       :: v a -> v a -> Trans v ()
-  unsafeGrow       :: v a -> Int -> Trans v (v a)
-
-  overlaps         :: v a -> v a -> Bool
+  overlaps         :: v m a -> v m a -> Bool
 
   {-# INLINE unsafeNewWith #-}
   unsafeNewWith n x = do
@@ -74,49 +72,49 @@ class Monad (Trans v) => Base v a where
     where
       n = length v
 
-inBounds :: Base v a => v a -> Int -> Bool
+inBounds :: Base v m a => v m a -> Int -> Bool
 {-# INLINE inBounds #-}
 inBounds v i = i >= 0 && i < length v
 
-slice :: Base v a => v a -> Int -> Int -> v a
+slice :: Base v m a => v m a -> Int -> Int -> v m a
 {-# INLINE slice #-}
 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
             $ unsafeSlice v i n
 
-new :: (Base v a, m ~ Trans v) => Int -> m (v a)
+new :: Base v m a => Int -> m (v m a)
 {-# INLINE new #-}
 new n = assert (n >= 0) $ unsafeNew n
 
-newWith :: (Base v a, m ~ Trans v) => Int -> a -> m (v a)
+newWith :: Base v m a => Int -> a -> m (v m a)
 {-# INLINE newWith #-}
 newWith n x = assert (n >= 0) $ unsafeNewWith n x
 
-read :: (Base v a, m ~ Trans v) => v a -> Int -> m a
+read :: Base v m a => v m a -> Int -> m a
 {-# INLINE read #-}
 read v i = assert (inBounds v i) $ unsafeRead v i
 
-write :: (Base v a, m ~ Trans v) => v a -> Int -> a -> m ()
+write :: Base v m a => v m a -> Int -> a -> m ()
 {-# INLINE write #-}
 write v i x = assert (inBounds v i) $ unsafeWrite v i x
 
-copy :: (Base v a, m ~ Trans v) => v a -> v a -> m ()
+copy :: Base v m a => v m a -> v m a -> m ()
 {-# INLINE copy #-}
 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
              $ unsafeCopy dst src
 
-grow :: (Base v a, m ~ Trans v) => v a -> Int -> m (v a)
+grow :: Base v m a => v m a -> Int -> m (v m a)
 {-# INLINE grow #-}
 grow v by = assert (by >= 0)
           $ unsafeGrow v by
 
 
-unstream :: (Base v a, m ~ Trans v) => Stream a -> m (v a)
+unstream :: Base v m a => Stream a -> m (v m a)
 {-# INLINE unstream #-}
 unstream s = case upperBound (Stream.size s) of
                Just n  -> unstreamMax     s n
                Nothing -> unstreamUnknown s
 
-unstreamMax :: (Base v a, m ~ Trans v) => Stream a -> Int -> m (v a)
+unstreamMax :: Base v m a => Stream a -> Int -> m (v m a)
 {-# INLINE unstreamMax #-}
 unstreamMax s n
   = do
@@ -125,7 +123,7 @@ unstreamMax s n
       n' <- Stream.foldM put 0 s
       return $ slice v 0 n'
 
-unstreamUnknown :: (Base v a, m ~ Trans v) => Stream a -> m (v a)
+unstreamUnknown :: Base v m a => Stream a -> m (v m a)
 {-# INLINE unstreamUnknown #-}
 unstreamUnknown s
   = do
index be2b360..26515f9 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE MagicHash, UnboxedTuples, TypeFamilies, MultiParamTypeClasses, FlexibleInstances #-}
+{-# LANGUAGE MagicHash, UnboxedTuples, MultiParamTypeClasses, GADTs, FlexibleInstances #-}
 
 module Data.Vector.Mutable ( Vector(..) )
 where
@@ -12,13 +12,13 @@ import GHC.ST   ( ST(..) )
 
 import GHC.Base ( Int(..) )
 
-data Vector s a = Vector {-# UNPACK #-} !Int
-                         {-# UNPACK #-} !Int
-                                        (MutableArray# s a)
-
-instance Base.Base (Vector s) a where
-  type Base.Trans (Vector s) = ST s
+data Vector m a where
+  Vector :: {-# UNPACK #-} !Int
+         -> {-# UNPACK #-} !Int
+         -> MutableArray# s a
+         -> Vector (ST s) a
 
+instance Base.Base Vector (ST s) a where
   length (Vector _ n _) = n
   unsafeSlice (Vector i _ arr#) j m = Vector (i+j) m arr#
 
@@ -43,11 +43,11 @@ instance Base.Base (Vector s) a where
     where
       between x y z = x >= y && x < z
 
-unsafeNew :: Int -> ST s (Vector s a)
+unsafeNew :: Int -> ST s (Vector (ST s) a)
 {-# INLINE unsafeNew #-}
 unsafeNew n = unsafeNewWith n (error "Data.Vector.Mutable: uninitialised elemen t")
 
-unsafeNewWith :: Int -> a -> ST s (Vector s a)
+unsafeNewWith :: Int -> a -> ST s (Vector (ST s) a)
 {-# INLINE unsafeNewWith #-}
 unsafeNewWith (I# n#) x = ST (\s# ->
     case newArray# n# x s# of
index bd2250d..a6ffcc9 100644 (file)
@@ -20,12 +20,14 @@ data Vector a = Vector {-# UNPACK #-} !Int
 
 instance Unbox a => Base Vector a where
   {-# INLINE create #-}
-  create init = runST (do
-      Mut.Vector i n marr# <- init
-      ST (\s# -> case unsafeFreezeByteArray# marr# s# of
-                   (# s2#, arr# #) -> (# s2#, Vector i n arr# #)
-         )
-    )
+  create init = runST (do_create init)
+    where
+      do_create :: ST s (Mut.Vector (ST s) a) -> ST s (Vector a)
+      do_create init = do
+                         Mut.Vector i n marr# <- init
+                         ST (\s# -> case unsafeFreezeByteArray# marr# s# of
+                              (# s2#, arr# #) -> (# s2#, Vector i n arr# #)
+                            )
 
   {-# INLINE length #-}
   length (Vector _ n _) = n
index ed02898..434369b 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE MagicHash, UnboxedTuples, TypeFamilies, MultiParamTypeClasses, FlexibleInstances, ScopedTypeVariables #-}
+{-# LANGUAGE MagicHash, UnboxedTuples, MultiParamTypeClasses, FlexibleInstances, GADTs, ScopedTypeVariables #-}
 
 module Data.Vector.Unboxed.Mutable ( Vector(..) )
 where
@@ -13,14 +13,13 @@ import GHC.ST   ( ST(..) )
 
 import GHC.Base ( Int(..) )
 
-data Vector s a = Vector {-# UNPACK #-} !Int
-                         {-# UNPACK #-} !Int
-                                        (MutableByteArray# s)
-
-
-instance Unbox a => Base.Base (Vector s) a where
-  type Base.Trans (Vector s) = ST s
+data Vector m a where
+   Vector :: {-# UNPACK #-} !Int
+          -> {-# UNPACK #-} !Int
+          -> MutableByteArray# s
+          -> Vector (ST s) a
 
+instance Unbox a => Base.Base Vector (ST s) a where
   length (Vector _ n _) = n
   unsafeSlice (Vector i _ arr#) j m = Vector (i+j) m arr#