Fusion rules for in-place map
[darcs-mirrors/vector.git] / Data / Vector / MVector.hs
1 {-# LANGUAGE MultiParamTypeClasses #-}
2 -- |
3 -- Module : Data.Vector.MVector
4 -- Copyright : (c) Roman Leshchinskiy 2008
5 -- License : BSD-style
6 --
7 -- Maintainer : rl@cse.unsw.edu.au
8 -- Stability : experimental
9 -- Portability : non-portable
10 --
11 -- Generic interface to mutable vectors
12 --
13
14 #include "phases.h"
15
16 module Data.Vector.MVector (
17 MVectorPure(..), MVector(..),
18
19 slice, new, newWith, read, write, copy, grow, unstream, map, update
20 ) where
21
22 import qualified Data.Vector.Stream as Stream
23 import Data.Vector.Stream ( Stream )
24 import Data.Vector.Stream.Size
25
26 import Control.Monad.ST ( ST )
27 import Control.Exception ( assert )
28
29 import GHC.Float (
30 double2Int, int2Double
31 )
32
33 import Prelude hiding ( length, map, read )
34
35 gROWTH_FACTOR :: Double
36 gROWTH_FACTOR = 1.5
37
38 -- | Basic pure functions on mutable vectors
39 class MVectorPure v a where
40 -- | Length of the mutable vector
41 length :: v a -> Int
42
43 -- | Yield a part of the mutable vector without copying it. No range checks!
44 unsafeSlice :: v a -> Int -- ^ starting index
45 -> Int -- ^ length of the slice
46 -> v a
47
48 -- Check whether two vectors overlap.
49 overlaps :: v a -> v a -> Bool
50
51 -- | Class of mutable vectors. The type @m@ is the monad in which the mutable
52 -- vector can be transformed and @a@ is the type of elements.
53 --
54 class (Monad m, MVectorPure v a) => MVector v m a where
55 -- | Create a mutable vector of the given length. Length is not checked!
56 unsafeNew :: Int -> m (v a)
57
58 -- | Create a mutable vector of the given length and fill it with an
59 -- initial value. Length is not checked!
60 unsafeNewWith :: Int -> a -> m (v a)
61
62 -- | Yield the element at the given position. Index is not checked!
63 unsafeRead :: v a -> Int -> m a
64
65 -- | Replace the element at the given position. Index is not checked!
66 unsafeWrite :: v a -> Int -> a -> m ()
67
68 -- | Write the value at each position.
69 set :: v a -> a -> m ()
70
71 -- | Copy a vector. The two vectors may not overlap. This is not checked!
72 unsafeCopy :: v a -- ^ target
73 -> v a -- ^ source
74 -> m ()
75
76 -- | Grow a vector by the given number of elements. The length is not
77 -- checked!
78 unsafeGrow :: v a -> Int -> m (v a)
79
80 {-# INLINE unsafeNewWith #-}
81 unsafeNewWith n x = do
82 v <- unsafeNew n
83 set v x
84 return v
85
86 {-# INLINE set #-}
87 set v x = do_set 0
88 where
89 n = length v
90
91 do_set i | i < n = do
92 unsafeWrite v i x
93 do_set (i+1)
94 | otherwise = return ()
95
96 {-# INLINE unsafeCopy #-}
97 unsafeCopy dst src = do_copy 0
98 where
99 n = length src
100
101 do_copy i | i < n = do
102 x <- unsafeRead src i
103 unsafeWrite dst i x
104 do_copy (i+1)
105 | otherwise = return ()
106
107 {-# INLINE unsafeGrow #-}
108 unsafeGrow v by = do
109 v' <- unsafeNew (n+by)
110 unsafeCopy (unsafeSlice v' 0 n) v
111 return v'
112 where
113 n = length v
114
115 -- | Test whether the index is valid for the vector
116 inBounds :: MVectorPure v a => v a -> Int -> Bool
117 {-# INLINE inBounds #-}
118 inBounds v i = i >= 0 && i < length v
119
120 -- | Yield a part of the mutable vector without copying it. Safer version of
121 -- 'unsafeSlice'.
122 slice :: MVectorPure v a => v a -> Int -> Int -> v a
123 {-# INLINE slice #-}
124 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
125 $ unsafeSlice v i n
126
127 -- | Create a mutable vector of the given length. Safer version of
128 -- 'unsafeNew'.
129 new :: MVector v m a => Int -> m (v a)
130 {-# INLINE new #-}
131 new n = assert (n >= 0) $ unsafeNew n
132
133 -- | Create a mutable vector of the given length and fill it with an
134 -- initial value. Safer version of 'unsafeNewWith'.
135 newWith :: MVector v m a => Int -> a -> m (v a)
136 {-# INLINE newWith #-}
137 newWith n x = assert (n >= 0) $ unsafeNewWith n x
138
139 -- | Yield the element at the given position. Safer version of 'unsafeRead'.
140 read :: MVector v m a => v a -> Int -> m a
141 {-# INLINE read #-}
142 read v i = assert (inBounds v i) $ unsafeRead v i
143
144 -- | Replace the element at the given position. Safer version of
145 -- 'unsafeWrite'.
146 write :: MVector v m a => v a -> Int -> a -> m ()
147 {-# INLINE write #-}
148 write v i x = assert (inBounds v i) $ unsafeWrite v i x
149
150 -- | Copy a vector. The two vectors may not overlap. Safer version of
151 -- 'unsafeCopy'.
152 copy :: MVector v m a => v a -> v a -> m ()
153 {-# INLINE copy #-}
154 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
155 $ unsafeCopy dst src
156
157 -- | Grow a vector by the given number of elements. Safer version of
158 -- 'unsafeGrow'.
159 grow :: MVector v m a => v a -> Int -> m (v a)
160 {-# INLINE grow #-}
161 grow v by = assert (by >= 0)
162 $ unsafeGrow v by
163
164
165 -- | Create a new mutable vector and fill it with elements from the 'Stream'.
166 -- The vector will grow logarithmically if the 'Size' hint of the 'Stream' is
167 -- inexact.
168 unstream :: MVector v m a => Stream a -> m (v a)
169 {-# INLINE_STREAM unstream #-}
170 unstream s = case upperBound (Stream.size s) of
171 Just n -> unstreamMax s n
172 Nothing -> unstreamUnknown s
173
174 unstreamMax :: MVector v m a => Stream a -> Int -> m (v a)
175 {-# INLINE unstreamMax #-}
176 unstreamMax s n
177 = do
178 v <- new n
179 let put i x = do { write v i x; return (i+1) }
180 n' <- Stream.foldM put 0 s
181 return $ slice v 0 n'
182
183 unstreamUnknown :: MVector v m a => Stream a -> m (v a)
184 {-# INLINE unstreamUnknown #-}
185 unstreamUnknown s
186 = do
187 v <- new 0
188 (v', n) <- Stream.foldM put (v, 0) s
189 return $ slice v' 0 n
190 where
191 {-# INLINE put #-}
192 put (v, i) x = do
193 v' <- enlarge v i
194 unsafeWrite v' i x
195 return (v', i+1)
196
197 {-# INLINE enlarge #-}
198 enlarge v i | i < length v = return v
199 | otherwise = unsafeGrow v
200 . max 1
201 . double2Int
202 $ int2Double (length v) * gROWTH_FACTOR
203
204 map :: MVector v m a => (a -> a) -> v a -> m ()
205 {-# INLINE map #-}
206 map f v = map_loop 0
207 where
208 n = length v
209
210 map_loop i | i <= n = do
211 x <- read v i
212 write v i (f x)
213 | otherwise = return ()
214
215 update :: MVector v m a => v a -> Stream (Int, a) -> m ()
216 {-# INLINE update #-}
217 update v s = Stream.mapM_ put s
218 where
219 {-# INLINE put #-}
220 put (i, x) = write v i x
221