Use basic* functions in default methods of MVector
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, transform, unstreamR, transformR,
30 unsafeAccum, accum, unsafeUpdate, update, reverse,
31 unstablePartition, unstablePartitionStream
32 ) where
33
34 import qualified Data.Vector.Fusion.Stream as Stream
35 import Data.Vector.Fusion.Stream ( Stream, MStream )
36 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
37 import Data.Vector.Fusion.Stream.Size
38
39 import Control.Monad.Primitive ( PrimMonad, PrimState )
40
41 import GHC.Float (
42 double2Int, int2Double
43 )
44
45 import Prelude hiding ( length, reverse, map, read,
46 take, drop, init, tail )
47
48 #include "vector.h"
49
50 gROWTH_FACTOR :: Double
51 gROWTH_FACTOR = 1.5
52
53 -- | Class of mutable vectors parametrised with a primitive state token.
54 --
55 class MVector v a where
56 -- | Length of the mutable vector. This method should not be
57 -- called directly, use 'length' instead.
58 basicLength :: v s a -> Int
59
60 -- | Yield a part of the mutable vector without copying it. This method
61 -- should not be called directly, use 'unsafeSlice' instead.
62 basicUnsafeSlice :: Int -- ^ starting index
63 -> Int -- ^ length of the slice
64 -> v s a
65 -> v s a
66
67 -- Check whether two vectors overlap. This method should not be
68 -- called directly, use 'overlaps' instead.
69 basicOverlaps :: v s a -> v s a -> Bool
70
71 -- | Create a mutable vector of the given length. This method should not be
72 -- called directly, use 'unsafeNew' instead.
73 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
74
75 -- | Create a mutable vector of the given length and fill it with an
76 -- initial value. This method should not be called directly, use
77 -- 'unsafeNewWith' instead.
78 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
79
80 -- | Yield the element at the given position. This method should not be
81 -- called directly, use 'unsafeRead' instead.
82 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
83
84 -- | Replace the element at the given position. This method should not be
85 -- called directly, use 'unsafeWrite' instead.
86 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
87
88 -- | Reset all elements of the vector to some undefined value, clearing all
89 -- references to external objects. This is usually a noop for unboxed
90 -- vectors. This method should not be called directly, use 'clear' instead.
91 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
92
93 -- | Set all elements of the vector to the given value. This method should
94 -- not be called directly, use 'set' instead.
95 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
96
97 -- | Copy a vector. The two vectors may not overlap. This method should not
98 -- be called directly, use 'unsafeCopy' instead.
99 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
100 -> v (PrimState m) a -- ^ source
101 -> m ()
102
103 -- | Grow a vector by the given number of elements. This method should not be
104 -- called directly, use 'unsafeGrow' instead.
105 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
106 -> m (v (PrimState m) a)
107
108 {-# INLINE basicUnsafeNewWith #-}
109 basicUnsafeNewWith n x
110 = do
111 v <- basicUnsafeNew n
112 basicSet v x
113 return v
114
115 {-# INLINE basicClear #-}
116 basicClear _ = return ()
117
118 {-# INLINE basicSet #-}
119 basicSet v x = do_set 0
120 where
121 n = basicLength v
122
123 do_set i | i < n = do
124 basicUnsafeWrite v i x
125 do_set (i+1)
126 | otherwise = return ()
127
128 {-# INLINE basicUnsafeCopy #-}
129 basicUnsafeCopy dst src = do_copy 0
130 where
131 n = basicLength src
132
133 do_copy i | i < n = do
134 x <- basicUnsafeRead src i
135 basicUnsafeWrite dst i x
136 do_copy (i+1)
137 | otherwise = return ()
138
139 {-# INLINE basicUnsafeGrow #-}
140 basicUnsafeGrow v by
141 = do
142 v' <- basicUnsafeNew (n+by)
143 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
144 return v'
145 where
146 n = basicLength v
147
148
149 -- | Create a mutable vector of the given length. The length is not checked.
150 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
151 {-# INLINE unsafeNew #-}
152 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
153 $ basicUnsafeNew n
154
155 -- | Create a mutable vector of the given length and fill it with an
156 -- initial value. The length is not checked.
157 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
158 {-# INLINE unsafeNewWith #-}
159 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
160 $ basicUnsafeNewWith n x
161
162 -- | Yield the element at the given position. No bounds checks are performed.
163 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
164 {-# INLINE unsafeRead #-}
165 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
166 $ basicUnsafeRead v i
167
168 -- | Replace the element at the given position. No bounds checks are performed.
169 unsafeWrite :: (PrimMonad m, MVector v a)
170 => v (PrimState m) a -> Int -> a -> m ()
171 {-# INLINE unsafeWrite #-}
172 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
173 $ basicUnsafeWrite v i x
174
175
176 -- | Copy a vector. The two vectors must have the same length and may not
177 -- overlap. This is not checked.
178 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
179 -> v (PrimState m) a -- ^ source
180 -> m ()
181 {-# INLINE unsafeCopy #-}
182 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
183 (length dst == length src)
184 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
185 (not (dst `overlaps` src))
186 $ basicUnsafeCopy dst src
187
188 -- | Grow a vector by the given number of elements. The number must be
189 -- positive but this is not checked.
190 unsafeGrow :: (PrimMonad m, MVector v a)
191 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
192 {-# INLINE unsafeGrow #-}
193 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
194 $ basicUnsafeGrow v n
195
196 unsafeGrowFront :: (PrimMonad m, MVector v a)
197 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
198 {-# INLINE unsafeGrowFront #-}
199 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
200 $ do
201 let n = length v
202 v' <- basicUnsafeNew (by+n)
203 basicUnsafeCopy (basicUnsafeSlice by n v') v
204 return v'
205
206 -- | Length of the mutable vector.
207 length :: MVector v a => v s a -> Int
208 {-# INLINE length #-}
209 length = basicLength
210
211 -- Check whether two vectors overlap.
212 overlaps :: MVector v a => v s a -> v s a -> Bool
213 {-# INLINE overlaps #-}
214 overlaps = basicOverlaps
215
216 -- | Create a mutable vector of the given length.
217 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
218 {-# INLINE new #-}
219 new n = BOUNDS_CHECK(checkLength) "new" n
220 $ unsafeNew n
221
222 -- | Create a mutable vector of the given length and fill it with an
223 -- initial value.
224 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
225 {-# INLINE newWith #-}
226 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
227 $ unsafeNewWith n x
228
229 -- | Yield the element at the given position.
230 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
231 {-# INLINE read #-}
232 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
233 $ unsafeRead v i
234
235 -- | Replace the element at the given position.
236 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
237 {-# INLINE write #-}
238 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
239 $ unsafeWrite v i x
240
241 -- | Reset all elements of the vector to some undefined value, clearing all
242 -- references to external objects. This is usually a noop for unboxed vectors.
243 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
244 {-# INLINE clear #-}
245 clear = basicClear
246
247 -- | Set all elements of the vector to the given value.
248 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
249 {-# INLINE set #-}
250 set = basicSet
251
252 -- | Copy a vector. The two vectors must have the same length and may not
253 -- overlap.
254 copy :: (PrimMonad m, MVector v a)
255 => v (PrimState m) a -> v (PrimState m) a -> m ()
256 {-# INLINE copy #-}
257 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
258 (not (dst `overlaps` src))
259 $ BOUNDS_CHECK(check) "copy" "length mismatch"
260 (length dst == length src)
261 $ unsafeCopy dst src
262
263 -- | Grow a vector by the given number of elements. The number must be
264 -- positive.
265 grow :: (PrimMonad m, MVector v a)
266 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
267 {-# INLINE grow #-}
268 grow v by = BOUNDS_CHECK(checkLength) "grow" by
269 $ unsafeGrow v by
270
271 enlarge_delta v = max 1
272 $ double2Int
273 $ int2Double (length v) * gROWTH_FACTOR
274
275 -- | Grow a vector logarithmically
276 enlarge :: (PrimMonad m, MVector v a)
277 => v (PrimState m) a -> m (v (PrimState m) a)
278 {-# INLINE enlarge #-}
279 enlarge v = unsafeGrow v (enlarge_delta v)
280
281 enlargeFront :: (PrimMonad m, MVector v a)
282 => v (PrimState m) a -> m (v (PrimState m) a, Int)
283 {-# INLINE enlargeFront #-}
284 enlargeFront v = do
285 v' <- unsafeGrowFront v by
286 return (v', by)
287 where
288 by = enlarge_delta v
289
290 -- | Yield a part of the mutable vector without copying it.
291 slice :: MVector v a => Int -> Int -> v s a -> v s a
292 {-# INLINE slice #-}
293 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
294 $ unsafeSlice i n v
295
296 take :: MVector v a => Int -> v s a -> v s a
297 {-# INLINE take #-}
298 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
299
300 drop :: MVector v a => Int -> v s a -> v s a
301 {-# INLINE drop #-}
302 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
303 where
304 n' = max n 0
305 m = length v
306
307 init :: MVector v a => v s a -> v s a
308 {-# INLINE init #-}
309 init v = slice 0 (length v - 1) v
310
311 tail :: MVector v a => v s a -> v s a
312 {-# INLINE tail #-}
313 tail v = slice 1 (length v - 1) v
314
315 -- | Yield a part of the mutable vector without copying it. No bounds checks
316 -- are performed.
317 unsafeSlice :: MVector v a => Int -- ^ starting index
318 -> Int -- ^ length of the slice
319 -> v s a
320 -> v s a
321 {-# INLINE unsafeSlice #-}
322 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
323 $ basicUnsafeSlice i n v
324
325 unsafeInit :: MVector v a => v s a -> v s a
326 {-# INLINE unsafeInit #-}
327 unsafeInit v = unsafeSlice 0 (length v - 1) v
328
329 unsafeTail :: MVector v a => v s a -> v s a
330 {-# INLINE unsafeTail #-}
331 unsafeTail v = unsafeSlice 1 (length v - 1) v
332
333 unsafeAppend1 :: (PrimMonad m, MVector v a)
334 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
335 {-# INLINE_INNER unsafeAppend1 #-}
336 -- NOTE: The case distinction has to be on the outside because
337 -- GHC creates a join point for the unsafeWrite even when everything
338 -- is inlined. This is bad because with the join point, v isn't getting
339 -- unboxed.
340 unsafeAppend1 v i x
341 | i < length v = do
342 unsafeWrite v i x
343 return v
344 | otherwise = do
345 v' <- enlarge v
346 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
347 $ unsafeWrite v' i x
348 return v'
349
350 unsafePrepend1 :: (PrimMonad m, MVector v a)
351 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
352 {-# INLINE_INNER unsafePrepend1 #-}
353 unsafePrepend1 v i x
354 | i /= 0 = do
355 let i' = i-1
356 unsafeWrite v i' x
357 return (v, i')
358 | otherwise = do
359 (v', i) <- enlargeFront v
360 let i' = i-1
361 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
362 $ unsafeWrite v' i' x
363 return (v', i')
364
365
366 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
367 {-# INLINE mstream #-}
368 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
369 where
370 n = length v
371
372 {-# INLINE_INNER get #-}
373 get i | i < n = do x <- unsafeRead v i
374 return $ Just (x, i+1)
375 | otherwise = return $ Nothing
376
377 munstream :: (PrimMonad m, MVector v a)
378 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
379 {-# INLINE munstream #-}
380 munstream v s = v `seq` do
381 n' <- MStream.foldM put 0 s
382 return $ unsafeSlice 0 n' v
383 where
384 {-# INLINE_INNER put #-}
385 put i x = do
386 INTERNAL_CHECK(checkIndex) "munstream" i (length v)
387 $ unsafeWrite v i x
388 return (i+1)
389
390 transform :: (PrimMonad m, MVector v a)
391 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
392 {-# INLINE_STREAM transform #-}
393 transform f v = munstream v (f (mstream v))
394
395 mrstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
396 {-# INLINE mrstream #-}
397 mrstream v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
398 where
399 n = length v
400
401 {-# INLINE_INNER get #-}
402 get i | j >= 0 = do x <- unsafeRead v j
403 return $ Just (x,j)
404 | otherwise = return Nothing
405 where
406 j = i-1
407
408 munstreamR :: (PrimMonad m, MVector v a)
409 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
410 {-# INLINE munstreamR #-}
411 munstreamR v s = v `seq` do
412 i <- MStream.foldM put n s
413 return $ unsafeSlice i (n-i) v
414 where
415 n = length v
416
417 {-# INLINE_INNER put #-}
418 put i x = do
419 unsafeWrite v j x
420 return j
421 where
422 j = i-1
423
424 transformR :: (PrimMonad m, MVector v a)
425 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
426 {-# INLINE_STREAM transformR #-}
427 transformR f v = munstreamR v (f (mrstream v))
428
429 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
430 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
431 -- inexact.
432 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
433 {-# INLINE_STREAM unstream #-}
434 unstream s = case upperBound (Stream.size s) of
435 Just n -> unstreamMax s n
436 Nothing -> unstreamUnknown s
437
438 unstreamMax
439 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
440 {-# INLINE unstreamMax #-}
441 unstreamMax s n
442 = do
443 v <- INTERNAL_CHECK(checkLength) "unstreamMax" n
444 $ unsafeNew n
445 let put i x = do
446 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
447 $ unsafeWrite v i x
448 return (i+1)
449 n' <- Stream.foldM' put 0 s
450 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n
451 $ unsafeSlice 0 n' v
452
453 unstreamUnknown
454 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
455 {-# INLINE unstreamUnknown #-}
456 unstreamUnknown s
457 = do
458 v <- unsafeNew 0
459 (v', n) <- Stream.foldM put (v, 0) s
460 return $ INTERNAL_CHECK(checkSlice) "unstreamUnknown" 0 n (length v')
461 $ slice 0 n v'
462 where
463 {-# INLINE_INNER put #-}
464 put (v,i) x = do
465 v' <- unsafeAppend1 v i x
466 return (v',i+1)
467
468 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
469 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
470 -- inexact.
471 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
472 {-# INLINE_STREAM unstreamR #-}
473 unstreamR s = case upperBound (Stream.size s) of
474 Just n -> unstreamRMax s n
475 Nothing -> unstreamRUnknown s
476
477 unstreamRMax
478 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
479 {-# INLINE unstreamRMax #-}
480 unstreamRMax s n
481 = do
482 v <- INTERNAL_CHECK(checkLength) "unstreamRMax" n
483 $ unsafeNew n
484 let put i x = do
485 let i' = i-1
486 INTERNAL_CHECK(checkIndex) "unstreamRMax" i' n
487 $ unsafeWrite v i x
488 return i
489 i <- Stream.foldM' put n s
490 return $ INTERNAL_CHECK(checkSlice) "unstreamRMax" i (n-i) n
491 $ unsafeSlice i (n-i) v
492
493 unstreamRUnknown
494 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
495 {-# INLINE unstreamRUnknown #-}
496 unstreamRUnknown s
497 = do
498 v <- unsafeNew 0
499 (v', i) <- Stream.foldM put (v, 0) s
500 let n = length v'
501 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
502 $ slice i (n-i) v'
503 where
504 {-# INLINE_INNER put #-}
505 put (v,i) x = unsafePrepend1 v i x
506
507 unsafeAccum :: (PrimMonad m, MVector v a)
508 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
509 {-# INLINE unsafeAccum #-}
510 unsafeAccum f !v s = Stream.mapM_ upd s
511 where
512 {-# INLINE_INNER upd #-}
513 upd (i,b) = do
514 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
515 $ unsafeRead v i
516 unsafeWrite v i (f a b)
517
518 accum :: (PrimMonad m, MVector v a)
519 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
520 {-# INLINE accum #-}
521 accum f !v s = Stream.mapM_ upd s
522 where
523 {-# INLINE_INNER upd #-}
524 upd (i,b) = do
525 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
526 $ unsafeRead v i
527 unsafeWrite v i (f a b)
528
529 unsafeUpdate :: (PrimMonad m, MVector v a)
530 => v (PrimState m) a -> Stream (Int, a) -> m ()
531 {-# INLINE unsafeUpdate #-}
532 unsafeUpdate !v s = Stream.mapM_ upd s
533 where
534 {-# INLINE_INNER upd #-}
535 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
536 $ unsafeWrite v i b
537
538 update :: (PrimMonad m, MVector v a)
539 => v (PrimState m) a -> Stream (Int, a) -> m ()
540 {-# INLINE update #-}
541 update !v s = Stream.mapM_ upd s
542 where
543 {-# INLINE_INNER upd #-}
544 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
545 $ unsafeWrite v i b
546
547 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
548 {-# INLINE reverse #-}
549 reverse !v = reverse_loop 0 (length v - 1)
550 where
551 reverse_loop i j | i < j = do
552 x <- unsafeRead v i
553 y <- unsafeRead v j
554 unsafeWrite v i y
555 unsafeWrite v j x
556 reverse_loop (i + 1) (j - 1)
557 reverse_loop _ _ = return ()
558
559 unstablePartition :: (PrimMonad m, MVector v a)
560 => (a -> Bool) -> v (PrimState m) a -> m Int
561 {-# INLINE unstablePartition #-}
562 unstablePartition f !v = from_left 0 (length v)
563 where
564 from_left i j
565 | i == j = return i
566 | otherwise = do
567 x <- unsafeRead v i
568 if f x
569 then from_left (i+1) j
570 else from_right i (j-1)
571
572 from_right i j
573 | i == j = return i
574 | otherwise = do
575 x <- unsafeRead v j
576 if f x
577 then do
578 y <- unsafeRead v i
579 unsafeWrite v i x
580 unsafeWrite v j y
581 from_left (i+1) j
582 else from_right i (j-1)
583
584 unstablePartitionStream :: (PrimMonad m, MVector v a)
585 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
586 {-# INLINE unstablePartitionStream #-}
587 unstablePartitionStream f s
588 = case upperBound (Stream.size s) of
589 Just n -> unstablePartitionMax f s n
590 Nothing -> partitionUnknown f s
591
592
593 unstablePartitionMax :: (PrimMonad m, MVector v a)
594 => (a -> Bool) -> Stream a -> Int
595 -> m (v (PrimState m) a, v (PrimState m) a)
596 {-# INLINE unstablePartitionMax #-}
597 unstablePartitionMax f s n
598 = do
599 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
600 $ unsafeNew n
601 let {-# INLINE_INNER put #-}
602 put (i, j) x
603 | f x = do
604 unsafeWrite v i x
605 return (i+1, j)
606 | otherwise = do
607 unsafeWrite v (j-1) x
608 return (i, j-1)
609
610 (i,j) <- Stream.foldM' put (0, n) s
611 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
612
613 partitionUnknown :: (PrimMonad m, MVector v a)
614 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
615 {-# INLINE partitionUnknown #-}
616 partitionUnknown f s
617 = do
618 v1 <- unsafeNew 0
619 v2 <- unsafeNew 0
620 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
621 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
622 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
623 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
624 where
625 -- NOTE: The case distinction has to be on the outside because
626 -- GHC creates a join point for the unsafeWrite even when everything
627 -- is inlined. This is bad because with the join point, v isn't getting
628 -- unboxed.
629 {-# INLINE_INNER put #-}
630 put (v1, i1, v2, i2) x
631 | f x = do
632 v1' <- unsafeAppend1 v1 i1 x
633 return (v1', i1+1, v2, i2)
634 | otherwise = do
635 v2' <- unsafeAppend1 v2 i2 x
636 return (v1, i1, v2', i2+1)
637