Right-to-left streaming/unstreaming
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, transform, unstreamR, transformR,
30 unsafeAccum, accum, unsafeUpdate, update, reverse,
31 unstablePartition, unstablePartitionStream
32 ) where
33
34 import qualified Data.Vector.Fusion.Stream as Stream
35 import Data.Vector.Fusion.Stream ( Stream, MStream )
36 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
37 import Data.Vector.Fusion.Stream.Size
38
39 import Control.Monad.Primitive ( PrimMonad, PrimState )
40
41 import GHC.Float (
42 double2Int, int2Double
43 )
44
45 import Prelude hiding ( length, reverse, map, read,
46 take, drop, init, tail )
47
48 #include "vector.h"
49
50 gROWTH_FACTOR :: Double
51 gROWTH_FACTOR = 1.5
52
53 -- | Class of mutable vectors parametrised with a primitive state token.
54 --
55 class MVector v a where
56 -- | Length of the mutable vector. This method should not be
57 -- called directly, use 'length' instead.
58 basicLength :: v s a -> Int
59
60 -- | Yield a part of the mutable vector without copying it. This method
61 -- should not be called directly, use 'unsafeSlice' instead.
62 basicUnsafeSlice :: Int -- ^ starting index
63 -> Int -- ^ length of the slice
64 -> v s a
65 -> v s a
66
67 -- Check whether two vectors overlap. This method should not be
68 -- called directly, use 'overlaps' instead.
69 basicOverlaps :: v s a -> v s a -> Bool
70
71 -- | Create a mutable vector of the given length. This method should not be
72 -- called directly, use 'unsafeNew' instead.
73 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
74
75 -- | Create a mutable vector of the given length and fill it with an
76 -- initial value. This method should not be called directly, use
77 -- 'unsafeNewWith' instead.
78 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
79
80 -- | Yield the element at the given position. This method should not be
81 -- called directly, use 'unsafeRead' instead.
82 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
83
84 -- | Replace the element at the given position. This method should not be
85 -- called directly, use 'unsafeWrite' instead.
86 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
87
88 -- | Reset all elements of the vector to some undefined value, clearing all
89 -- references to external objects. This is usually a noop for unboxed
90 -- vectors. This method should not be called directly, use 'clear' instead.
91 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
92
93 -- | Set all elements of the vector to the given value. This method should
94 -- not be called directly, use 'set' instead.
95 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
96
97 -- | Copy a vector. The two vectors may not overlap. This method should not
98 -- be called directly, use 'unsafeCopy' instead.
99 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
100 -> v (PrimState m) a -- ^ source
101 -> m ()
102
103 -- | Grow a vector by the given number of elements. This method should not be
104 -- called directly, use 'unsafeGrow' instead.
105 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
106 -> m (v (PrimState m) a)
107
108 {-# INLINE basicUnsafeNewWith #-}
109 basicUnsafeNewWith n x
110 = do
111 v <- basicUnsafeNew n
112 set v x
113 return v
114
115 {-# INLINE basicClear #-}
116 basicClear _ = return ()
117
118 {-# INLINE basicSet #-}
119 basicSet v x = do_set 0
120 where
121 n = length v
122
123 do_set i | i < n = do
124 basicUnsafeWrite v i x
125 do_set (i+1)
126 | otherwise = return ()
127
128 {-# INLINE basicUnsafeCopy #-}
129 basicUnsafeCopy dst src = do_copy 0
130 where
131 n = length src
132
133 do_copy i | i < n = do
134 x <- basicUnsafeRead src i
135 basicUnsafeWrite dst i x
136 do_copy (i+1)
137 | otherwise = return ()
138
139 {-# INLINE basicUnsafeGrow #-}
140 basicUnsafeGrow v by
141 = do
142 v' <- basicUnsafeNew (n+by)
143 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
144 return v'
145 where
146 n = length v
147
148 -- | Create a mutable vector of the given length. The length is not checked.
149 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
150 {-# INLINE unsafeNew #-}
151 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
152 $ basicUnsafeNew n
153
154 -- | Create a mutable vector of the given length and fill it with an
155 -- initial value. The length is not checked.
156 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
157 {-# INLINE unsafeNewWith #-}
158 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
159 $ basicUnsafeNewWith n x
160
161 -- | Yield the element at the given position. No bounds checks are performed.
162 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
163 {-# INLINE unsafeRead #-}
164 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
165 $ basicUnsafeRead v i
166
167 -- | Replace the element at the given position. No bounds checks are performed.
168 unsafeWrite :: (PrimMonad m, MVector v a)
169 => v (PrimState m) a -> Int -> a -> m ()
170 {-# INLINE unsafeWrite #-}
171 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
172 $ basicUnsafeWrite v i x
173
174
175 -- | Copy a vector. The two vectors must have the same length and may not
176 -- overlap. This is not checked.
177 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
178 -> v (PrimState m) a -- ^ source
179 -> m ()
180 {-# INLINE unsafeCopy #-}
181 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
182 (length dst == length src)
183 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
184 (not (dst `overlaps` src))
185 $ basicUnsafeCopy dst src
186
187 -- | Grow a vector by the given number of elements. The number must be
188 -- positive but this is not checked.
189 unsafeGrow :: (PrimMonad m, MVector v a)
190 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
191 {-# INLINE unsafeGrow #-}
192 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
193 $ basicUnsafeGrow v n
194
195 unsafeGrowFront :: (PrimMonad m, MVector v a)
196 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
197 {-# INLINE unsafeGrowFront #-}
198 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
199 $ do
200 let n = length v
201 v' <- basicUnsafeNew (by+n)
202 basicUnsafeCopy (basicUnsafeSlice by n v') v
203 return v'
204
205 -- | Length of the mutable vector.
206 length :: MVector v a => v s a -> Int
207 {-# INLINE length #-}
208 length = basicLength
209
210 -- Check whether two vectors overlap.
211 overlaps :: MVector v a => v s a -> v s a -> Bool
212 {-# INLINE overlaps #-}
213 overlaps = basicOverlaps
214
215 -- | Create a mutable vector of the given length.
216 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
217 {-# INLINE new #-}
218 new n = BOUNDS_CHECK(checkLength) "new" n
219 $ unsafeNew n
220
221 -- | Create a mutable vector of the given length and fill it with an
222 -- initial value.
223 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
224 {-# INLINE newWith #-}
225 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
226 $ unsafeNewWith n x
227
228 -- | Yield the element at the given position.
229 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
230 {-# INLINE read #-}
231 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
232 $ unsafeRead v i
233
234 -- | Replace the element at the given position.
235 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
236 {-# INLINE write #-}
237 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
238 $ unsafeWrite v i x
239
240 -- | Reset all elements of the vector to some undefined value, clearing all
241 -- references to external objects. This is usually a noop for unboxed vectors.
242 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
243 {-# INLINE clear #-}
244 clear = basicClear
245
246 -- | Set all elements of the vector to the given value.
247 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
248 {-# INLINE set #-}
249 set = basicSet
250
251 -- | Copy a vector. The two vectors must have the same length and may not
252 -- overlap.
253 copy :: (PrimMonad m, MVector v a)
254 => v (PrimState m) a -> v (PrimState m) a -> m ()
255 {-# INLINE copy #-}
256 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
257 (not (dst `overlaps` src))
258 $ BOUNDS_CHECK(check) "copy" "length mismatch"
259 (length dst == length src)
260 $ unsafeCopy dst src
261
262 -- | Grow a vector by the given number of elements. The number must be
263 -- positive.
264 grow :: (PrimMonad m, MVector v a)
265 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
266 {-# INLINE grow #-}
267 grow v by = BOUNDS_CHECK(checkLength) "grow" by
268 $ unsafeGrow v by
269
270 enlarge_delta v = max 1
271 $ double2Int
272 $ int2Double (length v) * gROWTH_FACTOR
273
274 -- | Grow a vector logarithmically
275 enlarge :: (PrimMonad m, MVector v a)
276 => v (PrimState m) a -> m (v (PrimState m) a)
277 {-# INLINE enlarge #-}
278 enlarge v = unsafeGrow v (enlarge_delta v)
279
280 enlargeFront :: (PrimMonad m, MVector v a)
281 => v (PrimState m) a -> m (v (PrimState m) a, Int)
282 {-# INLINE enlargeFront #-}
283 enlargeFront v = do
284 v' <- unsafeGrowFront v by
285 return (v', by)
286 where
287 by = enlarge_delta v
288
289 -- | Yield a part of the mutable vector without copying it.
290 slice :: MVector v a => Int -> Int -> v s a -> v s a
291 {-# INLINE slice #-}
292 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
293 $ unsafeSlice i n v
294
295 take :: MVector v a => Int -> v s a -> v s a
296 {-# INLINE take #-}
297 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
298
299 drop :: MVector v a => Int -> v s a -> v s a
300 {-# INLINE drop #-}
301 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
302 where
303 n' = max n 0
304 m = length v
305
306 init :: MVector v a => v s a -> v s a
307 {-# INLINE init #-}
308 init v = slice 0 (length v - 1) v
309
310 tail :: MVector v a => v s a -> v s a
311 {-# INLINE tail #-}
312 tail v = slice 1 (length v - 1) v
313
314 -- | Yield a part of the mutable vector without copying it. No bounds checks
315 -- are performed.
316 unsafeSlice :: MVector v a => Int -- ^ starting index
317 -> Int -- ^ length of the slice
318 -> v s a
319 -> v s a
320 {-# INLINE unsafeSlice #-}
321 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
322 $ basicUnsafeSlice i n v
323
324 unsafeInit :: MVector v a => v s a -> v s a
325 {-# INLINE unsafeInit #-}
326 unsafeInit v = unsafeSlice 0 (length v - 1) v
327
328 unsafeTail :: MVector v a => v s a -> v s a
329 {-# INLINE unsafeTail #-}
330 unsafeTail v = unsafeSlice 1 (length v - 1) v
331
332 unsafeAppend1 :: (PrimMonad m, MVector v a)
333 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
334 {-# INLINE_INNER unsafeAppend1 #-}
335 -- NOTE: The case distinction has to be on the outside because
336 -- GHC creates a join point for the unsafeWrite even when everything
337 -- is inlined. This is bad because with the join point, v isn't getting
338 -- unboxed.
339 unsafeAppend1 v i x
340 | i < length v = do
341 unsafeWrite v i x
342 return v
343 | otherwise = do
344 v' <- enlarge v
345 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
346 $ unsafeWrite v' i x
347 return v'
348
349 unsafePrepend1 :: (PrimMonad m, MVector v a)
350 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
351 {-# INLINE_INNER unsafePrepend1 #-}
352 unsafePrepend1 v i x
353 | i /= 0 = do
354 let i' = i-1
355 unsafeWrite v i' x
356 return (v, i')
357 | otherwise = do
358 (v', i) <- enlargeFront v
359 let i' = i-1
360 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
361 $ unsafeWrite v' i' x
362 return (v', i')
363
364
365 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
366 {-# INLINE mstream #-}
367 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
368 where
369 n = length v
370
371 {-# INLINE_INNER get #-}
372 get i | i < n = do x <- unsafeRead v i
373 return $ Just (x, i+1)
374 | otherwise = return $ Nothing
375
376 munstream :: (PrimMonad m, MVector v a)
377 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
378 {-# INLINE munstream #-}
379 munstream v s = v `seq` do
380 n' <- MStream.foldM put 0 s
381 return $ unsafeSlice 0 n' v
382 where
383 {-# INLINE_INNER put #-}
384 put i x = do
385 INTERNAL_CHECK(checkIndex) "munstream" i (length v)
386 $ unsafeWrite v i x
387 return (i+1)
388
389 transform :: (PrimMonad m, MVector v a)
390 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
391 {-# INLINE_STREAM transform #-}
392 transform f v = munstream v (f (mstream v))
393
394 mrstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
395 {-# INLINE mrstream #-}
396 mrstream v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
397 where
398 n = length v
399
400 {-# INLINE_INNER get #-}
401 get i | j >= 0 = do x <- unsafeRead v j
402 return $ Just (x,j)
403 | otherwise = return Nothing
404 where
405 j = i-1
406
407 munstreamR :: (PrimMonad m, MVector v a)
408 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
409 {-# INLINE munstreamR #-}
410 munstreamR v s = v `seq` do
411 i <- MStream.foldM put n s
412 return $ unsafeSlice i (n-i) v
413 where
414 n = length v
415
416 {-# INLINE_INNER put #-}
417 put i x = do
418 unsafeWrite v j x
419 return j
420 where
421 j = i-1
422
423 transformR :: (PrimMonad m, MVector v a)
424 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
425 {-# INLINE_STREAM transformR #-}
426 transformR f v = munstreamR v (f (mrstream v))
427
428 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
429 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
430 -- inexact.
431 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
432 {-# INLINE_STREAM unstream #-}
433 unstream s = case upperBound (Stream.size s) of
434 Just n -> unstreamMax s n
435 Nothing -> unstreamUnknown s
436
437 unstreamMax
438 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
439 {-# INLINE unstreamMax #-}
440 unstreamMax s n
441 = do
442 v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
443 $ unsafeNew n
444 let put i x = do
445 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
446 $ unsafeWrite v i x
447 return (i+1)
448 n' <- Stream.foldM' put 0 s
449 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
450 $ unsafeSlice 0 n' v
451
452 unstreamUnknown
453 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
454 {-# INLINE unstreamUnknown #-}
455 unstreamUnknown s
456 = do
457 v <- unsafeNew 0
458 (v', n) <- Stream.foldM put (v, 0) s
459 return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
460 $ slice 0 n v'
461 where
462 {-# INLINE_INNER put #-}
463 put (v,i) x = do
464 v' <- unsafeAppend1 v i x
465 return (v',i+1)
466
467 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
468 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
469 -- inexact.
470 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
471 {-# INLINE_STREAM unstreamR #-}
472 unstreamR s = case upperBound (Stream.size s) of
473 Just n -> unstreamRMax s n
474 Nothing -> unstreamRUnknown s
475
476 unstreamRMax
477 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
478 {-# INLINE unstreamRMax #-}
479 unstreamRMax s n
480 = do
481 v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
482 $ unsafeNew n
483 let put i x = do
484 let i' = i-1
485 INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
486 $ unsafeWrite v i x
487 return i
488 i <- Stream.foldM' put n s
489 return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (n-i) n
490 $ unsafeSlice i (n-i) v
491
492 unstreamRUnknown
493 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
494 {-# INLINE unstreamRUnknown #-}
495 unstreamRUnknown s
496 = do
497 v <- unsafeNew 0
498 (v', i) <- Stream.foldM put (v, 0) s
499 let n = length v'
500 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
501 $ slice i (n-i) v'
502 where
503 {-# INLINE_INNER put #-}
504 put (v,i) x = unsafePrepend1 v i x
505
506 unsafeAccum :: (PrimMonad m, MVector v a)
507 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
508 {-# INLINE unsafeAccum #-}
509 unsafeAccum f !v s = Stream.mapM_ upd s
510 where
511 {-# INLINE_INNER upd #-}
512 upd (i,b) = do
513 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
514 $ unsafeRead v i
515 unsafeWrite v i (f a b)
516
517 accum :: (PrimMonad m, MVector v a)
518 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
519 {-# INLINE accum #-}
520 accum f !v s = Stream.mapM_ upd s
521 where
522 {-# INLINE_INNER upd #-}
523 upd (i,b) = do
524 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
525 $ unsafeRead v i
526 unsafeWrite v i (f a b)
527
528 unsafeUpdate :: (PrimMonad m, MVector v a)
529 => v (PrimState m) a -> Stream (Int, a) -> m ()
530 {-# INLINE unsafeUpdate #-}
531 unsafeUpdate !v s = Stream.mapM_ upd s
532 where
533 {-# INLINE_INNER upd #-}
534 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
535 $ unsafeWrite v i b
536
537 update :: (PrimMonad m, MVector v a)
538 => v (PrimState m) a -> Stream (Int, a) -> m ()
539 {-# INLINE update #-}
540 update !v s = Stream.mapM_ upd s
541 where
542 {-# INLINE_INNER upd #-}
543 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
544 $ unsafeWrite v i b
545
546 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
547 {-# INLINE reverse #-}
548 reverse !v = reverse_loop 0 (length v - 1)
549 where
550 reverse_loop i j | i < j = do
551 x <- unsafeRead v i
552 y <- unsafeRead v j
553 unsafeWrite v i y
554 unsafeWrite v j x
555 reverse_loop (i + 1) (j - 1)
556 reverse_loop _ _ = return ()
557
558 unstablePartition :: (PrimMonad m, MVector v a)
559 => (a -> Bool) -> v (PrimState m) a -> m Int
560 {-# INLINE unstablePartition #-}
561 unstablePartition f !v = from_left 0 (length v)
562 where
563 from_left i j
564 | i == j = return i
565 | otherwise = do
566 x <- unsafeRead v i
567 if f x
568 then from_left (i+1) j
569 else from_right i (j-1)
570
571 from_right i j
572 | i == j = return i
573 | otherwise = do
574 x <- unsafeRead v j
575 if f x
576 then do
577 y <- unsafeRead v i
578 unsafeWrite v i x
579 unsafeWrite v j y
580 from_left (i+1) j
581 else from_right i (j-1)
582
583 unstablePartitionStream :: (PrimMonad m, MVector v a)
584 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
585 {-# INLINE unstablePartitionStream #-}
586 unstablePartitionStream f s
587 = case upperBound (Stream.size s) of
588 Just n -> unstablePartitionMax f s n
589 Nothing -> partitionUnknown f s
590
591
592 unstablePartitionMax :: (PrimMonad m, MVector v a)
593 => (a -> Bool) -> Stream a -> Int
594 -> m (v (PrimState m) a, v (PrimState m) a)
595 {-# INLINE unstablePartitionMax #-}
596 unstablePartitionMax f s n
597 = do
598 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
599 $ unsafeNew n
600 let {-# INLINE_INNER put #-}
601 put (i, j) x
602 | f x = do
603 unsafeWrite v i x
604 return (i+1, j)
605 | otherwise = do
606 unsafeWrite v (j-1) x
607 return (i, j-1)
608
609 (i,j) <- Stream.foldM' put (0, n) s
610 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
611
612 partitionUnknown :: (PrimMonad m, MVector v a)
613 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
614 {-# INLINE partitionUnknown #-}
615 partitionUnknown f s
616 = do
617 v1 <- unsafeNew 0
618 v2 <- unsafeNew 0
619 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
620 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
621 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
622 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
623 where
624 -- NOTE: The case distinction has to be on the outside because
625 -- GHC creates a join point for the unsafeWrite even when everything
626 -- is inlined. This is bad because with the join point, v isn't getting
627 -- unboxed.
628 {-# INLINE_INNER put #-}
629 put (v1, i1, v2, i2) x
630 | f x = do
631 v1' <- unsafeAppend1 v1 i1 x
632 return (v1', i1+1, v2, i2)
633 | otherwise = do
634 v2' <- unsafeAppend1 v2 i2 x
635 return (v1, i1, v2', i2+1)
636