Comments
[darcs-mirrors/vector.git] / Data / Vector / Generic / Mutable.hs
1 {-# LANGUAGE MultiParamTypeClasses, BangPatterns #-}
2 -- |
3 -- Module : Data.Vector.Generic.Mutable
4 -- Copyright : (c) Roman Leshchinskiy 2008-2009
5 -- License : BSD-style
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 #include "phases.h"
15
16 module Data.Vector.Generic.Mutable (
17 MVectorPure(..), MVector(..),
18
19 slice, new, newWith, read, write, copy, grow,
20 unstream, transform,
21 accum, update, reverse
22 ) where
23
24 import qualified Data.Vector.Fusion.Stream as Stream
25 import Data.Vector.Fusion.Stream ( Stream, MStream )
26 import qualified Data.Vector.Fusion.Stream.Monadic as MStream
27 import Data.Vector.Fusion.Stream.Size
28
29 import Control.Exception ( assert )
30
31 import GHC.Float (
32 double2Int, int2Double
33 )
34
35 import Prelude hiding ( length, reverse, map, read )
36
37 gROWTH_FACTOR :: Double
38 gROWTH_FACTOR = 1.5
39
40 -- | Basic pure functions on mutable vectors
41 class MVectorPure v a where
42 -- | Length of the mutable vector
43 length :: v a -> Int
44
45 -- | Yield a part of the mutable vector without copying it. No range checks!
46 unsafeSlice :: v a -> Int -- ^ starting index
47 -> Int -- ^ length of the slice
48 -> v a
49
50 -- Check whether two vectors overlap.
51 overlaps :: v a -> v a -> Bool
52
53 -- | Class of mutable vectors. The type @m@ is the monad in which the mutable
54 -- vector can be transformed and @a@ is the type of elements.
55 --
56 class (Monad m, MVectorPure v a) => MVector v m a where
57 -- | Create a mutable vector of the given length. Length is not checked!
58 unsafeNew :: Int -> m (v a)
59
60 -- | Create a mutable vector of the given length and fill it with an
61 -- initial value. Length is not checked!
62 unsafeNewWith :: Int -> a -> m (v a)
63
64 -- | Yield the element at the given position. Index is not checked!
65 unsafeRead :: v a -> Int -> m a
66
67 -- | Replace the element at the given position. Index is not checked!
68 unsafeWrite :: v a -> Int -> a -> m ()
69
70 -- | Clear all references to external objects
71 clear :: v a -> m ()
72
73 -- | Write the value at each position.
74 set :: v a -> a -> m ()
75
76 -- | Copy a vector. The two vectors may not overlap. This is not checked!
77 unsafeCopy :: v a -- ^ target
78 -> v a -- ^ source
79 -> m ()
80
81 -- | Grow a vector by the given number of elements. The length is not
82 -- checked!
83 unsafeGrow :: v a -> Int -> m (v a)
84
85 {-# INLINE unsafeNewWith #-}
86 unsafeNewWith n x = do
87 v <- unsafeNew n
88 set v x
89 return v
90
91 {-# INLINE set #-}
92 set v x = do_set 0
93 where
94 n = length v
95
96 do_set i | i < n = do
97 unsafeWrite v i x
98 do_set (i+1)
99 | otherwise = return ()
100
101 {-# INLINE unsafeCopy #-}
102 unsafeCopy dst src = do_copy 0
103 where
104 n = length src
105
106 do_copy i | i < n = do
107 x <- unsafeRead src i
108 unsafeWrite dst i x
109 do_copy (i+1)
110 | otherwise = return ()
111
112 {-# INLINE unsafeGrow #-}
113 unsafeGrow v by = do
114 v' <- unsafeNew (n+by)
115 unsafeCopy (unsafeSlice v' 0 n) v
116 return v'
117 where
118 n = length v
119
120 -- | Test whether the index is valid for the vector
121 inBounds :: MVectorPure v a => v a -> Int -> Bool
122 {-# INLINE inBounds #-}
123 inBounds v i = i >= 0 && i < length v
124
125 -- | Yield a part of the mutable vector without copying it. Safer version of
126 -- 'unsafeSlice'.
127 slice :: MVectorPure v a => v a -> Int -> Int -> v a
128 {-# INLINE slice #-}
129 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
130 $ unsafeSlice v i n
131
132 -- | Create a mutable vector of the given length. Safer version of
133 -- 'unsafeNew'.
134 new :: MVector v m a => Int -> m (v a)
135 {-# INLINE new #-}
136 new n = assert (n >= 0) $ unsafeNew n
137
138 -- | Create a mutable vector of the given length and fill it with an
139 -- initial value. Safer version of 'unsafeNewWith'.
140 newWith :: MVector v m a => Int -> a -> m (v a)
141 {-# INLINE newWith #-}
142 newWith n x = assert (n >= 0) $ unsafeNewWith n x
143
144 -- | Yield the element at the given position. Safer version of 'unsafeRead'.
145 read :: MVector v m a => v a -> Int -> m a
146 {-# INLINE read #-}
147 read v i = assert (inBounds v i) $ unsafeRead v i
148
149 -- | Replace the element at the given position. Safer version of
150 -- 'unsafeWrite'.
151 write :: MVector v m a => v a -> Int -> a -> m ()
152 {-# INLINE write #-}
153 write v i x = assert (inBounds v i) $ unsafeWrite v i x
154
155 -- | Copy a vector. The two vectors may not overlap. Safer version of
156 -- 'unsafeCopy'.
157 copy :: MVector v m a => v a -> v a -> m ()
158 {-# INLINE copy #-}
159 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
160 $ unsafeCopy dst src
161
162 -- | Grow a vector by the given number of elements. Safer version of
163 -- 'unsafeGrow'.
164 grow :: MVector v m a => v a -> Int -> m (v a)
165 {-# INLINE grow #-}
166 grow v by = assert (by >= 0)
167 $ unsafeGrow v by
168
169 mstream :: MVector v m a => v a -> MStream m a
170 {-# INLINE mstream #-}
171 mstream v = v `seq` (MStream.unfoldrM get 0 `MStream.sized` Exact n)
172 where
173 n = length v
174
175 {-# INLINE_INNER get #-}
176 get i | i < n = do x <- unsafeRead v i
177 return $ Just (x, i+1)
178 | otherwise = return $ Nothing
179
180 munstream :: MVector v m a => v a -> MStream m a -> m (v a)
181 {-# INLINE munstream #-}
182 munstream v s = v `seq` do
183 n' <- MStream.foldM put 0 s
184 return $ slice v 0 n'
185 where
186 {-# INLINE_INNER put #-}
187 put i x = do { write v i x; return (i+1) }
188
189 transform :: MVector v m a => (MStream m a -> MStream m a) -> v a -> m (v a)
190 {-# INLINE_STREAM transform #-}
191 transform f v = munstream v (f (mstream v))
192
193 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
194 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
195 -- inexact.
196 unstream :: MVector v m a => Stream a -> m (v a)
197 {-# INLINE_STREAM unstream #-}
198 unstream s = case upperBound (Stream.size s) of
199 Just n -> unstreamMax s n
200 Nothing -> unstreamUnknown s
201
202 unstreamMax :: MVector v m a => Stream a -> Int -> m (v a)
203 {-# INLINE unstreamMax #-}
204 unstreamMax s n
205 = do
206 v <- new n
207 let put i x = do { write v i x; return (i+1) }
208 n' <- Stream.foldM' put 0 s
209 return $ slice v 0 n'
210
211 unstreamUnknown :: MVector v m a => Stream a -> m (v a)
212 {-# INLINE unstreamUnknown #-}
213 unstreamUnknown s
214 = do
215 v <- new 0
216 (v', n) <- Stream.foldM put (v, 0) s
217 return $ slice v' 0 n
218 where
219 {-# INLINE_INNER put #-}
220 put (v, i) x = do
221 v' <- enlarge v i
222 unsafeWrite v' i x
223 return (v', i+1)
224
225 {-# INLINE_INNER enlarge #-}
226 enlarge v i | i < length v = return v
227 | otherwise = unsafeGrow v
228 . max 1
229 . double2Int
230 $ int2Double (length v) * gROWTH_FACTOR
231
232 accum :: MVector v m a => (a -> b -> a) -> v a -> Stream (Int, b) -> m ()
233 {-# INLINE accum #-}
234 accum f !v s = Stream.mapM_ upd s
235 where
236 {-# INLINE_INNER upd #-}
237 upd (i,b) = do
238 a <- read v i
239 write v i (f a b)
240
241 update :: MVector v m a => v a -> Stream (Int, a) -> m ()
242 {-# INLINE update #-}
243 update = accum (const id)
244
245 reverse :: MVector v m a => v a -> m ()
246 {-# INLINE reverse #-}
247 reverse !v = reverse_loop 0 (length v - 1)
248 where
249 reverse_loop i j | i < j = do
250 x <- unsafeRead v i
251 y <- unsafeRead v j
252 unsafeWrite v i y
253 unsafeWrite v j x
254 reverse_loop (i + 1) (j - 1)
255 reverse_loop _ _ = return ()
256