Generic Vector framework
[darcs-mirrors/vector.git] / Data / Vector / Base / Mutable.hs
1 {-# LANGUAGE TypeFamilies, FlexibleContexts, MultiParamTypeClasses #-}
2 module Data.Vector.Base.Mutable (
3 Base(..),
4
5 slice, new, newWith, read, write, copy, grow, unstream
6 ) where
7
8 import qualified Data.Vector.Stream as Stream
9 import Data.Vector.Stream ( Stream )
10 import Data.Vector.Stream.Size
11
12 import Control.Monad.ST ( ST )
13 import Control.Exception ( assert )
14
15 import GHC.Float (
16 double2Int, int2Double
17 )
18
19 import Prelude hiding ( length, read )
20
21 gROWTH_FACTOR :: Double
22 gROWTH_FACTOR = 1.5
23
24 class Monad (Trans v) => Base v a where
25 type Trans v :: * -> *
26
27 length :: v a -> Int
28 unsafeSlice :: v a -> Int -> Int -> v a
29
30 unsafeNew :: Int -> Trans v (v a)
31 unsafeNewWith :: Int -> a -> Trans v (v a)
32
33 unsafeRead :: v a -> Int -> Trans v a
34 unsafeWrite :: v a -> Int -> a -> Trans v ()
35
36 set :: v a -> a -> Trans v ()
37 unsafeCopy :: v a -> v a -> Trans v ()
38 unsafeGrow :: v a -> Int -> Trans v (v a)
39
40 overlaps :: v a -> v a -> Bool
41
42 {-# INLINE unsafeNewWith #-}
43 unsafeNewWith n x = do
44 v <- unsafeNew n
45 set v x
46 return v
47
48 {-# INLINE set #-}
49 set v x = do_set 0
50 where
51 n = length v
52
53 do_set i | i < n = do
54 unsafeWrite v i x
55 do_set (i+1)
56 | otherwise = return ()
57
58 {-# INLINE unsafeCopy #-}
59 unsafeCopy dst src = do_copy 0
60 where
61 n = length src
62
63 do_copy i | i < n = do
64 x <- unsafeRead src i
65 unsafeWrite dst i x
66 do_copy (i+1)
67 | otherwise = return ()
68
69 {-# INLINE unsafeGrow #-}
70 unsafeGrow v by = do
71 v' <- unsafeNew (n+by)
72 unsafeCopy (unsafeSlice v' 0 n) v
73 return v'
74 where
75 n = length v
76
77 inBounds :: Base v a => v a -> Int -> Bool
78 {-# INLINE inBounds #-}
79 inBounds v i = i >= 0 && i < length v
80
81 slice :: Base v a => v a -> Int -> Int -> v a
82 {-# INLINE slice #-}
83 slice v i n = assert (i >=0 && n >= 0 && i+n <= length v)
84 $ unsafeSlice v i n
85
86 new :: (Base v a, m ~ Trans v) => Int -> m (v a)
87 {-# INLINE new #-}
88 new n = assert (n >= 0) $ unsafeNew n
89
90 newWith :: (Base v a, m ~ Trans v) => Int -> a -> m (v a)
91 {-# INLINE newWith #-}
92 newWith n x = assert (n >= 0) $ unsafeNewWith n x
93
94 read :: (Base v a, m ~ Trans v) => v a -> Int -> m a
95 {-# INLINE read #-}
96 read v i = assert (inBounds v i) $ unsafeRead v i
97
98 write :: (Base v a, m ~ Trans v) => v a -> Int -> a -> m ()
99 {-# INLINE write #-}
100 write v i x = assert (inBounds v i) $ unsafeWrite v i x
101
102 copy :: (Base v a, m ~ Trans v) => v a -> v a -> m ()
103 {-# INLINE copy #-}
104 copy dst src = assert (not (dst `overlaps` src) && length dst == length src)
105 $ unsafeCopy dst src
106
107 grow :: (Base v a, m ~ Trans v) => v a -> Int -> m (v a)
108 {-# INLINE grow #-}
109 grow v by = assert (by >= 0)
110 $ unsafeGrow v by
111
112
113 unstream :: (Base v a, m ~ Trans v) => Stream a -> m (v a)
114 {-# INLINE unstream #-}
115 unstream s = case upperBound (Stream.size s) of
116 Just n -> unstreamMax s n
117 Nothing -> unstreamUnknown s
118
119 unstreamMax :: (Base v a, m ~ Trans v) => Stream a -> Int -> m (v a)
120 {-# INLINE unstreamMax #-}
121 unstreamMax s n
122 = do
123 v <- new n
124 let put i x = do { write v i x; return (i+1) }
125 n' <- Stream.foldM put 0 s
126 return $ slice v 0 n'
127
128 unstreamUnknown :: (Base v a, m ~ Trans v) => Stream a -> m (v a)
129 {-# INLINE unstreamUnknown #-}
130 unstreamUnknown s
131 = do
132 v <- new 0
133 (v', n) <- Stream.foldM put (v, 0) s
134 return $ slice v' 0 n
135 where
136 {-# INLINE put #-}
137 put (v, i) x = do
138 v' <- enlarge v i
139 unsafeWrite v' i x
140 return (v', i+1)
141
142 {-# INLINE enlarge #-}
143 enlarge v i | i < length v = return v
144 | otherwise = unsafeGrow v
145 . max 1
146 . double2Int
147 $ int2Double (length v) * gROWTH_FACTOR
148