57888d9644be8a4dcea81286b05dd4411b667868
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns, ScopedTypeVariables #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, swap, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite, unsafeSwap,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, unstreamR,
30 transform, transformR,
31 munstream, munstreamR,
32 unsafeAccum, accum, unsafeUpdate, update, reverse,
33 unstablePartition, unstablePartitionStream, partitionStream
34 ) where
35
36 import qualified Data.Vector.Fusion.Stream as Stream
37 import Data.Vector.Fusion.Stream ( Stream, MStream )
38 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
39 import Data.Vector.Fusion.Stream.Size
40
41 import Control.Monad.Primitive ( PrimMonad, PrimState )
42
43 import GHC.Float (
44 double2Int, int2Double
45 )
46
47 import Prelude hiding ( length, reverse, map, read,
48 take, drop, init, tail )
49
50 #include "vector.h"
51
52 -- | Class of mutable vectors parametrised with a primitive state token.
53 --
54 class MVector v a where
55 -- | Length of the mutable vector. This method should not be
56 -- called directly, use 'length' instead.
57 basicLength :: v s a -> Int
58
59 -- | Yield a part of the mutable vector without copying it. This method
60 -- should not be called directly, use 'unsafeSlice' instead.
61 basicUnsafeSlice :: Int -- ^ starting index
62 -> Int -- ^ length of the slice
63 -> v s a
64 -> v s a
65
66 -- Check whether two vectors overlap. This method should not be
67 -- called directly, use 'overlaps' instead.
68 basicOverlaps :: v s a -> v s a -> Bool
69
70 -- | Create a mutable vector of the given length. This method should not be
71 -- called directly, use 'unsafeNew' instead.
72 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
73
74 -- | Create a mutable vector of the given length and fill it with an
75 -- initial value. This method should not be called directly, use
76 -- 'unsafeNewWith' instead.
77 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
78
79 -- | Yield the element at the given position. This method should not be
80 -- called directly, use 'unsafeRead' instead.
81 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
82
83 -- | Replace the element at the given position. This method should not be
84 -- called directly, use 'unsafeWrite' instead.
85 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
86
87 -- | Reset all elements of the vector to some undefined value, clearing all
88 -- references to external objects. This is usually a noop for unboxed
89 -- vectors. This method should not be called directly, use 'clear' instead.
90 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
91
92 -- | Set all elements of the vector to the given value. This method should
93 -- not be called directly, use 'set' instead.
94 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
95
96 -- | Copy a vector. The two vectors may not overlap. This method should not
97 -- be called directly, use 'unsafeCopy' instead.
98 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
99 -> v (PrimState m) a -- ^ source
100 -> m ()
101
102 -- | Grow a vector by the given number of elements. This method should not be
103 -- called directly, use 'unsafeGrow' instead.
104 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
105 -> m (v (PrimState m) a)
106
107 {-# INLINE basicUnsafeNewWith #-}
108 basicUnsafeNewWith n x
109 = do
110 v <- basicUnsafeNew n
111 basicSet v x
112 return v
113
114 {-# INLINE basicClear #-}
115 basicClear _ = return ()
116
117 {-# INLINE basicSet #-}
118 basicSet v x = do_set 0
119 where
120 n = basicLength v
121
122 do_set i | i < n = do
123 basicUnsafeWrite v i x
124 do_set (i+1)
125 | otherwise = return ()
126
127 {-# INLINE basicUnsafeCopy #-}
128 basicUnsafeCopy dst src = do_copy 0
129 where
130 n = basicLength src
131
132 do_copy i | i < n = do
133 x <- basicUnsafeRead src i
134 basicUnsafeWrite dst i x
135 do_copy (i+1)
136 | otherwise = return ()
137
138 {-# INLINE basicUnsafeGrow #-}
139 basicUnsafeGrow v by
140 = do
141 v' <- basicUnsafeNew (n+by)
142 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
143 return v'
144 where
145 n = basicLength v
146
147 -- ------------------
148 -- Internal functions
149 -- ------------------
150
151 -- Check whether two vectors overlap.
152 overlaps :: MVector v a => v s a -> v s a -> Bool
153 {-# INLINE overlaps #-}
154 overlaps = basicOverlaps
155
156 unsafeAppend1 :: (PrimMonad m, MVector v a)
157 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
158 {-# INLINE_INNER unsafeAppend1 #-}
159 -- NOTE: The case distinction has to be on the outside because
160 -- GHC creates a join point for the unsafeWrite even when everything
161 -- is inlined. This is bad because with the join point, v isn't getting
162 -- unboxed.
163 unsafeAppend1 v i x
164 | i < length v = do
165 unsafeWrite v i x
166 return v
167 | otherwise = do
168 v' <- enlarge v
169 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
170 $ unsafeWrite v' i x
171 return v'
172
173 unsafePrepend1 :: (PrimMonad m, MVector v a)
174 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
175 {-# INLINE_INNER unsafePrepend1 #-}
176 unsafePrepend1 v i x
177 | i /= 0 = do
178 let i' = i-1
179 unsafeWrite v i' x
180 return (v, i')
181 | otherwise = do
182 (v', i) <- enlargeFront v
183 let i' = i-1
184 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
185 $ unsafeWrite v' i' x
186 return (v', i')
187
188 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
189 {-# INLINE mstream #-}
190 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
191 where
192 n = length v
193
194 {-# INLINE_INNER get #-}
195 get i | i < n = do x <- unsafeRead v i
196 return $ Just (x, i+1)
197 | otherwise = return $ Nothing
198
199 munstream :: (PrimMonad m, MVector v a)
200 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
201 {-# INLINE munstream #-}
202 munstream v s = v `seq` do
203 n' <- MStream.foldM put 0 s
204 return $ unsafeSlice 0 n' v
205 where
206 {-# INLINE_INNER put #-}
207 put i x = do
208 INTERNAL_CHECK(checkIndex) "munstream" i (length v)
209 $ unsafeWrite v i x
210 return (i+1)
211
212 transform :: (PrimMonad m, MVector v a)
213 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
214 {-# INLINE_STREAM transform #-}
215 transform f v = munstream v (f (mstream v))
216
217 mrstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
218 {-# INLINE mrstream #-}
219 mrstream v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
220 where
221 n = length v
222
223 {-# INLINE_INNER get #-}
224 get i | j >= 0 = do x <- unsafeRead v j
225 return $ Just (x,j)
226 | otherwise = return Nothing
227 where
228 j = i-1
229
230 munstreamR :: (PrimMonad m, MVector v a)
231 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
232 {-# INLINE munstreamR #-}
233 munstreamR v s = v `seq` do
234 i <- MStream.foldM put n s
235 return $ unsafeSlice i (n-i) v
236 where
237 n = length v
238
239 {-# INLINE_INNER put #-}
240 put i x = do
241 unsafeWrite v j x
242 return j
243 where
244 j = i-1
245
246 transformR :: (PrimMonad m, MVector v a)
247 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
248 {-# INLINE_STREAM transformR #-}
249 transformR f v = munstreamR v (f (mrstream v))
250
251 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
252 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
253 -- inexact.
254 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
255 {-# INLINE_STREAM unstream #-}
256 unstream s = case upperBound (Stream.size s) of
257 Just n -> unstreamMax s n
258 Nothing -> unstreamUnknown s
259
260 -- FIXME: I can't think of how to prevent GHC from floating out
261 -- unstreamUnknown. That is bad because SpecConstr then generates two
262 -- specialisations: one for when it is called from unstream (it doesn't know
263 -- the shape of the vector) and one for when the vector has grown. To see the
264 -- problem simply compile this:
265 --
266 -- fromList = Data.Vector.Unboxed.unstream . Stream.fromList
267 --
268
269 unstreamMax
270 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
271 {-# INLINE unstreamMax #-}
272 unstreamMax s n
273 = do
274 v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
275 $ unsafeNew n
276 let put i x = do
277 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
278 $ unsafeWrite v i x
279 return (i+1)
280 n' <- Stream.foldM' put 0 s
281 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
282 $ unsafeSlice 0 n' v
283
284 unstreamUnknown
285 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
286 {-# INLINE unstreamUnknown #-}
287 unstreamUnknown s
288 = do
289 v <- unsafeNew 0
290 (v', n) <- Stream.foldM put (v, 0) s
291 return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
292 $ unsafeSlice 0 n v'
293 where
294 {-# INLINE_INNER put #-}
295 put (v,i) x = do
296 v' <- unsafeAppend1 v i x
297 return (v',i+1)
298
299 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
300 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
301 -- inexact.
302 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
303 {-# INLINE_STREAM unstreamR #-}
304 unstreamR s = case upperBound (Stream.size s) of
305 Just n -> unstreamRMax s n
306 Nothing -> unstreamRUnknown s
307
308 unstreamRMax
309 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
310 {-# INLINE unstreamRMax #-}
311 unstreamRMax s n
312 = do
313 v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
314 $ unsafeNew n
315 let put i x = do
316 let i' = i-1
317 INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
318 $ unsafeWrite v i' x
319 return i'
320 i <- Stream.foldM' put n s
321 return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (n-i) n
322 $ unsafeSlice i (n-i) v
323
324 unstreamRUnknown
325 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
326 {-# INLINE unstreamRUnknown #-}
327 unstreamRUnknown s
328 = do
329 v <- unsafeNew 0
330 (v', i) <- Stream.foldM put (v, 0) s
331 let n = length v'
332 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
333 $ unsafeSlice i (n-i) v'
334 where
335 {-# INLINE_INNER put #-}
336 put (v,i) x = unsafePrepend1 v i x
337
338 -- Length
339 -- ------
340
341 -- | Length of the mutable vector.
342 length :: MVector v a => v s a -> Int
343 {-# INLINE length #-}
344 length = basicLength
345
346 -- | Check whether the vector is empty
347 null :: MVector v a => v s a -> Bool
348 {-# INLINE null #-}
349 null v = length v == 0
350
351
352 -- Construction
353 -- ------------
354
355 -- | Create a mutable vector of the given length.
356 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
357 {-# INLINE new #-}
358 new n = BOUNDS_CHECK(checkLength) "new" n
359 $ unsafeNew n
360
361 -- | Create a mutable vector of the given length and fill it with an
362 -- initial value.
363 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
364 {-# INLINE newWith #-}
365 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
366 $ unsafeNewWith n x
367
368 -- | Create a mutable vector of the given length. The length is not checked.
369 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
370 {-# INLINE unsafeNew #-}
371 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
372 $ basicUnsafeNew n
373
374 -- | Create a mutable vector of the given length and fill it with an
375 -- initial value. The length is not checked.
376 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
377 {-# INLINE unsafeNewWith #-}
378 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
379 $ basicUnsafeNewWith n x
380
381
382 -- Growing
383 -- -------
384
385 -- | Grow a vector by the given number of elements. The number must be
386 -- positive.
387 grow :: (PrimMonad m, MVector v a)
388 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
389 {-# INLINE grow #-}
390 grow v by = BOUNDS_CHECK(checkLength) "grow" by
391 $ unsafeGrow v by
392
393 growFront :: (PrimMonad m, MVector v a)
394 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
395 {-# INLINE growFront #-}
396 growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
397 $ unsafeGrowFront v by
398
399 enlarge_delta v = max (length v) 1
400
401 -- | Grow a vector logarithmically
402 enlarge :: (PrimMonad m, MVector v a)
403 => v (PrimState m) a -> m (v (PrimState m) a)
404 {-# INLINE enlarge #-}
405 enlarge v = unsafeGrow v (enlarge_delta v)
406
407 enlargeFront :: (PrimMonad m, MVector v a)
408 => v (PrimState m) a -> m (v (PrimState m) a, Int)
409 {-# INLINE enlargeFront #-}
410 enlargeFront v = do
411 v' <- unsafeGrowFront v by
412 return (v', by)
413 where
414 by = enlarge_delta v
415
416 -- | Grow a vector by the given number of elements. The number must be
417 -- positive but this is not checked.
418 unsafeGrow :: (PrimMonad m, MVector v a)
419 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
420 {-# INLINE unsafeGrow #-}
421 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
422 $ basicUnsafeGrow v n
423
424 unsafeGrowFront :: (PrimMonad m, MVector v a)
425 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
426 {-# INLINE unsafeGrowFront #-}
427 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
428 $ do
429 let n = length v
430 v' <- basicUnsafeNew (by+n)
431 basicUnsafeCopy (basicUnsafeSlice by n v') v
432 return v'
433
434 -- Accessing individual elements
435 -- -----------------------------
436
437 -- | Yield the element at the given position.
438 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
439 {-# INLINE read #-}
440 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
441 $ unsafeRead v i
442
443 -- | Replace the element at the given position.
444 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
445 {-# INLINE write #-}
446 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
447 $ unsafeWrite v i x
448
449 -- | Swap the elements at the given positions.
450 swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
451 {-# INLINE swap #-}
452 swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
453 $ BOUNDS_CHECK(checkIndex) "swap" j (length v)
454 $ unsafeSwap v i j
455
456 -- | Replace the element at the give position and return the old element.
457 exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
458 {-# INLINE exchange #-}
459 exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
460 $ unsafeExchange v i x
461
462 -- | Yield the element at the given position. No bounds checks are performed.
463 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
464 {-# INLINE unsafeRead #-}
465 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
466 $ basicUnsafeRead v i
467
468 -- | Replace the element at the given position. No bounds checks are performed.
469 unsafeWrite :: (PrimMonad m, MVector v a)
470 => v (PrimState m) a -> Int -> a -> m ()
471 {-# INLINE unsafeWrite #-}
472 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
473 $ basicUnsafeWrite v i x
474
475 -- | Swap the elements at the given positions. No bounds checks are performed.
476 unsafeSwap :: (PrimMonad m, MVector v a)
477 => v (PrimState m) a -> Int -> Int -> m ()
478 {-# INLINE unsafeSwap #-}
479 unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
480 $ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
481 $ do
482 x <- unsafeRead v i
483 y <- unsafeRead v j
484 unsafeWrite v i y
485 unsafeWrite v j x
486
487 -- | Replace the element at the give position and return the old element. No
488 -- bounds checks are performed.
489 unsafeExchange :: (PrimMonad m, MVector v a)
490 => v (PrimState m) a -> Int -> a -> m a
491 {-# INLINE unsafeExchange #-}
492 unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
493 $ do
494 y <- unsafeRead v i
495 unsafeWrite v i x
496 return y
497
498 -- Block operations
499 -- ----------------
500
501 -- | Reset all elements of the vector to some undefined value, clearing all
502 -- references to external objects. This is usually a noop for unboxed vectors.
503 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
504 {-# INLINE clear #-}
505 clear = basicClear
506
507 -- | Set all elements of the vector to the given value.
508 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
509 {-# INLINE set #-}
510 set = basicSet
511
512 -- | Copy a vector. The two vectors must have the same length and may not
513 -- overlap.
514 copy :: (PrimMonad m, MVector v a)
515 => v (PrimState m) a -> v (PrimState m) a -> m ()
516 {-# INLINE copy #-}
517 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
518 (not (dst `overlaps` src))
519 $ BOUNDS_CHECK(check) "copy" "length mismatch"
520 (length dst == length src)
521 $ unsafeCopy dst src
522
523 -- | Copy a vector. The two vectors must have the same length and may not
524 -- overlap. This is not checked.
525 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
526 -> v (PrimState m) a -- ^ source
527 -> m ()
528 {-# INLINE unsafeCopy #-}
529 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
530 (length dst == length src)
531 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
532 (not (dst `overlaps` src))
533 $ (dst `seq` src `seq` basicUnsafeCopy dst src)
534
535 -- Subvectors
536 -- ----------
537
538 -- | Yield a part of the mutable vector without copying it.
539 slice :: MVector v a => Int -> Int -> v s a -> v s a
540 {-# INLINE slice #-}
541 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
542 $ unsafeSlice i n v
543
544 take :: MVector v a => Int -> v s a -> v s a
545 {-# INLINE take #-}
546 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
547
548 drop :: MVector v a => Int -> v s a -> v s a
549 {-# INLINE drop #-}
550 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
551 where
552 n' = max n 0
553 m = length v
554
555 init :: MVector v a => v s a -> v s a
556 {-# INLINE init #-}
557 init v = slice 0 (length v - 1) v
558
559 tail :: MVector v a => v s a -> v s a
560 {-# INLINE tail #-}
561 tail v = slice 1 (length v - 1) v
562
563 -- | Yield a part of the mutable vector without copying it. No bounds checks
564 -- are performed.
565 unsafeSlice :: MVector v a => Int -- ^ starting index
566 -> Int -- ^ length of the slice
567 -> v s a
568 -> v s a
569 {-# INLINE unsafeSlice #-}
570 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
571 $ basicUnsafeSlice i n v
572
573 unsafeInit :: MVector v a => v s a -> v s a
574 {-# INLINE unsafeInit #-}
575 unsafeInit v = unsafeSlice 0 (length v - 1) v
576
577 unsafeTail :: MVector v a => v s a -> v s a
578 {-# INLINE unsafeTail #-}
579 unsafeTail v = unsafeSlice 1 (length v - 1) v
580
581 unsafeTake :: MVector v a => Int -> v s a -> v s a
582 {-# INLINE unsafeTake #-}
583 unsafeTake n v = unsafeSlice 0 n v
584
585 unsafeDrop :: MVector v a => Int -> v s a -> v s a
586 {-# INLINE unsafeDrop #-}
587 unsafeDrop n v = unsafeSlice n (length v - n) v
588
589 -- Permutations
590 -- ------------
591
592 accum :: (PrimMonad m, MVector v a)
593 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
594 {-# INLINE accum #-}
595 accum f !v s = Stream.mapM_ upd s
596 where
597 {-# INLINE_INNER upd #-}
598 upd (i,b) = do
599 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
600 $ unsafeRead v i
601 unsafeWrite v i (f a b)
602
603 update :: (PrimMonad m, MVector v a)
604 => v (PrimState m) a -> Stream (Int, a) -> m ()
605 {-# INLINE update #-}
606 update !v s = Stream.mapM_ upd s
607 where
608 {-# INLINE_INNER upd #-}
609 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
610 $ unsafeWrite v i b
611
612 unsafeAccum :: (PrimMonad m, MVector v a)
613 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
614 {-# INLINE unsafeAccum #-}
615 unsafeAccum f !v s = Stream.mapM_ upd s
616 where
617 {-# INLINE_INNER upd #-}
618 upd (i,b) = do
619 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
620 $ unsafeRead v i
621 unsafeWrite v i (f a b)
622
623 unsafeUpdate :: (PrimMonad m, MVector v a)
624 => v (PrimState m) a -> Stream (Int, a) -> m ()
625 {-# INLINE unsafeUpdate #-}
626 unsafeUpdate !v s = Stream.mapM_ upd s
627 where
628 {-# INLINE_INNER upd #-}
629 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
630 $ unsafeWrite v i b
631
632 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
633 {-# INLINE reverse #-}
634 reverse !v = reverse_loop 0 (length v - 1)
635 where
636 reverse_loop i j | i < j = do
637 unsafeSwap v i j
638 reverse_loop (i + 1) (j - 1)
639 reverse_loop _ _ = return ()
640
641 unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
642 => (a -> Bool) -> v (PrimState m) a -> m Int
643 {-# INLINE unstablePartition #-}
644 unstablePartition f !v = from_left 0 (length v)
645 where
646 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
647 -- from_right
648 from_left :: Int -> Int -> m Int
649 from_left i j
650 | i == j = return i
651 | otherwise = do
652 x <- unsafeRead v i
653 if f x
654 then from_left (i+1) j
655 else from_right i (j-1)
656
657 from_right :: Int -> Int -> m Int
658 from_right i j
659 | i == j = return i
660 | otherwise = do
661 x <- unsafeRead v j
662 if f x
663 then do
664 y <- unsafeRead v i
665 unsafeWrite v i x
666 unsafeWrite v j y
667 from_left (i+1) j
668 else from_right i (j-1)
669
670 unstablePartitionStream :: (PrimMonad m, MVector v a)
671 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
672 {-# INLINE unstablePartitionStream #-}
673 unstablePartitionStream f s
674 = case upperBound (Stream.size s) of
675 Just n -> unstablePartitionMax f s n
676 Nothing -> partitionUnknown f s
677
678 unstablePartitionMax :: (PrimMonad m, MVector v a)
679 => (a -> Bool) -> Stream a -> Int
680 -> m (v (PrimState m) a, v (PrimState m) a)
681 {-# INLINE unstablePartitionMax #-}
682 unstablePartitionMax f s n
683 = do
684 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
685 $ unsafeNew n
686 let {-# INLINE_INNER put #-}
687 put (i, j) x
688 | f x = do
689 unsafeWrite v i x
690 return (i+1, j)
691 | otherwise = do
692 unsafeWrite v (j-1) x
693 return (i, j-1)
694
695 (i,j) <- Stream.foldM' put (0, n) s
696 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
697
698 partitionStream :: (PrimMonad m, MVector v a)
699 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
700 {-# INLINE partitionStream #-}
701 partitionStream f s
702 = case upperBound (Stream.size s) of
703 Just n -> partitionMax f s n
704 Nothing -> partitionUnknown f s
705
706 partitionMax :: (PrimMonad m, MVector v a)
707 => (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
708 {-# INLINE partitionMax #-}
709 partitionMax f s n
710 = do
711 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
712 $ unsafeNew n
713
714 let {-# INLINE_INNER put #-}
715 put (i,j) x
716 | f x = do
717 unsafeWrite v i x
718 return (i+1,j)
719
720 | otherwise = let j' = j-1 in
721 do
722 unsafeWrite v j' x
723 return (i,j')
724
725 (i,j) <- Stream.foldM' put (0,n) s
726 INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
727 $ return ()
728 let l = unsafeSlice 0 i v
729 r = unsafeSlice j (n-j) v
730 reverse r
731 return (l,r)
732
733 partitionUnknown :: (PrimMonad m, MVector v a)
734 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
735 {-# INLINE partitionUnknown #-}
736 partitionUnknown f s
737 = do
738 v1 <- unsafeNew 0
739 v2 <- unsafeNew 0
740 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
741 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
742 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
743 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
744 where
745 -- NOTE: The case distinction has to be on the outside because
746 -- GHC creates a join point for the unsafeWrite even when everything
747 -- is inlined. This is bad because with the join point, v isn't getting
748 -- unboxed.
749 {-# INLINE_INNER put #-}
750 put (v1, i1, v2, i2) x
751 | f x = do
752 v1' <- unsafeAppend1 v1 i1 x
753 return (v1', i1+1, v2, i2)
754 | otherwise = do
755 v2' <- unsafeAppend1 v2 i2 x
756 return (v1, i1, v2', i2+1)
757