bf70aafc5652e36b9e82e07027ff6f43419fef0b
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 module Data.Vector.Generic.Mutable (
15 MVector(..),
16
17 slice, new, newWith, read, write, copy, grow,
18 unstream, transform,
19 accum, update, reverse
20 ) where
21
22 import qualified Data.Vector.Fusion.Stream as Stream
23 import Data.Vector.Fusion.Stream ( Stream, MStream )
24 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
25 import Data.Vector.Fusion.Stream.Size
26
27 import Control.Monad.Primitive ( PrimMonad, PrimState )
28
29 import GHC.Float (
30 double2Int, int2Double
31 )
32
33 import Prelude hiding ( length, reverse, map, read )
34
35 #include "vector.h"
36
37 gROWTH_FACTOR :: Double
38 gROWTH_FACTOR = 1.5
39
40 -- | Class of mutable vectors parametrised with a primitive state token.
41 --
42 class MVector v a where
43 -- | Length of the mutable vector
44 length :: v s a -> Int
45
46 -- | Yield a part of the mutable vector without copying it. No range checks!
47 unsafeSlice :: v s a -> Int -- ^ starting index
48 -> Int -- ^ length of the slice
49 -> v s a
50
51 -- Check whether two vectors overlap.
52 overlaps :: v s a -> v s a -> Bool
53
54 -- | Create a mutable vector of the given length. Length is not checked!
55 unsafeNew :: PrimMonad m => Int -> m (v (PrimState m) a)
56
57 -- | Create a mutable vector of the given length and fill it with an
58 -- initial value. Length is not checked!
59 unsafeNewWith :: PrimMonad m => Int -> a -> m (v (PrimState m) a)
60
61 -- | Yield the element at the given position. Index is not checked!
62 unsafeRead :: PrimMonad m => v (PrimState m) a -> Int -> m a
63
64 -- | Replace the element at the given position. Index is not checked!
65 unsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()
66
67 -- | Clear all references to external objects
68 clear :: PrimMonad m => v (PrimState m) a -> m ()
69
70 -- | Write the value at each position.
71 set :: PrimMonad m => v (PrimState m) a -> a -> m ()
72
73 -- | Copy a vector. The two vectors may not overlap. This is not checked!
74 unsafeCopy :: PrimMonad m => v (PrimState m) a -- ^ target
75 -> v (PrimState m) a -- ^ source
76 -> m ()
77
78 -- | Grow a vector by the given number of elements. The length is not
79 -- checked!
80 unsafeGrow :: PrimMonad m => v (PrimState m) a -> Int -> m (v (PrimState m) a)
81
82 {-# INLINE unsafeNewWith #-}
83 unsafeNewWith n x = UNSAFE_CHECK(checkLength) "unsafeNewWith" n
84 $ do
85 v <- unsafeNew n
86 set v x
87 return v
88
89 {-# INLINE set #-}
90 set v x = do_set 0
91 where
92 n = length v
93
94 do_set i | i < n = do
95 unsafeWrite v i x
96 do_set (i+1)
97 | otherwise = return ()
98
99 {-# INLINE unsafeCopy #-}
100 unsafeCopy dst src
101 = UNSAFE_CHECK(check) "unsafeCopy" "overlapping vectors"
102 (not (dst `overlaps` src))
103 $ UNSAFE_CHECK(check) "unsafeCopy" "length mismatch"
104 (length dst == length src)
105 $ do_copy 0
106 where
107 n = length src
108
109 do_copy i | i < n = do
110 x <- unsafeRead src i
111 unsafeWrite dst i x
112 do_copy (i+1)
113 | otherwise = return ()
114
115 {-# INLINE unsafeGrow #-}
116 unsafeGrow v by = UNSAFE_CHECK(checkLength) "unsafeGrow" by
117 $ do
118 v' <- unsafeNew (n+by)
119 unsafeCopy (unsafeSlice v' 0 n) v
120 return v'
121 where
122 n = length v
123
124 -- | Yield a part of the mutable vector without copying it. Safer version of
125 -- 'unsafeSlice'.
126 slice :: MVector v a => v s a -> Int -> Int -> v s a
127 {-# INLINE slice #-}
128 slice v i n = BOUNDS_CHECK(checkSlice) "slice" i n (length v)
129 $ unsafeSlice v i n
130
131 -- | Create a mutable vector of the given length. Safer version of
132 -- 'unsafeNew'.
133 new :: (PrimMonad m, MVector v a) => Int -> m (v (PrimState m) a)
134 {-# INLINE new #-}
135 new n = BOUNDS_CHECK(checkLength) "new" n
136 $ unsafeNew n
137
138 -- | Create a mutable vector of the given length and fill it with an
139 -- initial value. Safer version of 'unsafeNewWith'.
140 newWith :: (PrimMonad m, MVector v a) => Int -> a -> m (v (PrimState m) a)
141 {-# INLINE newWith #-}
142 newWith n x = BOUNDS_CHECK(checkLength) "newWith" n
143 $ unsafeNewWith n x
144
145 -- | Yield the element at the given position. Safer version of 'unsafeRead'.
146 read :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
147 {-# INLINE read #-}
148 read v i = BOUNDS_CHECK(checkIndex) "read" i (length v)
149 $ unsafeRead v i
150
151 -- | Replace the element at the given position. Safer version of
152 -- 'unsafeWrite'.
153 write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
154 {-# INLINE write #-}
155 write v i x = BOUNDS_CHECK(checkIndex) "write" i (length v)
156 $ unsafeWrite v i x
157
158 -- | Copy a vector. The two vectors may not overlap. Safer version of
159 -- 'unsafeCopy'.
160 copy :: (PrimMonad m, MVector v a)
161 => v (PrimState m) a -> v (PrimState m) a -> m ()
162 {-# INLINE copy #-}
163 copy dst src = BOUNDS_CHECK(check) "copy" "overlapping vectors"
164 (not (dst `overlaps` src))
165 $ BOUNDS_CHECK(check) "copy" "length mismatch"
166 (length dst == length src)
167 $ unsafeCopy dst src
168
169 -- | Grow a vector by the given number of elements. Safer version of
170 -- 'unsafeGrow'.
171 grow :: (PrimMonad m, MVector v a)
172 => v (PrimState m) a -> Int -> m (v (PrimState m) a)
173 {-# INLINE grow #-}
174 grow v by = BOUNDS_CHECK(checkLength) "grow" by
175 $ unsafeGrow v by
176
177 mstream :: (PrimMonad m, MVector v a) => v (PrimState m) a -> MStream m a
178 {-# INLINE mstream #-}
179 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
180 where
181 n = length v
182
183 {-# INLINE_INNER get #-}
184 get i | i < n = do x <- unsafeRead v i
185 return $ Just (x, i+1)
186 | otherwise = return $ Nothing
187
188 internal_munstream :: (PrimMonad m, MVector v a)
189 => v (PrimState m) a -> MStream m a -> m (v (PrimState m) a)
190 {-# INLINE internal_munstream #-}
191 internal_munstream v s = v `seq` do
192 n' <- MStream.foldM put 0 s
193 return $ slice v 0 n'
194 where
195 {-# INLINE_INNER put #-}
196 put i x = do
197 INTERNAL_CHECK(checkIndex) "internal_munstream" i (length v)
198 $ unsafeWrite v i x
199 return (i+1)
200
201 transform :: (PrimMonad m, MVector v a)
202 => (MStream m a -> MStream m a) -> v (PrimState m) a -> m (v (PrimState m) a)
203 {-# INLINE_STREAM transform #-}
204 transform f v = internal_munstream v (f (mstream v))
205
206 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
207 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
208 -- inexact.
209 unstream :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
210 {-# INLINE_STREAM unstream #-}
211 unstream s = case upperBound (Stream.size s) of
212 Just n -> unstreamMax s n
213 Nothing -> unstreamUnknown s
214
215 unstreamMax
216 :: (PrimMonad m, MVector v a) => Stream a -> Int -> m (v (PrimState m) a)
217 {-# INLINE unstreamMax #-}
218 unstreamMax s n
219 = do
220 v <- new n
221 let put i x = do
222 INTERNAL_CHECK(checkIndex) "unstreamMax" i n
223 $ unsafeWrite v i x
224 return (i+1)
225 n' <- Stream.foldM' put 0 s
226 return $ INTERNAL_CHECK(checkSlice) "unstreamMax" 0 n' n $ slice v 0 n'
227
228 unstreamUnknown
229 :: (PrimMonad m, MVector v a) => Stream a -> m (v (PrimState m) a)
230 {-# INLINE unstreamUnknown #-}
231 unstreamUnknown s
232 = do
233 v <- new 0
234 (v', n) <- Stream.foldM put (v, 0) s
235 return $ slice v' 0 n
236 where
237 -- NOTE: The case distinction has to be on the outside because
238 -- GHC creates a join point for the unsafeWrite even when everything
239 -- is inlined. This is bad because with the join point, v isn't getting
240 -- unboxed.
241 {-# INLINE_INNER put #-}
242 put (v, i) x
243 | i < length v = do
244 unsafeWrite v i x
245 return (v, i+1)
246 | otherwise = do
247 v' <- enlarge v
248 INTERNAL_CHECK(checkIndex) "unstreamMax" i (length v')
249 $ unsafeWrite v' i x
250 return (v', i+1)
251
252 {-# INLINE_INNER enlarge #-}
253 enlarge v = unsafeGrow v
254 $ max 1
255 $ double2Int
256 $ int2Double (length v) * gROWTH_FACTOR
257
258 accum :: (PrimMonad m, MVector v a)
259 => (a -> b -> a) -> v (PrimState m) a -> Stream (Int, b) -> m ()
260 {-# INLINE accum #-}
261 accum f !v s = Stream.mapM_ upd s
262 where
263 {-# INLINE_INNER upd #-}
264 upd (i,b) = do
265 a <- read v i
266 write v i (f a b)
267
268 update :: (PrimMonad m, MVector v a)
269 => v (PrimState m) a -> Stream (Int, a) -> m ()
270 {-# INLINE update #-}
271 update = accum (const id)
272
273 reverse :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m ()
274 {-# INLINE reverse #-}
275 reverse !v = reverse_loop 0 (length v - 1)
276 where
277 reverse_loop i j | i < j = do
278 x <- unsafeRead v i
279 y <- unsafeRead v j
280 unsafeWrite v i y
281 unsafeWrite v j x
282 reverse_loop (i + 1) (j - 1)
283 reverse_loop _ _ = return ()
284