Implement proper monadic unstreaming for mutable vectors
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns, ScopedTypeVariables #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2010
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 -- * Class of mutable vector types
16 MVector(..),
17
18 -- * Operations on mutable vectors
19 length, overlaps, new, newWith, read, write, swap, clear, set, copy, grow,
20
21 slice, take, drop, init, tail,
22 unsafeSlice, unsafeInit, unsafeTail,
23
24 -- * Unsafe operations
25 unsafeNew, unsafeNewWith, unsafeRead, unsafeWrite, unsafeSwap,
26 unsafeCopy, unsafeGrow,
27
28 -- * Internal operations
29 unstream, unstreamR,
30 munstream, munstreamR,
31 transform, transformR,
32 fill, fillR,
33 unsafeAccum, accum, unsafeUpdate, update, reverse,
34 unstablePartition, unstablePartitionStream, partitionStream
35 ) where
36
37 import qualified Data.Vector.Fusion.Stream as Stream
38 import Data.Vector.Fusion.Stream ( Stream, MStream )
39 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
40 import Data.Vector.Fusion.Stream.Size
41
42 import Control.Monad.Primitive ( PrimMonad, PrimState )
43
44 import Prelude hiding ( length, reverse, map, read,
45 take, drop, init, tail )
46
47 #include "vector.h"
48
49 -- | Class of mutable vectors parametrised with a primitive state token.
50 --
51 class MVector v a where
52 -- | Length of the mutable vector. This method should not be
53 -- called directly, use 'length' instead.
54 basicLength :: v s a -> Int
55
56 -- | Yield a part of the mutable vector without copying it. This method
57 -- should not be called directly, use 'unsafeSlice' instead.
58 basicUnsafeSlice :: Int -- ^ starting index
59 -> Int -- ^ length of the slice
60 -> v s a
61 -> v s a
62
63 -- Check whether two vectors overlap. This method should not be
64 -- called directly, use 'overlaps' instead.
65 basicOverlaps :: v s a -> v s a -> Bool
66
67 -- | Create a mutable vector of the given length. This method should not be
68 -- called directly, use 'unsafeNew' instead.
69 basicUnsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
70
71 -- | Create a mutable vector of the given length and fill it with an
72 -- initial value. This method should not be called directly, use
73 -- 'unsafeNewWith' instead.
74 basicUnsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
75
76 -- | Yield the element at the given position. This method should not be
77 -- called directly, use 'unsafeRead' instead.
78 basicUnsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
79
80 -- | Replace the element at the given position. This method should not be
81 -- called directly, use 'unsafeWrite' instead.
82 basicUnsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
83
84 -- | Reset all elements of the vector to some undefined value, clearing all
85 -- references to external objects. This is usually a noop for unboxed
86 -- vectors. This method should not be called directly, use 'clear' instead.
87 basicClear :: PrimMonad m => v (PrimState m) a -> m ()
88
89 -- | Set all elements of the vector to the given value. This method should
90 -- not be called directly, use 'set' instead.
91 basicSet :: PrimMonad m => v (PrimState m) a -> a -> m ()
92
93 -- | Copy a vector. The two vectors may not overlap. This method should not
94 -- be called directly, use 'unsafeCopy' instead.
95 basicUnsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
96 -> v (PrimState m) a -- ^ source
97 -> m ()
98
99 -- | Grow a vector by the given number of elements. This method should not be
100 -- called directly, use 'unsafeGrow' instead.
101 basicUnsafeGrow :: PrimMonad m => v (PrimState m) a -> Int
102 -> m (v (PrimState m) a)
103
104 {-# INLINE basicUnsafeNewWith #-}
105 basicUnsafeNewWith n x
106 = do
107 v <- basicUnsafeNew n
108 basicSet v x
109 return v
110
111 {-# INLINE basicClear #-}
112 basicClear _ = return ()
113
114 {-# INLINE basicSet #-}
115 basicSet v x = do_set 0
116 where
117 n = basicLength v
118
119 do_set i | i < n = do
120 basicUnsafeWrite v i x
121 do_set (i+1)
122 | otherwise = return ()
123
124 {-# INLINE basicUnsafeCopy #-}
125 basicUnsafeCopy dst src = do_copy 0
126 where
127 n = basicLength src
128
129 do_copy i | i < n = do
130 x <- basicUnsafeRead src i
131 basicUnsafeWrite dst i x
132 do_copy (i+1)
133 | otherwise = return ()
134
135 {-# INLINE basicUnsafeGrow #-}
136 basicUnsafeGrow v by
137 = do
138 v' <- basicUnsafeNew (n+by)
139 basicUnsafeCopy (basicUnsafeSlice 0 n v') v
140 return v'
141 where
142 n = basicLength v
143
144 -- ------------------
145 -- Internal functions
146 -- ------------------
147
148 -- Check whether two vectors overlap.
149 overlaps :: MVector v a => v s a -> v s a -> Bool
150 {-# INLINE overlaps #-}
151 overlaps = basicOverlaps
152
153 unsafeAppend1 :: (PrimMonad m, MVector v a)
154 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a)
155 {-# INLINE_INNER unsafeAppend1 #-}
156 -- NOTE: The case distinction has to be on the outside because
157 -- GHC creates a join point for the unsafeWrite even when everything
158 -- is inlined. This is bad because with the join point, v isn't getting
159 -- unboxed.
160 unsafeAppend1 v i x
161 | i < length v = do
162 unsafeWrite v i x
163 return v
164 | otherwise = do
165 v' <- enlarge v
166 INTERNAL_CHECK(checkIndex) "unsafeAppend1" i (length v')
167 $ unsafeWrite v' i x
168 return v'
169
170 unsafePrepend1 :: (PrimMonad m, MVector v a)
171 => v (PrimState m) a -> Int -> a -> m (v (PrimState m) a, Int)
172 {-# INLINE_INNER unsafePrepend1 #-}
173 unsafePrepend1 v i x
174 | i /= 0 = do
175 let i' = i-1
176 unsafeWrite v i' x
177 return (v, i')
178 | otherwise = do
179 (v', i) <- enlargeFront v
180 let i' = i-1
181 INTERNAL_CHECK(checkIndex) "unsafePrepend1" i' (length v')
182 $ unsafeWrite v' i' x
183 return (v', i')
184
185 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
186 {-# INLINE mstream #-}
187 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
188 where
189 n = length v
190
191 {-# INLINE_INNER get #-}
192 get i | i < n = do x <- unsafeRead v i
193 return $ Just (x, i+1)
194 | otherwise = return $ Nothing
195
196 fill :: (PrimMonad m, MVector v a)
197 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
198 {-# INLINE fill #-}
199 fill v s = v `seq` do
200 n' <- MStream.foldM put 0 s
201 return $ unsafeSlice 0 n' v
202 where
203 {-# INLINE_INNER put #-}
204 put i x = do
205 INTERNAL_CHECK(checkIndex) "fill" i (length v)
206 $ unsafeWrite v i x
207 return (i+1)
208
209 transform :: (PrimMonad m, MVector v a)
210 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
211 {-# INLINE_STREAM transform #-}
212 transform f v = fill v (f (mstream v))
213
214 mstreamR :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
215 {-# INLINE mstreamR #-}
216 mstreamR v = v `seq` (MStream.unfoldrM get n `MStream.sized` Exact n)
217 where
218 n = length v
219
220 {-# INLINE_INNER get #-}
221 get i | j >= 0 = do x <- unsafeRead v j
222 return $ Just (x,j)
223 | otherwise = return Nothing
224 where
225 j = i-1
226
227 fillR :: (PrimMonad m, MVector v a)
228 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
229 {-# INLINE fillR #-}
230 fillR v s = v `seq` do
231 i <- MStream.foldM put n s
232 return $ unsafeSlice i (n-i) v
233 where
234 n = length v
235
236 {-# INLINE_INNER put #-}
237 put i x = do
238 unsafeWrite v j x
239 return j
240 where
241 j = i-1
242
243 transformR :: (PrimMonad m, MVector v a)
244 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
245 {-# INLINE_STREAM transformR #-}
246 transformR f v = fillR v (f (mstreamR v))
247
248 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
249 -- The vector will grow exponentially if the maximum size of the 'Stream' is
250 -- unknown.
251 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
252 -- NOTE: replace INLINE_STREAM by INLINE? (also in unstreamR)
253 {-# INLINE_STREAM unstream #-}
254 unstream s = munstream (Stream.liftStream s)
255
256 -- | Create a new mutable vector and fill it with elements from the monadic
257 -- stream. The vector will grow exponentially if the maximum size of the stream
258 -- is unknown.
259 munstream :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
260 {-# INLINE_STREAM munstream #-}
261 munstream s = case upperBound (MStream.size s) of
262 Just n -> munstreamMax s n
263 Nothing -> munstreamUnknown s
264
265 -- FIXME: I can't think of how to prevent GHC from floating out
266 -- unstreamUnknown. That is bad because SpecConstr then generates two
267 -- specialisations: one for when it is called from unstream (it doesn't know
268 -- the shape of the vector) and one for when the vector has grown. To see the
269 -- problem simply compile this:
270 --
271 -- fromList = Data.Vector.Unboxed.unstream . Stream.fromList
272 --
273 -- I'm not sure this still applies (19/04/2010)
274
275 munstreamMax
276 :: (PrimMonad m, MVector v a) => MStream m a -> Int -> m (v (PrimState m) a)
277 {-# INLINE munstreamMax #-}
278 munstreamMax s n
279 = do
280 v <- INTERNAL_CHECK(checkLength) "munstreamMax" n
281 $ unsafeNew n
282 let put i x = do
283 INTERNAL_CHECK(checkIndex) "munstreamMax" i n
284 $ unsafeWrite v i x
285 return (i+1)
286 n' <- MStream.foldM' put 0 s
287 return $ INTERNAL_CHECK(checkSlice) "munstreamMax" 0 n' n
288 $ unsafeSlice 0 n' v
289
290 munstreamUnknown
291 :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
292 {-# INLINE munstreamUnknown #-}
293 munstreamUnknown s
294 = do
295 v <- unsafeNew 0
296 (v', n) <- MStream.foldM put (v, 0) s
297 return $ INTERNAL_CHECK(checkSlice) "munstreamUnknown" 0 n (length v')
298 $ unsafeSlice 0 n v'
299 where
300 {-# INLINE_INNER put #-}
301 put (v,i) x = do
302 v' <- unsafeAppend1 v i x
303 return (v',i+1)
304
305 -- | Create a new mutable vector and fill it with elements from the 'Stream'
306 -- from right to left. The vector will grow exponentially if the maximum size
307 -- of the 'Stream' is unknown.
308 unstreamR :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
309 -- NOTE: replace INLINE_STREAM by INLINE? (also in unstream)
310 {-# INLINE_STREAM unstreamR #-}
311 unstreamR s = munstreamR (Stream.liftStream s)
312
313 -- | Create a new mutable vector and fill it with elements from the monadic
314 -- stream from right to left. The vector will grow exponentially if the maximum
315 -- size of the stream is unknown.
316 munstreamR :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
317 {-# INLINE_STREAM munstreamR #-}
318 munstreamR s = case upperBound (MStream.size s) of
319 Just n -> munstreamRMax s n
320 Nothing -> munstreamRUnknown s
321
322 munstreamRMax
323 :: (PrimMonad m, MVector v a) => MStream m a -> Int -> m (v (PrimState m) a)
324 {-# INLINE munstreamRMax #-}
325 munstreamRMax s n
326 = do
327 v <- INTERNAL_CHECK(checkLength) "munstreamRMax" n
328 $ unsafeNew n
329 let put i x = do
330 let i' = i-1
331 INTERNAL_CHECK(checkIndex) "munstreamRMax" i' n
332 $ unsafeWrite v i' x
333 return i'
334 i <- MStream.foldM' put n s
335 return $ INTERNAL_CHECK(checkSlice) "munstreamRMax" i (n-i) n
336 $ unsafeSlice i (n-i) v
337
338 munstreamRUnknown
339 :: (PrimMonad m, MVector v a) => MStream m a -> m (v (PrimState m) a)
340 {-# INLINE munstreamRUnknown #-}
341 munstreamRUnknown s
342 = do
343 v <- unsafeNew 0
344 (v', i) <- MStream.foldM put (v, 0) s
345 let n = length v'
346 return $ INTERNAL_CHECK(checkSlice) "unstreamRUnknown" i (n-i) n
347 $ unsafeSlice i (n-i) v'
348 where
349 {-# INLINE_INNER put #-}
350 put (v,i) x = unsafePrepend1 v i x
351
352 -- Length
353 -- ------
354
355 -- | Length of the mutable vector.
356 length :: MVector v a => v s a -> Int
357 {-# INLINE length #-}
358 length = basicLength
359
360 -- | Check whether the vector is empty
361 null :: MVector v a => v s a -> Bool
362 {-# INLINE null #-}
363 null v = length v == 0
364
365
366 -- Construction
367 -- ------------
368
369 -- | Create a mutable vector of the given length.
370 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
371 {-# INLINE new #-}
372 new n = BOUNDS_CHECK(checkLength) "new" n
373 $ unsafeNew n
374
375 -- | Create a mutable vector of the given length and fill it with an
376 -- initial value.
377 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
378 {-# INLINE newWith #-}
379 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
380 $ unsafeNewWith n x
381
382 -- | Create a mutable vector of the given length. The length is not checked.
383 unsafeNew :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
384 {-# INLINE unsafeNew #-}
385 unsafeNew n = UNSAFE_CHECK(checkLength) "unsafeNew" n
386 $ basicUnsafeNew n
387
388 -- | Create a mutable vector of the given length and fill it with an
389 -- initial value. The length is not checked.
390 unsafeNewWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
391 {-# INLINE unsafeNewWith #-}
392 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
393 $ basicUnsafeNewWith n x
394
395
396 -- Growing
397 -- -------
398
399 -- | Grow a vector by the given number of elements. The number must be
400 -- positive.
401 grow :: (PrimMonad m, MVector v a)
402 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
403 {-# INLINE grow #-}
404 grow v by = BOUNDS_CHECK(checkLength) "grow" by
405 $ unsafeGrow v by
406
407 growFront :: (PrimMonad m, MVector v a)
408 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
409 {-# INLINE growFront #-}
410 growFront v by = BOUNDS_CHECK(checkLength) "growFront" by
411 $ unsafeGrowFront v by
412
413 enlarge_delta v = max (length v) 1
414
415 -- | Grow a vector logarithmically
416 enlarge :: (PrimMonad m, MVector v a)
417 => v (PrimState m) a -> m (v (PrimState m) a)
418 {-# INLINE enlarge #-}
419 enlarge v = unsafeGrow v (enlarge_delta v)
420
421 enlargeFront :: (PrimMonad m, MVector v a)
422 => v (PrimState m) a -> m (v (PrimState m) a, Int)
423 {-# INLINE enlargeFront #-}
424 enlargeFront v = do
425 v' <- unsafeGrowFront v by
426 return (v', by)
427 where
428 by = enlarge_delta v
429
430 -- | Grow a vector by the given number of elements. The number must be
431 -- positive but this is not checked.
432 unsafeGrow :: (PrimMonad m, MVector v a)
433 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
434 {-# INLINE unsafeGrow #-}
435 unsafeGrow v n = UNSAFE_CHECK(checkLength) "unsafeGrow" n
436 $ basicUnsafeGrow v n
437
438 unsafeGrowFront :: (PrimMonad m, MVector v a)
439 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
440 {-# INLINE unsafeGrowFront #-}
441 unsafeGrowFront v by = UNSAFE_CHECK(checkLength) "unsafeGrowFront" by
442 $ do
443 let n = length v
444 v' <- basicUnsafeNew (by+n)
445 basicUnsafeCopy (basicUnsafeSlice by n v') v
446 return v'
447
448 -- Accessing individual elements
449 -- -----------------------------
450
451 -- | Yield the element at the given position.
452 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
453 {-# INLINE read #-}
454 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
455 $ unsafeRead v i
456
457 -- | Replace the element at the given position.
458 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
459 {-# INLINE write #-}
460 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
461 $ unsafeWrite v i x
462
463 -- | Swap the elements at the given positions.
464 swap :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> Int -> m ()
465 {-# INLINE swap #-}
466 swap v i j = BOUNDS_CHECK(checkIndex) "swap" i (length v)
467 $ BOUNDS_CHECK(checkIndex) "swap" j (length v)
468 $ unsafeSwap v i j
469
470 -- | Replace the element at the give position and return the old element.
471 exchange :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m a
472 {-# INLINE exchange #-}
473 exchange v i x = BOUNDS_CHECK(checkIndex) "exchange" i (length v)
474 $ unsafeExchange v i x
475
476 -- | Yield the element at the given position. No bounds checks are performed.
477 unsafeRead :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
478 {-# INLINE unsafeRead #-}
479 unsafeRead v i = UNSAFE_CHECK(checkIndex) "unsafeRead" i (length v)
480 $ basicUnsafeRead v i
481
482 -- | Replace the element at the given position. No bounds checks are performed.
483 unsafeWrite :: (PrimMonad m, MVector v a)
484 => v (PrimState m) a -> Int -> a -> m ()
485 {-# INLINE unsafeWrite #-}
486 unsafeWrite v i x = UNSAFE_CHECK(checkIndex) "unsafeWrite" i (length v)
487 $ basicUnsafeWrite v i x
488
489 -- | Swap the elements at the given positions. No bounds checks are performed.
490 unsafeSwap :: (PrimMonad m, MVector v a)
491 => v (PrimState m) a -> Int -> Int -> m ()
492 {-# INLINE unsafeSwap #-}
493 unsafeSwap v i j = UNSAFE_CHECK(checkIndex) "unsafeSwap" i (length v)
494 $ UNSAFE_CHECK(checkIndex) "unsafeSwap" j (length v)
495 $ do
496 x <- unsafeRead v i
497 y <- unsafeRead v j
498 unsafeWrite v i y
499 unsafeWrite v j x
500
501 -- | Replace the element at the give position and return the old element. No
502 -- bounds checks are performed.
503 unsafeExchange :: (PrimMonad m, MVector v a)
504 => v (PrimState m) a -> Int -> a -> m a
505 {-# INLINE unsafeExchange #-}
506 unsafeExchange v i x = UNSAFE_CHECK(checkIndex) "unsafeExchange" i (length v)
507 $ do
508 y <- unsafeRead v i
509 unsafeWrite v i x
510 return y
511
512 -- Block operations
513 -- ----------------
514
515 -- | Reset all elements of the vector to some undefined value, clearing all
516 -- references to external objects. This is usually a noop for unboxed vectors.
517 clear :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
518 {-# INLINE clear #-}
519 clear = basicClear
520
521 -- | Set all elements of the vector to the given value.
522 set :: (PrimMonad m, MVector v a) => v (PrimState m) a -> a -> m ()
523 {-# INLINE set #-}
524 set = basicSet
525
526 -- | Copy a vector. The two vectors must have the same length and may not
527 -- overlap.
528 copy :: (PrimMonad m, MVector v a)
529 => v (PrimState m) a -> v (PrimState m) a -> m ()
530 {-# INLINE copy #-}
531 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
532 (not (dst `overlaps` src))
533 $ BOUNDS_CHECK(check) "copy" "length mismatch"
534 (length dst == length src)
535 $ unsafeCopy dst src
536
537 -- | Copy a vector. The two vectors must have the same length and may not
538 -- overlap. This is not checked.
539 unsafeCopy :: (PrimMonad m, MVector v a) => v (PrimState m) a -- ^ target
540 -> v (PrimState m) a -- ^ source
541 -> m ()
542 {-# INLINE unsafeCopy #-}
543 unsafeCopy dst src = UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
544 (length dst == length src)
545 $ UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
546 (not (dst `overlaps` src))
547 $ (dst `seq` src `seq` basicUnsafeCopy dst src)
548
549 -- Subvectors
550 -- ----------
551
552 -- | Yield a part of the mutable vector without copying it.
553 slice :: MVector v a => Int -> Int -> v s a -> v s a
554 {-# INLINE slice #-}
555 slice i n v = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
556 $ unsafeSlice i n v
557
558 take :: MVector v a => Int -> v s a -> v s a
559 {-# INLINE take #-}
560 take n v = unsafeSlice 0 (min (max n 0) (length v)) v
561
562 drop :: MVector v a => Int -> v s a -> v s a
563 {-# INLINE drop #-}
564 drop n v = unsafeSlice (min m n') (max 0 (m - n')) v
565 where
566 n' = max n 0
567 m = length v
568
569 init :: MVector v a => v s a -> v s a
570 {-# INLINE init #-}
571 init v = slice 0 (length v - 1) v
572
573 tail :: MVector v a => v s a -> v s a
574 {-# INLINE tail #-}
575 tail v = slice 1 (length v - 1) v
576
577 -- | Yield a part of the mutable vector without copying it. No bounds checks
578 -- are performed.
579 unsafeSlice :: MVector v a => Int -- ^ starting index
580 -> Int -- ^ length of the slice
581 -> v s a
582 -> v s a
583 {-# INLINE unsafeSlice #-}
584 unsafeSlice i n v = UNSAFE_CHECK(checkSlice) "unsafeSlice" i n (length v)
585 $ basicUnsafeSlice i n v
586
587 unsafeInit :: MVector v a => v s a -> v s a
588 {-# INLINE unsafeInit #-}
589 unsafeInit v = unsafeSlice 0 (length v - 1) v
590
591 unsafeTail :: MVector v a => v s a -> v s a
592 {-# INLINE unsafeTail #-}
593 unsafeTail v = unsafeSlice 1 (length v - 1) v
594
595 unsafeTake :: MVector v a => Int -> v s a -> v s a
596 {-# INLINE unsafeTake #-}
597 unsafeTake n v = unsafeSlice 0 n v
598
599 unsafeDrop :: MVector v a => Int -> v s a -> v s a
600 {-# INLINE unsafeDrop #-}
601 unsafeDrop n v = unsafeSlice n (length v - n) v
602
603 -- Permutations
604 -- ------------
605
606 accum :: (PrimMonad m, MVector v a)
607 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
608 {-# INLINE accum #-}
609 accum f !v s = Stream.mapM_ upd s
610 where
611 {-# INLINE_INNER upd #-}
612 upd (i,b) = do
613 a <- BOUNDS_CHECK(checkIndex) "accum" i (length v)
614 $ unsafeRead v i
615 unsafeWrite v i (f a b)
616
617 update :: (PrimMonad m, MVector v a)
618 => v (PrimState m) a -> Stream (Int, a) -> m ()
619 {-# INLINE update #-}
620 update !v s = Stream.mapM_ upd s
621 where
622 {-# INLINE_INNER upd #-}
623 upd (i,b) = BOUNDS_CHECK(checkIndex) "update" i (length v)
624 $ unsafeWrite v i b
625
626 unsafeAccum :: (PrimMonad m, MVector v a)
627 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
628 {-# INLINE unsafeAccum #-}
629 unsafeAccum f !v s = Stream.mapM_ upd s
630 where
631 {-# INLINE_INNER upd #-}
632 upd (i,b) = do
633 a <- UNSAFE_CHECK(checkIndex) "accum" i (length v)
634 $ unsafeRead v i
635 unsafeWrite v i (f a b)
636
637 unsafeUpdate :: (PrimMonad m, MVector v a)
638 => v (PrimState m) a -> Stream (Int, a) -> m ()
639 {-# INLINE unsafeUpdate #-}
640 unsafeUpdate !v s = Stream.mapM_ upd s
641 where
642 {-# INLINE_INNER upd #-}
643 upd (i,b) = UNSAFE_CHECK(checkIndex) "accum" i (length v)
644 $ unsafeWrite v i b
645
646 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
647 {-# INLINE reverse #-}
648 reverse !v = reverse_loop 0 (length v - 1)
649 where
650 reverse_loop i j | i < j = do
651 unsafeSwap v i j
652 reverse_loop (i + 1) (j - 1)
653 reverse_loop _ _ = return ()
654
655 unstablePartition :: forall m v a. (PrimMonad m, MVector v a)
656 => (a -> Bool) -> v (PrimState m) a -> m Int
657 {-# INLINE unstablePartition #-}
658 unstablePartition f !v = from_left 0 (length v)
659 where
660 -- NOTE: GHC 6.10.4 panics without the signatures on from_left and
661 -- from_right
662 from_left :: Int -> Int -> m Int
663 from_left i j
664 | i == j = return i
665 | otherwise = do
666 x <- unsafeRead v i
667 if f x
668 then from_left (i+1) j
669 else from_right i (j-1)
670
671 from_right :: Int -> Int -> m Int
672 from_right i j
673 | i == j = return i
674 | otherwise = do
675 x <- unsafeRead v j
676 if f x
677 then do
678 y <- unsafeRead v i
679 unsafeWrite v i x
680 unsafeWrite v j y
681 from_left (i+1) j
682 else from_right i (j-1)
683
684 unstablePartitionStream :: (PrimMonad m, MVector v a)
685 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
686 {-# INLINE unstablePartitionStream #-}
687 unstablePartitionStream f s
688 = case upperBound (Stream.size s) of
689 Just n -> unstablePartitionMax f s n
690 Nothing -> partitionUnknown f s
691
692 unstablePartitionMax :: (PrimMonad m, MVector v a)
693 => (a -> Bool) -> Stream a -> Int
694 -> m (v (PrimState m) a, v (PrimState m) a)
695 {-# INLINE unstablePartitionMax #-}
696 unstablePartitionMax f s n
697 = do
698 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
699 $ unsafeNew n
700 let {-# INLINE_INNER put #-}
701 put (i, j) x
702 | f x = do
703 unsafeWrite v i x
704 return (i+1, j)
705 | otherwise = do
706 unsafeWrite v (j-1) x
707 return (i, j-1)
708
709 (i,j) <- Stream.foldM' put (0, n) s
710 return (unsafeSlice 0 i v, unsafeSlice j (n-j) v)
711
712 partitionStream :: (PrimMonad m, MVector v a)
713 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
714 {-# INLINE partitionStream #-}
715 partitionStream f s
716 = case upperBound (Stream.size s) of
717 Just n -> partitionMax f s n
718 Nothing -> partitionUnknown f s
719
720 partitionMax :: (PrimMonad m, MVector v a)
721 => (a -> Bool) -> Stream a -> Int -> m (v (PrimState m) a, v (PrimState m) a)
722 {-# INLINE partitionMax #-}
723 partitionMax f s n
724 = do
725 v <- INTERNAL_CHECK(checkLength) "unstablePartitionMax" n
726 $ unsafeNew n
727
728 let {-# INLINE_INNER put #-}
729 put (i,j) x
730 | f x = do
731 unsafeWrite v i x
732 return (i+1,j)
733
734 | otherwise = let j' = j-1 in
735 do
736 unsafeWrite v j' x
737 return (i,j')
738
739 (i,j) <- Stream.foldM' put (0,n) s
740 INTERNAL_CHECK(check) "partitionMax" "invalid indices" (i <= j)
741 $ return ()
742 let l = unsafeSlice 0 i v
743 r = unsafeSlice j (n-j) v
744 reverse r
745 return (l,r)
746
747 partitionUnknown :: (PrimMonad m, MVector v a)
748 => (a -> Bool) -> Stream a -> m (v (PrimState m) a, v (PrimState m) a)
749 {-# INLINE partitionUnknown #-}
750 partitionUnknown f s
751 = do
752 v1 <- unsafeNew 0
753 v2 <- unsafeNew 0
754 (v1', n1, v2', n2) <- Stream.foldM' put (v1, 0, v2, 0) s
755 INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n1 (length v1')
756 $ INTERNAL_CHECK(checkSlice) "partitionUnknown" 0 n2 (length v2')
757 $ return (unsafeSlice 0 n1 v1', unsafeSlice 0 n2 v2')
758 where
759 -- NOTE: The case distinction has to be on the outside because
760 -- GHC creates a join point for the unsafeWrite even when everything
761 -- is inlined. This is bad because with the join point, v isn't getting
762 -- unboxed.
763 {-# INLINE_INNER put #-}
764 put (v1, i1, v2, i2) x
765 | f x = do
766 v1' <- unsafeAppend1 v1 i1 x
767 return (v1', i1+1, v2, i2)
768 | otherwise = do
769 v2' <- unsafeAppend1 v2 i2 x
770 return (v1, i1, v2', i2+1)
771