Trim imports
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns, ScopedTypeVariables #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, swap, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite, unsafeSwap,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, unstreamR,
30 transform, transformR,
31 munstream, munstreamR,
32 unsafeAccum, accum, unsafeUpdate, update, reverse,
33 unstablePartition, unstablePartitionStream, partitionStream
34 ) where
35
36 import qualified Data.Vector.Fusion.Stream as Stream
37 import Data.Vector.Fusion.Stream ( Stream, MStream )
38 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
39 import Data.Vector.Fusion.Stream.Size
40
41 import Control.Monad.Primitive ( PrimMonad, PrimState )
42
43 import Prelude hiding ( length, reverse, map, read,
44 take, drop, init, tail )
45
46 #include "vector.h"
47
48 -- | Class of mutable vectors parametrised with a primitive state token.
49 --
50 class MVector v a where
51 -- | Length of the mutable vector. This method should not be
52 -- called directly, use 'length' instead.
53 basicLength :: v s a -> Int
54
55 -- | Yield a part of the mutable vector without copying it. This method
56 -- should not be called directly, use 'unsafeSlice' instead.
57 basicUnsafeSlice :: Int -- ^ starting index
58 -> Int -- ^ length of the slice
59 -> v s a
60 -> v s a
61
62 -- Check whether two vectors overlap. This method should not be
63 -- called directly, use 'overlaps' instead.
64 basicOverlaps :: v s a -> v s a -> Bool
65
66 -- | Create a mutable vector of the given length. This method should not be
67 -- called directly, use 'unsafeNew' instead.
68 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
69
70 -- | Create a mutable vector of the given length and fill it with an
71 -- initial value. This method should not be called directly, use
72 -- 'unsafeNewWith' instead.
73 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
74
75 -- | Yield the element at the given position. This method should not be
76 -- called directly, use 'unsafeRead' instead.
77 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
78
79 -- | Replace the element at the given position. This method should not be
80 -- called directly, use 'unsafeWrite' instead.
81 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
82
83 -- | Reset all elements of the vector to some undefined value, clearing all
84 -- references to external objects. This is usually a noop for unboxed
85 -- vectors. This method should not be called directly, use 'clear' instead.
86 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
87
88 -- | Set all elements of the vector to the given value. This method should
89 -- not be called directly, use 'set' instead.
90 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
91
92 -- | Copy a vector. The two vectors may not overlap. This method should not
93 -- be called directly, use 'unsafeCopy' instead.
94 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
95 -> v (PrimState m) a -- ^ source
96 -> m ()
97
98 -- | Grow a vector by the given number of elements. This method should not be
99 -- called directly, use 'unsafeGrow' instead.
100 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
101 -> m (v (PrimState m) a)
102
103 {-# INLINE basicUnsafeNewWith #-}
104 basicUnsafeNewWith n x
105 = do
106 v <- basicUnsafeNew n
107 basicSet v x
108 return v
109
110 {-# INLINE basicClear #-}
111 basicClear _ = return ()
112
113 {-# INLINE basicSet #-}
114 basicSet v x = do_set 0
115 where
116 n = basicLength v
117
118 do_set i | i < n = do
119 basicUnsafeWrite v i x
120 do_set (i+1)
121 | otherwise = return ()
122
123 {-# INLINE basicUnsafeCopy #-}
124 basicUnsafeCopy dst src = do_copy 0
125 where
126 n = basicLength src
127
128 do_copy i | i < n = do
129 x <- basicUnsafeRead src i
130 basicUnsafeWrite dst i x
131 do_copy (i+1)
132 | otherwise = return ()
133
134 {-# INLINE basicUnsafeGrow #-}
135 basicUnsafeGrow v by
136 = do
137 v' <- basicUnsafeNew (n+by)
138 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
139 return v'
140 where
141 n = basicLength v
142
143 -- ------------------
144 -- Internal functions
145 -- ------------------
146
147 -- Check whether two vectors overlap.
148 overlaps :: MVector v a => v s a -> v s a -> Bool
149 {-# INLINE overlaps #-}
150 overlaps = basicOverlaps
151
152 unsafeAppend1 :: (PrimMonad m, MVector v a)
153 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
154 {-# INLINE_INNER unsafeAppend1 #-}
155 -- NOTE: The case distinction has to be on the outside because
156 -- GHC creates a join point for the unsafeWrite even when everything
157 -- is inlined. This is bad because with the join point, v isn't getting
158 -- unboxed.
159 unsafeAppend1 v i x
160 | i < length v = do
161 unsafeWrite v i x
162 return v
163 | otherwise = do
164 v' <- enlarge v
165 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
166 $ unsafeWrite v' i x
167 return v'
168
169 unsafePrepend1 :: (PrimMonad m, MVector v a)
170 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
171 {-# INLINE_INNER unsafePrepend1 #-}
172 unsafePrepend1 v i x
173 | i /= 0 = do
174 let i' = i-1
175 unsafeWrite v i' x
176 return (v, i')
177 | otherwise = do
178 (v', i) <- enlargeFront v
179 let i' = i-1
180 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
181 $ unsafeWrite v' i' x
182 return (v', i')
183
184 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
185 {-# INLINE mstream #-}
186 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
187 where
188 n = length v
189
190 {-# INLINE_INNER get #-}
191 get i | i < n = do x <- unsafeRead v i
192 return $ Just (x, i+1)
193 | otherwise = return $ Nothing
194
195 munstream :: (PrimMonad m, MVector v a)
196 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
197 {-# INLINE munstream #-}
198 munstream v s = v `seq` do
199 n' <- MStream.foldM put 0 s
200 return $ unsafeSlice 0 n' v
201 where
202 {-# INLINE_INNER put #-}
203 put i x = do
204 INTERNAL_CHECK(checkIndex) "munstream" i (length v)
205 $ unsafeWrite v i x
206 return (i+1)
207
208 transform :: (PrimMonad m, MVector v a)
209 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
210 {-# INLINE_STREAM transform #-}
211 transform f v = munstream v (f (mstream v))
212
213 mrstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
214 {-# INLINE mrstream #-}
215 mrstream v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
216 where
217 n = length v
218
219 {-# INLINE_INNER get #-}
220 get i | j >= 0 = do x <- unsafeRead v j
221 return $ Just (x,j)
222 | otherwise = return Nothing
223 where
224 j = i-1
225
226 munstreamR :: (PrimMonad m, MVector v a)
227 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
228 {-# INLINE munstreamR #-}
229 munstreamR v s = v `seq` do
230 i <- MStream.foldM put n s
231 return $ unsafeSlice i (n-i) v
232 where
233 n = length v
234
235 {-# INLINE_INNER put #-}
236 put i x = do
237 unsafeWrite v j x
238 return j
239 where
240 j = i-1
241
242 transformR :: (PrimMonad m, MVector v a)
243 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
244 {-# INLINE_STREAM transformR #-}
245 transformR f v = munstreamR v (f (mrstream v))
246
247 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
248 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
249 -- inexact.
250 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
251 {-# INLINE_STREAM unstream #-}
252 unstream s = case upperBound (Stream.size s) of
253 Just n -> unstreamMax s n
254 Nothing -> unstreamUnknown s
255
256 -- FIXME: I can't think of how to prevent GHC from floating out
257 -- unstreamUnknown. That is bad because SpecConstr then generates two
258 -- specialisations: one for when it is called from unstream (it doesn't know
259 -- the shape of the vector) and one for when the vector has grown. To see the
260 -- problem simply compile this:
261 --
262 -- fromList = Data.Vector.Unboxed.unstream . Stream.fromList
263 --
264
265 unstreamMax
266 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
267 {-# INLINE unstreamMax #-}
268 unstreamMax s n
269 = do
270 v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
271 $ unsafeNew n
272 let put i x = do
273 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
274 $ unsafeWrite v i x
275 return (i+1)
276 n' <- Stream.foldM' put 0 s
277 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
278 $ unsafeSlice 0 n' v
279
280 unstreamUnknown
281 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
282 {-# INLINE unstreamUnknown #-}
283 unstreamUnknown s
284 = do
285 v <- unsafeNew 0
286 (v', n) <- Stream.foldM put (v, 0) s
287 return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
288 $ unsafeSlice 0 n v'
289 where
290 {-# INLINE_INNER put #-}
291 put (v,i) x = do
292 v' <- unsafeAppend1 v i x
293 return (v',i+1)
294
295 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
296 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
297 -- inexact.
298 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
299 {-# INLINE_STREAM unstreamR #-}
300 unstreamR s = case upperBound (Stream.size s) of
301 Just n -> unstreamRMax s n
302 Nothing -> unstreamRUnknown s
303
304 unstreamRMax
305 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
306 {-# INLINE unstreamRMax #-}
307 unstreamRMax s n
308 = do
309 v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
310 $ unsafeNew n
311 let put i x = do
312 let i' = i-1
313 INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
314 $ unsafeWrite v i' x
315 return i'
316 i <- Stream.foldM' put n s
317 return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (n-i) n
318 $ unsafeSlice i (n-i) v
319
320 unstreamRUnknown
321 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
322 {-# INLINE unstreamRUnknown #-}
323 unstreamRUnknown s
324 = do
325 v <- unsafeNew 0
326 (v', i) <- Stream.foldM put (v, 0) s
327 let n = length v'
328 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
329 $ unsafeSlice i (n-i) v'
330 where
331 {-# INLINE_INNER put #-}
332 put (v,i) x = unsafePrepend1 v i x
333
334 -- Length
335 -- ------
336
337 -- | Length of the mutable vector.
338 length :: MVector v a => v s a -> Int
339 {-# INLINE length #-}
340 length = basicLength
341
342 -- | Check whether the vector is empty
343 null :: MVector v a => v s a -> Bool
344 {-# INLINE null #-}
345 null v = length v == 0
346
347
348 -- Construction
349 -- ------------
350
351 -- | Create a mutable vector of the given length.
352 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
353 {-# INLINE new #-}
354 new n = BOUNDS_CHECK(checkLength) "new" n
355 $ unsafeNew n
356
357 -- | Create a mutable vector of the given length and fill it with an
358 -- initial value.
359 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
360 {-# INLINE newWith #-}
361 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
362 $ unsafeNewWith n x
363
364 -- | Create a mutable vector of the given length. The length is not checked.
365 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
366 {-# INLINE unsafeNew #-}
367 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
368 $ basicUnsafeNew n
369
370 -- | Create a mutable vector of the given length and fill it with an
371 -- initial value. The length is not checked.
372 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
373 {-# INLINE unsafeNewWith #-}
374 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
375 $ basicUnsafeNewWith n x
376
377
378 -- Growing
379 -- -------
380
381 -- | Grow a vector by the given number of elements. The number must be
382 -- positive.
383 grow :: (PrimMonad m, MVector v a)
384 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
385 {-# INLINE grow #-}
386 grow v by = BOUNDS_CHECK(checkLength) "grow" by
387 $ unsafeGrow v by
388
389 growFront :: (PrimMonad m, MVector v a)
390 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
391 {-# INLINE growFront #-}
392 growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
393 $ unsafeGrowFront v by
394
395 enlarge_delta v = max (length v) 1
396
397 -- | Grow a vector logarithmically
398 enlarge :: (PrimMonad m, MVector v a)
399 => v (PrimState m) a -> m (v (PrimState m) a)
400 {-# INLINE enlarge #-}
401 enlarge v = unsafeGrow v (enlarge_delta v)
402
403 enlargeFront :: (PrimMonad m, MVector v a)
404 => v (PrimState m) a -> m (v (PrimState m) a, Int)
405 {-# INLINE enlargeFront #-}
406 enlargeFront v = do
407 v' <- unsafeGrowFront v by
408 return (v', by)
409 where
410 by = enlarge_delta v
411
412 -- | Grow a vector by the given number of elements. The number must be
413 -- positive but this is not checked.
414 unsafeGrow :: (PrimMonad m, MVector v a)
415 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
416 {-# INLINE unsafeGrow #-}
417 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
418 $ basicUnsafeGrow v n
419
420 unsafeGrowFront :: (PrimMonad m, MVector v a)
421 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
422 {-# INLINE unsafeGrowFront #-}
423 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
424 $ do
425 let n = length v
426 v' <- basicUnsafeNew (by+n)
427 basicUnsafeCopy (basicUnsafeSlice by n v') v
428 return v'
429
430 -- Accessing individual elements
431 -- -----------------------------
432
433 -- | Yield the element at the given position.
434 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
435 {-# INLINE read #-}
436 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
437 $ unsafeRead v i
438
439 -- | Replace the element at the given position.
440 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
441 {-# INLINE write #-}
442 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
443 $ unsafeWrite v i x
444
445 -- | Swap the elements at the given positions.
446 swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
447 {-# INLINE swap #-}
448 swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
449 $ BOUNDS_CHECK(checkIndex) "swap" j (length v)
450 $ unsafeSwap v i j
451
452 -- | Replace the element at the give position and return the old element.
453 exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
454 {-# INLINE exchange #-}
455 exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
456 $ unsafeExchange v i x
457
458 -- | Yield the element at the given position. No bounds checks are performed.
459 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
460 {-# INLINE unsafeRead #-}
461 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
462 $ basicUnsafeRead v i
463
464 -- | Replace the element at the given position. No bounds checks are performed.
465 unsafeWrite :: (PrimMonad m, MVector v a)
466 => v (PrimState m) a -> Int -> a -> m ()
467 {-# INLINE unsafeWrite #-}
468 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
469 $ basicUnsafeWrite v i x
470
471 -- | Swap the elements at the given positions. No bounds checks are performed.
472 unsafeSwap :: (PrimMonad m, MVector v a)
473 => v (PrimState m) a -> Int -> Int -> m ()
474 {-# INLINE unsafeSwap #-}
475 unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
476 $ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
477 $ do
478 x <- unsafeRead v i
479 y <- unsafeRead v j
480 unsafeWrite v i y
481 unsafeWrite v j x
482
483 -- | Replace the element at the give position and return the old element. No
484 -- bounds checks are performed.
485 unsafeExchange :: (PrimMonad m, MVector v a)
486 => v (PrimState m) a -> Int -> a -> m a
487 {-# INLINE unsafeExchange #-}
488 unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
489 $ do
490 y <- unsafeRead v i
491 unsafeWrite v i x
492 return y
493
494 -- Block operations
495 -- ----------------
496
497 -- | Reset all elements of the vector to some undefined value, clearing all
498 -- references to external objects. This is usually a noop for unboxed vectors.
499 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
500 {-# INLINE clear #-}
501 clear = basicClear
502
503 -- | Set all elements of the vector to the given value.
504 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
505 {-# INLINE set #-}
506 set = basicSet
507
508 -- | Copy a vector. The two vectors must have the same length and may not
509 -- overlap.
510 copy :: (PrimMonad m, MVector v a)
511 => v (PrimState m) a -> v (PrimState m) a -> m ()
512 {-# INLINE copy #-}
513 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
514 (not (dst `overlaps` src))
515 $ BOUNDS_CHECK(check) "copy" "length mismatch"
516 (length dst == length src)
517 $ unsafeCopy dst src
518
519 -- | Copy a vector. The two vectors must have the same length and may not
520 -- overlap. This is not checked.
521 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
522 -> v (PrimState m) a -- ^ source
523 -> m ()
524 {-# INLINE unsafeCopy #-}
525 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
526 (length dst == length src)
527 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
528 (not (dst `overlaps` src))
529 $ (dst `seq` src `seq` basicUnsafeCopy dst src)
530
531 -- Subvectors
532 -- ----------
533
534 -- | Yield a part of the mutable vector without copying it.
535 slice :: MVector v a => Int -> Int -> v s a -> v s a
536 {-# INLINE slice #-}
537 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
538 $ unsafeSlice i n v
539
540 take :: MVector v a => Int -> v s a -> v s a
541 {-# INLINE take #-}
542 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
543
544 drop :: MVector v a => Int -> v s a -> v s a
545 {-# INLINE drop #-}
546 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
547 where
548 n' = max n 0
549 m = length v
550
551 init :: MVector v a => v s a -> v s a
552 {-# INLINE init #-}
553 init v = slice 0 (length v - 1) v
554
555 tail :: MVector v a => v s a -> v s a
556 {-# INLINE tail #-}
557 tail v = slice 1 (length v - 1) v
558
559 -- | Yield a part of the mutable vector without copying it. No bounds checks
560 -- are performed.
561 unsafeSlice :: MVector v a => Int -- ^ starting index
562 -> Int -- ^ length of the slice
563 -> v s a
564 -> v s a
565 {-# INLINE unsafeSlice #-}
566 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
567 $ basicUnsafeSlice i n v
568
569 unsafeInit :: MVector v a => v s a -> v s a
570 {-# INLINE unsafeInit #-}
571 unsafeInit v = unsafeSlice 0 (length v - 1) v
572
573 unsafeTail :: MVector v a => v s a -> v s a
574 {-# INLINE unsafeTail #-}
575 unsafeTail v = unsafeSlice 1 (length v - 1) v
576
577 unsafeTake :: MVector v a => Int -> v s a -> v s a
578 {-# INLINE unsafeTake #-}
579 unsafeTake n v = unsafeSlice 0 n v
580
581 unsafeDrop :: MVector v a => Int -> v s a -> v s a
582 {-# INLINE unsafeDrop #-}
583 unsafeDrop n v = unsafeSlice n (length v - n) v
584
585 -- Permutations
586 -- ------------
587
588 accum :: (PrimMonad m, MVector v a)
589 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
590 {-# INLINE accum #-}
591 accum f !v s = Stream.mapM_ upd s
592 where
593 {-# INLINE_INNER upd #-}
594 upd (i,b) = do
595 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
596 $ unsafeRead v i
597 unsafeWrite v i (f a b)
598
599 update :: (PrimMonad m, MVector v a)
600 => v (PrimState m) a -> Stream (Int, a) -> m ()
601 {-# INLINE update #-}
602 update !v s = Stream.mapM_ upd s
603 where
604 {-# INLINE_INNER upd #-}
605 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
606 $ unsafeWrite v i b
607
608 unsafeAccum :: (PrimMonad m, MVector v a)
609 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
610 {-# INLINE unsafeAccum #-}
611 unsafeAccum f !v s = Stream.mapM_ upd s
612 where
613 {-# INLINE_INNER upd #-}
614 upd (i,b) = do
615 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
616 $ unsafeRead v i
617 unsafeWrite v i (f a b)
618
619 unsafeUpdate :: (PrimMonad m, MVector v a)
620 => v (PrimState m) a -> Stream (Int, a) -> m ()
621 {-# INLINE unsafeUpdate #-}
622 unsafeUpdate !v s = Stream.mapM_ upd s
623 where
624 {-# INLINE_INNER upd #-}
625 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
626 $ unsafeWrite v i b
627
628 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
629 {-# INLINE reverse #-}
630 reverse !v = reverse_loop 0 (length v - 1)
631 where
632 reverse_loop i j | i < j = do
633 unsafeSwap v i j
634 reverse_loop (i + 1) (j - 1)
635 reverse_loop _ _ = return ()
636
637 unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
638 => (a -> Bool) -> v (PrimState m) a -> m Int
639 {-# INLINE unstablePartition #-}
640 unstablePartition f !v = from_left 0 (length v)
641 where
642 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
643 -- from_right
644 from_left :: Int -> Int -> m Int
645 from_left i j
646 | i == j = return i
647 | otherwise = do
648 x <- unsafeRead v i
649 if f x
650 then from_left (i+1) j
651 else from_right i (j-1)
652
653 from_right :: Int -> Int -> m Int
654 from_right i j
655 | i == j = return i
656 | otherwise = do
657 x <- unsafeRead v j
658 if f x
659 then do
660 y <- unsafeRead v i
661 unsafeWrite v i x
662 unsafeWrite v j y
663 from_left (i+1) j
664 else from_right i (j-1)
665
666 unstablePartitionStream :: (PrimMonad m, MVector v a)
667 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
668 {-# INLINE unstablePartitionStream #-}
669 unstablePartitionStream f s
670 = case upperBound (Stream.size s) of
671 Just n -> unstablePartitionMax f s n
672 Nothing -> partitionUnknown f s
673
674 unstablePartitionMax :: (PrimMonad m, MVector v a)
675 => (a -> Bool) -> Stream a -> Int
676 -> m (v (PrimState m) a, v (PrimState m) a)
677 {-# INLINE unstablePartitionMax #-}
678 unstablePartitionMax f s n
679 = do
680 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
681 $ unsafeNew n
682 let {-# INLINE_INNER put #-}
683 put (i, j) x
684 | f x = do
685 unsafeWrite v i x
686 return (i+1, j)
687 | otherwise = do
688 unsafeWrite v (j-1) x
689 return (i, j-1)
690
691 (i,j) <- Stream.foldM' put (0, n) s
692 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
693
694 partitionStream :: (PrimMonad m, MVector v a)
695 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
696 {-# INLINE partitionStream #-}
697 partitionStream f s
698 = case upperBound (Stream.size s) of
699 Just n -> partitionMax f s n
700 Nothing -> partitionUnknown f s
701
702 partitionMax :: (PrimMonad m, MVector v a)
703 => (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
704 {-# INLINE partitionMax #-}
705 partitionMax f s n
706 = do
707 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
708 $ unsafeNew n
709
710 let {-# INLINE_INNER put #-}
711 put (i,j) x
712 | f x = do
713 unsafeWrite v i x
714 return (i+1,j)
715
716 | otherwise = let j' = j-1 in
717 do
718 unsafeWrite v j' x
719 return (i,j')
720
721 (i,j) <- Stream.foldM' put (0,n) s
722 INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
723 $ return ()
724 let l = unsafeSlice 0 i v
725 r = unsafeSlice j (n-j) v
726 reverse r
727 return (l,r)
728
729 partitionUnknown :: (PrimMonad m, MVector v a)
730 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
731 {-# INLINE partitionUnknown #-}
732 partitionUnknown f s
733 = do
734 v1 <- unsafeNew 0
735 v2 <- unsafeNew 0
736 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
737 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
738 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
739 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
740 where
741 -- NOTE: The case distinction has to be on the outside because
742 -- GHC creates a join point for the unsafeWrite even when everything
743 -- is inlined. This is bad because with the join point, v isn't getting
744 -- unboxed.
745 {-# INLINE_INNER put #-}
746 put (v1, i1, v2, i2) x
747 | f x = do
748 v1' <- unsafeAppend1 v1 i1 x
749 return (v1', i1+1, v2, i2)
750 | otherwise = do
751 v2' <- unsafeAppend1 v2 i2 x
752 return (v1, i1, v2', i2+1)
753