8291486681c345b013932d755096409947c27239
[darcs-mirrors/vector.git] / Data / Vector / MVector.hs
1 {-# LANGUAGE MultiParamTypeClasses #-}
2 -- |
3 -- Module : Data.Vector.MVector
4 -- Copyright : (c) Roman Leshchinskiy 2008
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 #include "phases.h"
15
16 module Data.Vector.MVector (
17 MVectorPure(..), MVector(..),
18
19 slice, new, newWith, read, write, copy, grow,
20 unstream, transform,
21 accum, update, reverse
22 ) where
23
24 import qualified Data.Vector.Fusion.Stream as Stream
25 import Data.Vector.Fusion.Stream ( Stream, MStream )
26 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
27 import Data.Vector.Fusion.Stream.Size
28
29 import Control.Monad.ST ( ST )
30 import Control.Exception ( assert )
31
32 import GHC.Float (
33 double2Int, int2Double
34 )
35
36 import Prelude hiding ( length, reverse, map, read )
37
38 gROWTH_FACTOR :: Double
39 gROWTH_FACTOR = 1.5
40
41 -- | Basic pure functions on mutable vectors
42 class MVectorPure v a where
43 -- | Length of the mutable vector
44 length :: v a -> Int
45
46 -- | Yield a part of the mutable vector without copying it. No range checks!
47 unsafeSlice :: v a -> Int -- ^ starting index
48 -> Int -- ^ length of the slice
49 -> v a
50
51 -- Check whether two vectors overlap.
52 overlaps :: v a -> v a -> Bool
53
54 -- | Class of mutable vectors. The type @m@ is the monad in which the mutable
55 -- vector can be transformed and @a@ is the type of elements.
56 --
57 class (Monad m, MVectorPure v a) => MVector v m a where
58 -- | Create a mutable vector of the given length. Length is not checked!
59 unsafeNew :: Int -> m (v a)
60
61 -- | Create a mutable vector of the given length and fill it with an
62 -- initial value. Length is not checked!
63 unsafeNewWith :: Int -> a -> m (v a)
64
65 -- | Yield the element at the given position. Index is not checked!
66 unsafeRead :: v a -> Int -> m a
67
68 -- | Replace the element at the given position. Index is not checked!
69 unsafeWrite :: v a -> Int -> a -> m ()
70
71 -- | Clear all references to external objects
72 clear :: v a -> m ()
73
74 -- | Write the value at each position.
75 set :: v a -> a -> m ()
76
77 -- | Copy a vector. The two vectors may not overlap. This is not checked!
78 unsafeCopy :: v a -- ^ target
79 -> v a -- ^ source
80 -> m ()
81
82 -- | Grow a vector by the given number of elements. The length is not
83 -- checked!
84 unsafeGrow :: v a -> Int -> m (v a)
85
86 {-# INLINE unsafeNewWith #-}
87 unsafeNewWith n x = do
88 v <- unsafeNew n
89 set v x
90 return v
91
92 {-# INLINE set #-}
93 set v x = do_set 0
94 where
95 n = length v
96
97 do_set i | i < n = do
98 unsafeWrite v i x
99 do_set (i+1)
100 | otherwise = return ()
101
102 {-# INLINE unsafeCopy #-}
103 unsafeCopy dst src = do_copy 0
104 where
105 n = length src
106
107 do_copy i | i < n = do
108 x <- unsafeRead src i
109 unsafeWrite dst i x
110 do_copy (i+1)
111 | otherwise = return ()
112
113 {-# INLINE unsafeGrow #-}
114 unsafeGrow v by = do
115 v' <- unsafeNew (n+by)
116 unsafeCopy (unsafeSlice v' 0 n) v
117 return v'
118 where
119 n = length v
120
121 -- | Test whether the index is valid for the vector
122 inBounds :: MVectorPure v a => v a -> Int -> Bool
123 {-# INLINE inBounds #-}
124 inBounds v i = i >= 0 && i < length v
125
126 -- | Yield a part of the mutable vector without copying it. Safer version of
127 -- 'unsafeSlice'.
128 slice :: MVectorPure v a => v a -> Int -> Int -> v a
129 {-# INLINE slice #-}
130 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
131 $ unsafeSlice v i n
132
133 -- | Create a mutable vector of the given length. Safer version of
134 -- 'unsafeNew'.
135 new :: MVector v m a => Int -> m (v a)
136 {-# INLINE new #-}
137 new n = assert (n >= 0) $ unsafeNew n
138
139 -- | Create a mutable vector of the given length and fill it with an
140 -- initial value. Safer version of 'unsafeNewWith'.
141 newWith :: MVector v m a => Int -> a -> m (v a)
142 {-# INLINE newWith #-}
143 newWith n x = assert (n >= 0) $ unsafeNewWith n x
144
145 -- | Yield the element at the given position. Safer version of 'unsafeRead'.
146 read :: MVector v m a => v a -> Int -> m a
147 {-# INLINE read #-}
148 read v i = assert (inBounds v i) $ unsafeRead v i
149
150 -- | Replace the element at the given position. Safer version of
151 -- 'unsafeWrite'.
152 write :: MVector v m a => v a -> Int -> a -> m ()
153 {-# INLINE write #-}
154 write v i x = assert (inBounds v i) $ unsafeWrite v i x
155
156 -- | Copy a vector. The two vectors may not overlap. Safer version of
157 -- 'unsafeCopy'.
158 copy :: MVector v m a => v a -> v a -> m ()
159 {-# INLINE copy #-}
160 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
161 $ unsafeCopy dst src
162
163 -- | Grow a vector by the given number of elements. Safer version of
164 -- 'unsafeGrow'.
165 grow :: MVector v m a => v a -> Int -> m (v a)
166 {-# INLINE grow #-}
167 grow v by = assert (by >= 0)
168 $ unsafeGrow v by
169
170 mstream :: MVector v m a => v a -> MStream m a
171 {-# INLINE mstream #-}
172 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
173 where
174 n = length v
175
176 {-# INLINE get #-}
177 get i | i < n = do x <- unsafeRead v i
178 return $ Just (x, i+1)
179 | otherwise = return $ Nothing
180
181 munstream :: MVector v m a => v a -> MStream m a -> m (v a)
182 {-# INLINE munstream #-}
183 munstream v s = v `seq` do
184 n' <- MStream.foldM put 0 s
185 return $ slice v 0 n'
186 where
187 put i x = do { write v i x; return (i+1) }
188
189 transform :: MVector v m a => (MStream m a -> MStream m a) -> v a -> m (v a)
190 {-# INLINE_STREAM transform #-}
191 transform f v = munstream v (f (mstream v))
192
193 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
194 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
195 -- inexact.
196 unstream :: MVector v m a => Stream a -> m (v a)
197 {-# INLINE_STREAM unstream #-}
198 unstream s = case upperBound (Stream.size s) of
199 Just n -> unstreamMax s n
200 Nothing -> unstreamUnknown s
201
202 unstreamMax :: MVector v m a => Stream a -> Int -> m (v a)
203 {-# INLINE unstreamMax #-}
204 unstreamMax s n
205 = do
206 v <- new n
207 let put i x = do { write v i x; return (i+1) }
208 n' <- Stream.foldM put 0 s
209 return $ slice v 0 n'
210
211 unstreamUnknown :: MVector v m a => Stream a -> m (v a)
212 {-# INLINE unstreamUnknown #-}
213 unstreamUnknown s
214 = do
215 v <- new 0
216 (v', n) <- Stream.foldM put (v, 0) s
217 return $ slice v' 0 n
218 where
219 {-# INLINE put #-}
220 put (v, i) x = do
221 v' <- enlarge v i
222 unsafeWrite v' i x
223 return (v', i+1)
224
225 {-# INLINE enlarge #-}
226 enlarge v i | i < length v = return v
227 | otherwise = unsafeGrow v
228 . max 1
229 . double2Int
230 $ int2Double (length v) * gROWTH_FACTOR
231
232 accum :: MVector v m a => (a -> b -> a) -> v a -> Stream (Int, b) -> m ()
233 {-# INLINE accum #-}
234 accum f v s = Stream.mapM_ upd s
235 where
236 {-# INLINE upd #-}
237 upd (i,b) = do
238 a <- read v i
239 write v i (f a b)
240
241 update :: MVector v m a => v a -> Stream (Int, a) -> m ()
242 {-# INLINE update #-}
243 update = accum (const id)
244
245 reverse :: MVector v m a => v a -> m ()
246 {-# INLINE reverse #-}
247 reverse v = reverse_loop 0 (length v - 1)
248 where
249 reverse_loop i j | i < j = do
250 x <- unsafeRead v i
251 y <- unsafeRead v j
252 unsafeWrite v i y
253 unsafeWrite v j x
254 reverse_loop (i + 1) (j - 1)
255 reverse_loop _ _ = return ()
256