Clean up interface to mutable vectors
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, slice, new, newWith, read, write, clear, set, copy, grow,
20
21 -- * Unsafe operations
22 unsafeSlice, unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite,
23 unsafeCopy, unsafeGrow,
24
25 -- * Internal operations
26 unstream, transform, accum, update, reverse
27 ) where
28
29 import qualified Data.Vector.Fusion.Stream as Stream
30 import Data.Vector.Fusion.Stream ( Stream, MStream )
31 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
32 import Data.Vector.Fusion.Stream.Size
33
34 import Control.Monad.Primitive ( PrimMonad, PrimState )
35
36 import GHC.Float (
37 double2Int, int2Double
38 )
39
40 import Prelude hiding ( length, reverse, map, read )
41
42 #include "vector.h"
43
44 gROWTH_FACTOR :: Double
45 gROWTH_FACTOR = 1.5
46
47 -- | Class of mutable vectors parametrised with a primitive state token.
48 --
49 class MVector v a where
50 -- | Length of the mutable vector. This method should not be
51 -- called directly, use 'length' instead.
52 basicLength :: v s a -> Int
53
54 -- | Yield a part of the mutable vector without copying it. This method
55 -- should not be called directly, use 'unsafeSlice' instead.
56 basicUnsafeSlice :: v s a -> Int -- ^ starting index
57 -> Int -- ^ length of the slice
58 -> v s a
59
60 -- Check whether two vectors overlap. This method should not be
61 -- called directly, use 'overlaps' instead.
62 basicOverlaps :: v s a -> v s a -> Bool
63
64 -- | Create a mutable vector of the given length. This method should not be
65 -- called directly, use 'unsafeNew' instead.
66 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
67
68 -- | Create a mutable vector of the given length and fill it with an
69 -- initial value. This method should not be called directly, use
70 -- 'unsafeNewWith' instead.
71 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
72
73 -- | Yield the element at the given position. This method should not be
74 -- called directly, use 'unsafeRead' instead.
75 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
76
77 -- | Replace the element at the given position. This method should not be
78 -- called directly, use 'unsafeWrite' instead.
79 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
80
81 -- | Reset all elements of the vector to some undefined value, clearing all
82 -- references to external objects. This is usually a noop for unboxed
83 -- vectors. This method should not be called directly, use 'clear' instead.
84 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
85
86 -- | Set all elements of the vector to the given value. This method should
87 -- not be called directly, use 'set' instead.
88 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
89
90 -- | Copy a vector. The two vectors may not overlap. This method should not
91 -- be called directly, use 'unsafeCopy' instead.
92 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
93 -> v (PrimState m) a -- ^ source
94 -> m ()
95
96 -- | Grow a vector by the given number of elements. This method should not be
97 -- called directly, use 'unsafeGrow' instead.
98 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
99 -> m (v (PrimState m) a)
100 {-# INLINE basicUnsafeNewWith #-}
101 basicUnsafeNewWith n x
102 = do
103 v <- basicUnsafeNew n
104 set v x
105 return v
106
107 {-# INLINE basicClear #-}
108 basicClear _ = return ()
109
110 {-# INLINE basicSet #-}
111 basicSet v x = do_set 0
112 where
113 n = length v
114
115 do_set i | i < n = do
116 basicUnsafeWrite v i x
117 do_set (i+1)
118 | otherwise = return ()
119
120 {-# INLINE basicUnsafeCopy #-}
121 basicUnsafeCopy dst src = do_copy 0
122 where
123 n = length src
124
125 do_copy i | i < n = do
126 x <- basicUnsafeRead src i
127 basicUnsafeWrite dst i x
128 do_copy (i+1)
129 | otherwise = return ()
130
131 {-# INLINE basicUnsafeGrow #-}
132 basicUnsafeGrow v by
133 = do
134 v' <- basicUnsafeNew (n+by)
135 basicUnsafeCopy (basicUnsafeSlice v' 0 n) v
136 return v'
137 where
138 n = length v
139
140
141
142 -- | Yield a part of the mutable vector without copying it. No bounds checks
143 -- are performed.
144 unsafeSlice :: MVector v a => v s a -> Int -- ^ starting index
145 -> Int -- ^ length of the slice
146 -> v s a
147 {-# INLINE unsafeSlice #-}
148 unsafeSlice v i n = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
149 $ basicUnsafeSlice v i n
150
151
152 -- | Create a mutable vector of the given length. The length is not checked.
153 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
154 {-# INLINE unsafeNew #-}
155 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
156 $ basicUnsafeNew n
157
158 -- | Create a mutable vector of the given length and fill it with an
159 -- initial value. The length is not checked.
160 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
161 {-# INLINE unsafeNewWith #-}
162 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
163 $ basicUnsafeNewWith n x
164
165 -- | Yield the element at the given position. No bounds checks are performed.
166 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
167 {-# INLINE unsafeRead #-}
168 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
169 $ basicUnsafeRead v i
170
171 -- | Replace the element at the given position. No bounds checks are performed.
172 unsafeWrite :: (PrimMonad m, MVector v a)
173 => v (PrimState m) a -> Int -> a -> m ()
174 {-# INLINE unsafeWrite #-}
175 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
176 $ basicUnsafeWrite v i x
177
178
179 -- | Copy a vector. The two vectors must have the same length and may not
180 -- overlap. This is not checked.
181 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
182 -> v (PrimState m) a -- ^ source
183 -> m ()
184 {-# INLINE unsafeCopy #-}
185 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
186 (length dst == length src)
187 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
188 (not (dst `overlaps` src))
189 $ basicUnsafeCopy dst src
190
191 -- | Grow a vector by the given number of elements. The number must be
192 -- positive but this is not checked.
193 unsafeGrow :: (PrimMonad m, MVector v a)
194 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
195 {-# INLINE unsafeGrow #-}
196 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
197 $ basicUnsafeGrow v n
198
199 -- | Length of the mutable vector.
200 length :: MVector v a => v s a -> Int
201 {-# INLINE length #-}
202 length = basicLength
203
204 -- Check whether two vectors overlap.
205 overlaps :: MVector v a => v s a -> v s a -> Bool
206 {-# INLINE overlaps #-}
207 overlaps = basicOverlaps
208
209 -- | Yield a part of the mutable vector without copying it.
210 slice :: MVector v a => v s a -> Int -> Int -> v s a
211 {-# INLINE slice #-}
212 slice v i n = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
213 $ unsafeSlice v i n
214
215 -- | Create a mutable vector of the given length.
216 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
217 {-# INLINE new #-}
218 new n = BOUNDS_CHECK(checkLength) "new" n
219 $ unsafeNew n
220
221 -- | Create a mutable vector of the given length and fill it with an
222 -- initial value.
223 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
224 {-# INLINE newWith #-}
225 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
226 $ unsafeNewWith n x
227
228 -- | Yield the element at the given position.
229 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
230 {-# INLINE read #-}
231 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
232 $ unsafeRead v i
233
234 -- | Replace the element at the given position.
235 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
236 {-# INLINE write #-}
237 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
238 $ unsafeWrite v i x
239
240 -- | Reset all elements of the vector to some undefined value, clearing all
241 -- references to external objects. This is usually a noop for unboxed vectors.
242 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
243 {-# INLINE clear #-}
244 clear = basicClear
245
246 -- | Set all elements of the vector to the given value.
247 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
248 {-# INLINE set #-}
249 set = basicSet
250
251 -- | Copy a vector. The two vectors must have the same length and may not
252 -- overlap.
253 copy :: (PrimMonad m, MVector v a)
254 => v (PrimState m) a -> v (PrimState m) a -> m ()
255 {-# INLINE copy #-}
256 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
257 (not (dst `overlaps` src))
258 $ BOUNDS_CHECK(check) "copy" "length mismatch"
259 (length dst == length src)
260 $ unsafeCopy dst src
261
262 -- | Grow a vector by the given number of elements. The number must be
263 -- positive.
264 grow :: (PrimMonad m, MVector v a)
265 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
266 {-# INLINE grow #-}
267 grow v by = BOUNDS_CHECK(checkLength) "grow" by
268 $ unsafeGrow v by
269
270 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
271 {-# INLINE mstream #-}
272 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
273 where
274 n = length v
275
276 {-# INLINE_INNER get #-}
277 get i | i < n = do x <- unsafeRead v i
278 return $ Just (x, i+1)
279 | otherwise = return $ Nothing
280
281 internal_munstream :: (PrimMonad m, MVector v a)
282 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
283 {-# INLINE internal_munstream #-}
284 internal_munstream v s = v `seq` do
285 n' <- MStream.foldM put 0 s
286 return $ slice v 0 n'
287 where
288 {-# INLINE_INNER put #-}
289 put i x = do
290 INTERNAL_CHECK(checkIndex) "internal_munstream" i (length v)
291 $ unsafeWrite v i x
292 return (i+1)
293
294 transform :: (PrimMonad m, MVector v a)
295 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
296 {-# INLINE_STREAM transform #-}
297 transform f v = internal_munstream v (f (mstream v))
298
299 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
300 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
301 -- inexact.
302 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
303 {-# INLINE_STREAM unstream #-}
304 unstream s = case upperBound (Stream.size s) of
305 Just n -> unstreamMax s n
306 Nothing -> unstreamUnknown s
307
308 unstreamMax
309 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
310 {-# INLINE unstreamMax #-}
311 unstreamMax s n
312 = do
313 v <- new n
314 let put i x = do
315 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
316 $ unsafeWrite v i x
317 return (i+1)
318 n' <- Stream.foldM' put 0 s
319 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n $ slice v 0 n'
320
321 unstreamUnknown
322 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
323 {-# INLINE unstreamUnknown #-}
324 unstreamUnknown s
325 = do
326 v <- new 0
327 (v', n) <- Stream.foldM put (v, 0) s
328 return $ slice v' 0 n
329 where
330 -- NOTE: The case distinction has to be on the outside because
331 -- GHC creates a join point for the unsafeWrite even when everything
332 -- is inlined. This is bad because with the join point, v isn't getting
333 -- unboxed.
334 {-# INLINE_INNER put #-}
335 put (v, i) x
336 | i < length v = do
337 unsafeWrite v i x
338 return (v, i+1)
339 | otherwise = do
340 v' <- enlarge v
341 INTERNAL_CHECK(checkIndex) "unstreamMax" i (length v')
342 $ unsafeWrite v' i x
343 return (v', i+1)
344
345 {-# INLINE_INNER enlarge #-}
346 enlarge v = unsafeGrow v
347 $ max 1
348 $ double2Int
349 $ int2Double (length v) * gROWTH_FACTOR
350
351 accum :: (PrimMonad m, MVector v a)
352 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
353 {-# INLINE accum #-}
354 accum f !v s = Stream.mapM_ upd s
355 where
356 {-# INLINE_INNER upd #-}
357 upd (i,b) = do
358 a <- read v i
359 write v i (f a b)
360
361 update :: (PrimMonad m, MVector v a)
362 => v (PrimState m) a -> Stream (Int, a) -> m ()
363 {-# INLINE update #-}
364 update = accum (const id)
365
366 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
367 {-# INLINE reverse #-}
368 reverse !v = reverse_loop 0 (length v - 1)
369 where
370 reverse_loop i j | i < j = do
371 x <- unsafeRead v i
372 y <- unsafeRead v j
373 unsafeWrite v i y
374 unsafeWrite v j x
375 reverse_loop (i + 1) (j - 1)
376 reverse_loop _ _ = return ()
377