4ef0b748c796dd9d19d276590d649218c3480460
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns, ScopedTypeVariables #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2010
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, swap, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite, unsafeSwap,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, unstreamR,
30 transform, transformR,
31 munstream, munstreamR,
32 unsafeAccum, accum, unsafeUpdate, update, reverse,
33 unstablePartition, unstablePartitionStream, partitionStream
34 ) where
35
36 import qualified Data.Vector.Fusion.Stream as Stream
37 import Data.Vector.Fusion.Stream ( Stream, MStream )
38 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
39 import Data.Vector.Fusion.Stream.Size
40
41 import Control.Monad.Primitive ( PrimMonad, PrimState )
42
43 import Prelude hiding ( length, reverse, map, read,
44 take, drop, init, tail )
45
46 #include "vector.h"
47
48 -- | Class of mutable vectors parametrised with a primitive state token.
49 --
50 class MVector v a where
51 -- | Length of the mutable vector. This method should not be
52 -- called directly, use 'length' instead.
53 basicLength :: v s a -> Int
54
55 -- | Yield a part of the mutable vector without copying it. This method
56 -- should not be called directly, use 'unsafeSlice' instead.
57 basicUnsafeSlice :: Int -- ^ starting index
58 -> Int -- ^ length of the slice
59 -> v s a
60 -> v s a
61
62 -- Check whether two vectors overlap. This method should not be
63 -- called directly, use 'overlaps' instead.
64 basicOverlaps :: v s a -> v s a -> Bool
65
66 -- | Create a mutable vector of the given length. This method should not be
67 -- called directly, use 'unsafeNew' instead.
68 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
69
70 -- | Create a mutable vector of the given length and fill it with an
71 -- initial value. This method should not be called directly, use
72 -- 'unsafeNewWith' instead.
73 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
74
75 -- | Yield the element at the given position. This method should not be
76 -- called directly, use 'unsafeRead' instead.
77 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
78
79 -- | Replace the element at the given position. This method should not be
80 -- called directly, use 'unsafeWrite' instead.
81 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
82
83 -- | Reset all elements of the vector to some undefined value, clearing all
84 -- references to external objects. This is usually a noop for unboxed
85 -- vectors. This method should not be called directly, use 'clear' instead.
86 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
87
88 -- | Set all elements of the vector to the given value. This method should
89 -- not be called directly, use 'set' instead.
90 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
91
92 -- | Copy a vector. The two vectors may not overlap. This method should not
93 -- be called directly, use 'unsafeCopy' instead.
94 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
95 -> v (PrimState m) a -- ^ source
96 -> m ()
97
98 -- | Grow a vector by the given number of elements. This method should not be
99 -- called directly, use 'unsafeGrow' instead.
100 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
101 -> m (v (PrimState m) a)
102
103 {-# INLINE basicUnsafeNewWith #-}
104 basicUnsafeNewWith n x
105 = do
106 v <- basicUnsafeNew n
107 basicSet v x
108 return v
109
110 {-# INLINE basicClear #-}
111 basicClear _ = return ()
112
113 {-# INLINE basicSet #-}
114 basicSet v x = do_set 0
115 where
116 n = basicLength v
117
118 do_set i | i < n = do
119 basicUnsafeWrite v i x
120 do_set (i+1)
121 | otherwise = return ()
122
123 {-# INLINE basicUnsafeCopy #-}
124 basicUnsafeCopy dst src = do_copy 0
125 where
126 n = basicLength src
127
128 do_copy i | i < n = do
129 x <- basicUnsafeRead src i
130 basicUnsafeWrite dst i x
131 do_copy (i+1)
132 | otherwise = return ()
133
134 {-# INLINE basicUnsafeGrow #-}
135 basicUnsafeGrow v by
136 = do
137 v' <- basicUnsafeNew (n+by)
138 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
139 return v'
140 where
141 n = basicLength v
142
143 -- ------------------
144 -- Internal functions
145 -- ------------------
146
147 -- Check whether two vectors overlap.
148 overlaps :: MVector v a => v s a -> v s a -> Bool
149 {-# INLINE overlaps #-}
150 overlaps = basicOverlaps
151
152 unsafeAppend1 :: (PrimMonad m, MVector v a)
153 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
154 {-# INLINE_INNER unsafeAppend1 #-}
155 -- NOTE: The case distinction has to be on the outside because
156 -- GHC creates a join point for the unsafeWrite even when everything
157 -- is inlined. This is bad because with the join point, v isn't getting
158 -- unboxed.
159 unsafeAppend1 v i x
160 | i < length v = do
161 unsafeWrite v i x
162 return v
163 | otherwise = do
164 v' <- enlarge v
165 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
166 $ unsafeWrite v' i x
167 return v'
168
169 unsafePrepend1 :: (PrimMonad m, MVector v a)
170 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
171 {-# INLINE_INNER unsafePrepend1 #-}
172 unsafePrepend1 v i x
173 | i /= 0 = do
174 let i' = i-1
175 unsafeWrite v i' x
176 return (v, i')
177 | otherwise = do
178 (v', i) <- enlargeFront v
179 let i' = i-1
180 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
181 $ unsafeWrite v' i' x
182 return (v', i')
183
184 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
185 {-# INLINE mstream #-}
186 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
187 where
188 n = length v
189
190 {-# INLINE_INNER get #-}
191 get i | i < n = do x <- unsafeRead v i
192 return $ Just (x, i+1)
193 | otherwise = return $ Nothing
194
195 munstream :: (PrimMonad m, MVector v a)
196 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
197 {-# INLINE munstream #-}
198 munstream v s = v `seq` do
199 n' <- MStream.foldM put 0 s
200 return $ unsafeSlice 0 n' v
201 where
202 {-# INLINE_INNER put #-}
203 put i x = do
204 INTERNAL_CHECK(checkIndex) "munstream" i (length v)
205 $ unsafeWrite v i x
206 return (i+1)
207
208 transform :: (PrimMonad m, MVector v a)
209 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
210 {-# INLINE_STREAM transform #-}
211 transform f v = munstream v (f (mstream v))
212
213 mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
214 {-# INLINE mstreamR #-}
215 mstreamR v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
216 where
217 n = length v
218
219 {-# INLINE_INNER get #-}
220 get i | j >= 0 = do x <- unsafeRead v j
221 return $ Just (x,j)
222 | otherwise = return Nothing
223 where
224 j = i-1
225
226 munstreamR :: (PrimMonad m, MVector v a)
227 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
228 {-# INLINE munstreamR #-}
229 munstreamR v s = v `seq` do
230 i <- MStream.foldM put n s
231 return $ unsafeSlice i (n-i) v
232 where
233 n = length v
234
235 {-# INLINE_INNER put #-}
236 put i x = do
237 unsafeWrite v j x
238 return j
239 where
240 j = i-1
241
242 transformR :: (PrimMonad m, MVector v a)
243 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
244 {-# INLINE_STREAM transformR #-}
245 transformR f v = munstreamR v (f (mstreamR v))
246
247 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
248 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
249 -- inexact.
250 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
251 {-# INLINE_STREAM unstream #-}
252 unstream s = case upperBound (Stream.size s) of
253 Just n -> unstreamMax s n
254 Nothing -> unstreamUnknown s
255
256 -- FIXME: I can't think of how to prevent GHC from floating out
257 -- unstreamUnknown. That is bad because SpecConstr then generates two
258 -- specialisations: one for when it is called from unstream (it doesn't know
259 -- the shape of the vector) and one for when the vector has grown. To see the
260 -- problem simply compile this:
261 --
262 -- fromList = Data.Vector.Unboxed.unstream . Stream.fromList
263 --
264 -- I'm not sure this still applies (19/04/2010)
265
266 unstreamMax
267 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
268 {-# INLINE unstreamMax #-}
269 unstreamMax s n
270 = do
271 v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
272 $ unsafeNew n
273 let put i x = do
274 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
275 $ unsafeWrite v i x
276 return (i+1)
277 n' <- Stream.foldM' put 0 s
278 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
279 $ unsafeSlice 0 n' v
280
281 unstreamUnknown
282 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
283 {-# INLINE unstreamUnknown #-}
284 unstreamUnknown s
285 = do
286 v <- unsafeNew 0
287 (v', n) <- Stream.foldM put (v, 0) s
288 return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
289 $ unsafeSlice 0 n v'
290 where
291 {-# INLINE_INNER put #-}
292 put (v,i) x = do
293 v' <- unsafeAppend1 v i x
294 return (v',i+1)
295
296 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
297 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
298 -- inexact.
299 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
300 {-# INLINE_STREAM unstreamR #-}
301 unstreamR s = case upperBound (Stream.size s) of
302 Just n -> unstreamRMax s n
303 Nothing -> unstreamRUnknown s
304
305 unstreamRMax
306 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
307 {-# INLINE unstreamRMax #-}
308 unstreamRMax s n
309 = do
310 v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
311 $ unsafeNew n
312 let put i x = do
313 let i' = i-1
314 INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
315 $ unsafeWrite v i' x
316 return i'
317 i <- Stream.foldM' put n s
318 return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (n-i) n
319 $ unsafeSlice i (n-i) v
320
321 unstreamRUnknown
322 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
323 {-# INLINE unstreamRUnknown #-}
324 unstreamRUnknown s
325 = do
326 v <- unsafeNew 0
327 (v', i) <- Stream.foldM put (v, 0) s
328 let n = length v'
329 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
330 $ unsafeSlice i (n-i) v'
331 where
332 {-# INLINE_INNER put #-}
333 put (v,i) x = unsafePrepend1 v i x
334
335 -- Length
336 -- ------
337
338 -- | Length of the mutable vector.
339 length :: MVector v a => v s a -> Int
340 {-# INLINE length #-}
341 length = basicLength
342
343 -- | Check whether the vector is empty
344 null :: MVector v a => v s a -> Bool
345 {-# INLINE null #-}
346 null v = length v == 0
347
348
349 -- Construction
350 -- ------------
351
352 -- | Create a mutable vector of the given length.
353 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
354 {-# INLINE new #-}
355 new n = BOUNDS_CHECK(checkLength) "new" n
356 $ unsafeNew n
357
358 -- | Create a mutable vector of the given length and fill it with an
359 -- initial value.
360 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
361 {-# INLINE newWith #-}
362 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
363 $ unsafeNewWith n x
364
365 -- | Create a mutable vector of the given length. The length is not checked.
366 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
367 {-# INLINE unsafeNew #-}
368 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
369 $ basicUnsafeNew n
370
371 -- | Create a mutable vector of the given length and fill it with an
372 -- initial value. The length is not checked.
373 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
374 {-# INLINE unsafeNewWith #-}
375 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
376 $ basicUnsafeNewWith n x
377
378
379 -- Growing
380 -- -------
381
382 -- | Grow a vector by the given number of elements. The number must be
383 -- positive.
384 grow :: (PrimMonad m, MVector v a)
385 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
386 {-# INLINE grow #-}
387 grow v by = BOUNDS_CHECK(checkLength) "grow" by
388 $ unsafeGrow v by
389
390 growFront :: (PrimMonad m, MVector v a)
391 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
392 {-# INLINE growFront #-}
393 growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
394 $ unsafeGrowFront v by
395
396 enlarge_delta v = max (length v) 1
397
398 -- | Grow a vector logarithmically
399 enlarge :: (PrimMonad m, MVector v a)
400 => v (PrimState m) a -> m (v (PrimState m) a)
401 {-# INLINE enlarge #-}
402 enlarge v = unsafeGrow v (enlarge_delta v)
403
404 enlargeFront :: (PrimMonad m, MVector v a)
405 => v (PrimState m) a -> m (v (PrimState m) a, Int)
406 {-# INLINE enlargeFront #-}
407 enlargeFront v = do
408 v' <- unsafeGrowFront v by
409 return (v', by)
410 where
411 by = enlarge_delta v
412
413 -- | Grow a vector by the given number of elements. The number must be
414 -- positive but this is not checked.
415 unsafeGrow :: (PrimMonad m, MVector v a)
416 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
417 {-# INLINE unsafeGrow #-}
418 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
419 $ basicUnsafeGrow v n
420
421 unsafeGrowFront :: (PrimMonad m, MVector v a)
422 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
423 {-# INLINE unsafeGrowFront #-}
424 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
425 $ do
426 let n = length v
427 v' <- basicUnsafeNew (by+n)
428 basicUnsafeCopy (basicUnsafeSlice by n v') v
429 return v'
430
431 -- Accessing individual elements
432 -- -----------------------------
433
434 -- | Yield the element at the given position.
435 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
436 {-# INLINE read #-}
437 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
438 $ unsafeRead v i
439
440 -- | Replace the element at the given position.
441 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
442 {-# INLINE write #-}
443 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
444 $ unsafeWrite v i x
445
446 -- | Swap the elements at the given positions.
447 swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
448 {-# INLINE swap #-}
449 swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
450 $ BOUNDS_CHECK(checkIndex) "swap" j (length v)
451 $ unsafeSwap v i j
452
453 -- | Replace the element at the give position and return the old element.
454 exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
455 {-# INLINE exchange #-}
456 exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
457 $ unsafeExchange v i x
458
459 -- | Yield the element at the given position. No bounds checks are performed.
460 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
461 {-# INLINE unsafeRead #-}
462 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
463 $ basicUnsafeRead v i
464
465 -- | Replace the element at the given position. No bounds checks are performed.
466 unsafeWrite :: (PrimMonad m, MVector v a)
467 => v (PrimState m) a -> Int -> a -> m ()
468 {-# INLINE unsafeWrite #-}
469 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
470 $ basicUnsafeWrite v i x
471
472 -- | Swap the elements at the given positions. No bounds checks are performed.
473 unsafeSwap :: (PrimMonad m, MVector v a)
474 => v (PrimState m) a -> Int -> Int -> m ()
475 {-# INLINE unsafeSwap #-}
476 unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
477 $ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
478 $ do
479 x <- unsafeRead v i
480 y <- unsafeRead v j
481 unsafeWrite v i y
482 unsafeWrite v j x
483
484 -- | Replace the element at the give position and return the old element. No
485 -- bounds checks are performed.
486 unsafeExchange :: (PrimMonad m, MVector v a)
487 => v (PrimState m) a -> Int -> a -> m a
488 {-# INLINE unsafeExchange #-}
489 unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
490 $ do
491 y <- unsafeRead v i
492 unsafeWrite v i x
493 return y
494
495 -- Block operations
496 -- ----------------
497
498 -- | Reset all elements of the vector to some undefined value, clearing all
499 -- references to external objects. This is usually a noop for unboxed vectors.
500 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
501 {-# INLINE clear #-}
502 clear = basicClear
503
504 -- | Set all elements of the vector to the given value.
505 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
506 {-# INLINE set #-}
507 set = basicSet
508
509 -- | Copy a vector. The two vectors must have the same length and may not
510 -- overlap.
511 copy :: (PrimMonad m, MVector v a)
512 => v (PrimState m) a -> v (PrimState m) a -> m ()
513 {-# INLINE copy #-}
514 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
515 (not (dst `overlaps` src))
516 $ BOUNDS_CHECK(check) "copy" "length mismatch"
517 (length dst == length src)
518 $ unsafeCopy dst src
519
520 -- | Copy a vector. The two vectors must have the same length and may not
521 -- overlap. This is not checked.
522 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
523 -> v (PrimState m) a -- ^ source
524 -> m ()
525 {-# INLINE unsafeCopy #-}
526 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
527 (length dst == length src)
528 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
529 (not (dst `overlaps` src))
530 $ (dst `seq` src `seq` basicUnsafeCopy dst src)
531
532 -- Subvectors
533 -- ----------
534
535 -- | Yield a part of the mutable vector without copying it.
536 slice :: MVector v a => Int -> Int -> v s a -> v s a
537 {-# INLINE slice #-}
538 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
539 $ unsafeSlice i n v
540
541 take :: MVector v a => Int -> v s a -> v s a
542 {-# INLINE take #-}
543 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
544
545 drop :: MVector v a => Int -> v s a -> v s a
546 {-# INLINE drop #-}
547 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
548 where
549 n' = max n 0
550 m = length v
551
552 init :: MVector v a => v s a -> v s a
553 {-# INLINE init #-}
554 init v = slice 0 (length v - 1) v
555
556 tail :: MVector v a => v s a -> v s a
557 {-# INLINE tail #-}
558 tail v = slice 1 (length v - 1) v
559
560 -- | Yield a part of the mutable vector without copying it. No bounds checks
561 -- are performed.
562 unsafeSlice :: MVector v a => Int -- ^ starting index
563 -> Int -- ^ length of the slice
564 -> v s a
565 -> v s a
566 {-# INLINE unsafeSlice #-}
567 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
568 $ basicUnsafeSlice i n v
569
570 unsafeInit :: MVector v a => v s a -> v s a
571 {-# INLINE unsafeInit #-}
572 unsafeInit v = unsafeSlice 0 (length v - 1) v
573
574 unsafeTail :: MVector v a => v s a -> v s a
575 {-# INLINE unsafeTail #-}
576 unsafeTail v = unsafeSlice 1 (length v - 1) v
577
578 unsafeTake :: MVector v a => Int -> v s a -> v s a
579 {-# INLINE unsafeTake #-}
580 unsafeTake n v = unsafeSlice 0 n v
581
582 unsafeDrop :: MVector v a => Int -> v s a -> v s a
583 {-# INLINE unsafeDrop #-}
584 unsafeDrop n v = unsafeSlice n (length v - n) v
585
586 -- Permutations
587 -- ------------
588
589 accum :: (PrimMonad m, MVector v a)
590 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
591 {-# INLINE accum #-}
592 accum f !v s = Stream.mapM_ upd s
593 where
594 {-# INLINE_INNER upd #-}
595 upd (i,b) = do
596 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
597 $ unsafeRead v i
598 unsafeWrite v i (f a b)
599
600 update :: (PrimMonad m, MVector v a)
601 => v (PrimState m) a -> Stream (Int, a) -> m ()
602 {-# INLINE update #-}
603 update !v s = Stream.mapM_ upd s
604 where
605 {-# INLINE_INNER upd #-}
606 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
607 $ unsafeWrite v i b
608
609 unsafeAccum :: (PrimMonad m, MVector v a)
610 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
611 {-# INLINE unsafeAccum #-}
612 unsafeAccum f !v s = Stream.mapM_ upd s
613 where
614 {-# INLINE_INNER upd #-}
615 upd (i,b) = do
616 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
617 $ unsafeRead v i
618 unsafeWrite v i (f a b)
619
620 unsafeUpdate :: (PrimMonad m, MVector v a)
621 => v (PrimState m) a -> Stream (Int, a) -> m ()
622 {-# INLINE unsafeUpdate #-}
623 unsafeUpdate !v s = Stream.mapM_ upd s
624 where
625 {-# INLINE_INNER upd #-}
626 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
627 $ unsafeWrite v i b
628
629 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
630 {-# INLINE reverse #-}
631 reverse !v = reverse_loop 0 (length v - 1)
632 where
633 reverse_loop i j | i < j = do
634 unsafeSwap v i j
635 reverse_loop (i + 1) (j - 1)
636 reverse_loop _ _ = return ()
637
638 unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
639 => (a -> Bool) -> v (PrimState m) a -> m Int
640 {-# INLINE unstablePartition #-}
641 unstablePartition f !v = from_left 0 (length v)
642 where
643 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
644 -- from_right
645 from_left :: Int -> Int -> m Int
646 from_left i j
647 | i == j = return i
648 | otherwise = do
649 x <- unsafeRead v i
650 if f x
651 then from_left (i+1) j
652 else from_right i (j-1)
653
654 from_right :: Int -> Int -> m Int
655 from_right i j
656 | i == j = return i
657 | otherwise = do
658 x <- unsafeRead v j
659 if f x
660 then do
661 y <- unsafeRead v i
662 unsafeWrite v i x
663 unsafeWrite v j y
664 from_left (i+1) j
665 else from_right i (j-1)
666
667 unstablePartitionStream :: (PrimMonad m, MVector v a)
668 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
669 {-# INLINE unstablePartitionStream #-}
670 unstablePartitionStream f s
671 = case upperBound (Stream.size s) of
672 Just n -> unstablePartitionMax f s n
673 Nothing -> partitionUnknown f s
674
675 unstablePartitionMax :: (PrimMonad m, MVector v a)
676 => (a -> Bool) -> Stream a -> Int
677 -> m (v (PrimState m) a, v (PrimState m) a)
678 {-# INLINE unstablePartitionMax #-}
679 unstablePartitionMax f s n
680 = do
681 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
682 $ unsafeNew n
683 let {-# INLINE_INNER put #-}
684 put (i, j) x
685 | f x = do
686 unsafeWrite v i x
687 return (i+1, j)
688 | otherwise = do
689 unsafeWrite v (j-1) x
690 return (i, j-1)
691
692 (i,j) <- Stream.foldM' put (0, n) s
693 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
694
695 partitionStream :: (PrimMonad m, MVector v a)
696 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
697 {-# INLINE partitionStream #-}
698 partitionStream f s
699 = case upperBound (Stream.size s) of
700 Just n -> partitionMax f s n
701 Nothing -> partitionUnknown f s
702
703 partitionMax :: (PrimMonad m, MVector v a)
704 => (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
705 {-# INLINE partitionMax #-}
706 partitionMax f s n
707 = do
708 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
709 $ unsafeNew n
710
711 let {-# INLINE_INNER put #-}
712 put (i,j) x
713 | f x = do
714 unsafeWrite v i x
715 return (i+1,j)
716
717 | otherwise = let j' = j-1 in
718 do
719 unsafeWrite v j' x
720 return (i,j')
721
722 (i,j) <- Stream.foldM' put (0,n) s
723 INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
724 $ return ()
725 let l = unsafeSlice 0 i v
726 r = unsafeSlice j (n-j) v
727 reverse r
728 return (l,r)
729
730 partitionUnknown :: (PrimMonad m, MVector v a)
731 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
732 {-# INLINE partitionUnknown #-}
733 partitionUnknown f s
734 = do
735 v1 <- unsafeNew 0
736 v2 <- unsafeNew 0
737 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
738 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
739 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
740 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
741 where
742 -- NOTE: The case distinction has to be on the outside because
743 -- GHC creates a join point for the unsafeWrite even when everything
744 -- is inlined. This is bad because with the join point, v isn't getting
745 -- unboxed.
746 {-# INLINE_INNER put #-}
747 put (v1, i1, v2, i2) x
748 | f x = do
749 v1' <- unsafeAppend1 v1 i1 x
750 return (v1', i1+1, v2, i2)
751 | otherwise = do
752 v2' <- unsafeAppend1 v2 i2 x
753 return (v1, i1, v2', i2+1)
754