Vector combinators
[darcs-mirrors/vector.git] / Data / Vector / Base / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses #-}
2 module Data.Vector.Base.Mutable (
3 Base(..),
4
5 slice, new, newWith, read, write, copy, grow, unstream
6 ) where
7
8 import qualified Data.Vector.Stream as Stream
9 import Data.Vector.Stream ( Stream )
10 import Data.Vector.Stream.Size
11
12 import Control.Monad.ST ( ST )
13 import Control.Exception ( assert )
14
15 import GHC.Float (
16 double2Int, int2Double
17 )
18
19 import Prelude hiding ( length, read )
20
21 gROWTH_FACTOR :: Double
22 gROWTH_FACTOR = 1.5
23
24 class Monad m => Base v m a where
25 length :: v m a -> Int
26 unsafeSlice :: v m a -> Int -> Int -> v m a
27
28 unsafeNew :: Int -> m (v m a)
29 unsafeNewWith :: Int -> a -> m (v m a)
30
31 unsafeRead :: v m a -> Int -> m a
32 unsafeWrite :: v m a -> Int -> a -> m ()
33
34 set :: v m a -> a -> m ()
35 unsafeCopy :: v m a -> v m a -> m ()
36 unsafeGrow :: v m a -> Int -> m (v m a)
37
38 overlaps :: v m a -> v m a -> Bool
39
40 {-# INLINE unsafeNewWith #-}
41 unsafeNewWith n x = do
42 v <- unsafeNew n
43 set v x
44 return v
45
46 {-# INLINE set #-}
47 set v x = do_set 0
48 where
49 n = length v
50
51 do_set i | i < n = do
52 unsafeWrite v i x
53 do_set (i+1)
54 | otherwise = return ()
55
56 {-# INLINE unsafeCopy #-}
57 unsafeCopy dst src = do_copy 0
58 where
59 n = length src
60
61 do_copy i | i < n = do
62 x <- unsafeRead src i
63 unsafeWrite dst i x
64 do_copy (i+1)
65 | otherwise = return ()
66
67 {-# INLINE unsafeGrow #-}
68 unsafeGrow v by = do
69 v' <- unsafeNew (n+by)
70 unsafeCopy (unsafeSlice v' 0 n) v
71 return v'
72 where
73 n = length v
74
75 inBounds :: Base v m a => v m a -> Int -> Bool
76 {-# INLINE inBounds #-}
77 inBounds v i = i >= 0 && i < length v
78
79 slice :: Base v m a => v m a -> Int -> Int -> v m a
80 {-# INLINE slice #-}
81 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
82 $ unsafeSlice v i n
83
84 new :: Base v m a => Int -> m (v m a)
85 {-# INLINE new #-}
86 new n = assert (n >= 0) $ unsafeNew n
87
88 newWith :: Base v m a => Int -> a -> m (v m a)
89 {-# INLINE newWith #-}
90 newWith n x = assert (n >= 0) $ unsafeNewWith n x
91
92 read :: Base v m a => v m a -> Int -> m a
93 {-# INLINE read #-}
94 read v i = assert (inBounds v i) $ unsafeRead v i
95
96 write :: Base v m a => v m a -> Int -> a -> m ()
97 {-# INLINE write #-}
98 write v i x = assert (inBounds v i) $ unsafeWrite v i x
99
100 copy :: Base v m a => v m a -> v m a -> m ()
101 {-# INLINE copy #-}
102 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
103 $ unsafeCopy dst src
104
105 grow :: Base v m a => v m a -> Int -> m (v m a)
106 {-# INLINE grow #-}
107 grow v by = assert (by >= 0)
108 $ unsafeGrow v by
109
110
111 unstream :: Base v m a => Stream a -> m (v m a)
112 {-# INLINE unstream #-}
113 unstream s = case upperBound (Stream.size s) of
114 Just n -> unstreamMax s n
115 Nothing -> unstreamUnknown s
116
117 unstreamMax :: Base v m a => Stream a -> Int -> m (v m a)
118 {-# INLINE unstreamMax #-}
119 unstreamMax s n
120 = do
121 v <- new n
122 let put i x = do { write v i x; return (i+1) }
123 n' <- Stream.foldM put 0 s
124 return $ slice v 0 n'
125
126 unstreamUnknown :: Base v m a => Stream a -> m (v m a)
127 {-# INLINE unstreamUnknown #-}
128 unstreamUnknown s
129 = do
130 v <- new 0
131 (v', n) <- Stream.foldM put (v, 0) s
132 return $ slice v' 0 n
133 where
134 {-# INLINE put #-}
135 put (v, i) x = do
136 v' <- enlarge v i
137 unsafeWrite v' i x
138 return (v', i+1)
139
140 {-# INLINE enlarge #-}
141 enlarge v i | i < length v = return v
142 | otherwise = unsafeGrow v
143 . max 1
144 . double2Int
145 $ int2Double (length v) * gROWTH_FACTOR
146