dph-lifted-vseg: eliminate sharing in arrays during zipl
[packages/dph.git] / dph-lifted-reference / Data / Array / Parallel / PArray.hs
1
2 module Data.Array.Parallel.PArray
3 ( PArray(..)
4 , valid
5 , nf
6
7 -- * Constructors
8 , empty
9 , singleton, singletonl
10 , replicate, replicatel, replicates
11 , append, appendl
12 , concat, concatl
13
14 -- * Projections
15 , length, lengthl
16 , index, indexl
17 , extract
18
19 -- * Pack and Combine
20 , pack, packl
21 , packByTag
22 , combine2)
23 where
24 import Data.Array.Parallel.Base (Tag)
25 import Data.Vector (Vector)
26 import qualified Data.Array.Parallel.Unlifted as U
27 import qualified Data.Array.Parallel.Array as A
28 import qualified Data.Vector as V
29 import Control.Monad
30 import GHC.Exts
31 import Prelude
32 hiding (replicate, length, concat)
33
34 die fn str = error $ "Data.Array.Parallel.PArray: " ++ fn ++ " " ++ str
35
36 -- | Parallel Ararys.
37 data PArray a
38 = PArray Int# (Vector a)
39 deriving (Eq, Show)
40
41
42 -- Array Instances ------------------------------------------------------------
43 instance A.Array PArray a where
44 length (PArray _ vec)
45 = V.length vec
46
47 index (PArray _ vec) ix
48 = vec V.! ix
49
50 append (PArray n1# xs) (PArray n2# ys)
51 = PArray (n1# +# n2#) (xs V.++ ys)
52
53 toVector (PArray _ vec)
54 = vec
55
56 fromVector vec
57 = case V.length vec of
58 I# n# -> PArray n# vec
59
60
61 -- | Lift a unary array operator
62 lift1 :: (a -> b) -> PArray a -> PArray b
63 lift1 f (PArray n# vec)
64 = PArray n# $ V.map f vec
65
66
67 -- | Lift a unary array operator
68 lift2 :: (a -> b -> c) -> PArray a -> PArray b -> PArray c
69 lift2 f (PArray n1# vec1) (PArray n2# vec2)
70 | I# n1# /= I# n2#
71 = die "lift2" "length mismatch"
72
73 | otherwise
74 = PArray n1# $ V.zipWith f vec1 vec2
75
76
77 -- Basics ---------------------------------------------------------------------
78 -- | Check that an array has a valid internal representation.
79 valid :: PArray a -> Bool
80 valid _ = True
81
82 -- | Force an array to normal form.
83 nf :: PArray a -> ()
84 nf _ = ()
85
86
87 -- Constructors ----------------------------------------------------------------
88 -- | O(1). An empty array.
89 empty :: PArray a
90 empty = PArray 0# V.empty
91
92
93 -- | O(1). Produce an array containing a single element.
94 singleton :: a -> PArray a
95 singleton x = PArray 1# (V.singleton x)
96
97
98 -- | O(n). Produce an array of singleton arrays.
99 singletonl :: PArray a -> PArray (PArray a)
100 singletonl = lift1 singleton
101
102
103 -- | O(n). Define an array of the given size, that maps all elements to the same value.
104 replicate :: Int -> a -> PArray a
105 replicate n@(I# n#) x
106 = PArray n# (V.replicate n x)
107
108
109 -- | O(sum lengths). Lifted replicate.
110 replicatel :: PArray Int -> PArray a -> PArray (PArray a)
111 replicatel = lift2 replicate
112
113
114 -- | O(sum lengths). Segmented replicate.
115 replicates :: U.Segd -> PArray a -> PArray a
116 replicates segd (PArray n# vec)
117 | I# n# /= U.lengthSegd segd
118 = die "replicates" $ unlines
119 [ "segd length mismatch"
120 , " segd length = " ++ show (U.lengthSegd segd)
121 , " array length = " ++ show (I# n#) ]
122
123 | otherwise
124 = let !(I# n2#) = U.elementsSegd segd
125 in PArray n2#
126 $ join $ V.zipWith V.replicate
127 (V.convert $ U.lengthsSegd segd)
128 vec
129
130
131 -- | Append two arrays.
132 append :: PArray a -> PArray a -> PArray a
133 append (PArray n1# xs) (PArray n2# ys)
134 = PArray (n1# +# n2#) (xs V.++ ys)
135
136
137 -- | Lifted append.
138 appendl :: PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
139 appendl = lift2 append
140
141
142 -- | Concatenation
143 concat :: PArray (PArray a) -> PArray a
144 concat (PArray _ xss)
145 = let xs = join $ V.map A.toVector xss
146 !(I# n') = V.length xs
147 in PArray n' xs
148
149
150 -- | Lifted concatenation
151 concatl :: PArray (PArray (PArray a)) -> PArray (PArray a)
152 concatl = lift1 concat
153
154 -----------------------------------------------------------
155 -- TODO: unconcat
156 -----------------------------------------------------------
157
158 -----------------------------------------------------------
159 -- TODO: nestUSegd
160 -----------------------------------------------------------
161
162
163 -- Projections ----------------------------------------------------------------
164 -- | Take the length of an array
165 length :: PArray a -> Int
166 length (PArray n# _) = I# n#
167
168
169 -- | Take the length of some arrays.
170 lengthl :: PArray (PArray a) -> PArray Int
171 lengthl = lift1 length
172
173
174 -- | Lookup a single element from the source array.
175 index :: PArray a -> Int -> a
176 index (PArray _ arr) ix
177 = arr V.! ix
178
179
180 -- | Lookup a several elements from several source arrays.
181 indexl :: PArray (PArray a) -> PArray Int -> PArray a
182 indexl = lift2 index
183
184
185 -- | Extract a range of elements from an array.
186 extract :: PArray a -> Int -> Int -> PArray a
187 extract (PArray _ vec) start len@(I# len#)
188 = PArray len# $ V.slice start len vec
189
190
191 -- Pack and Combine -----------------------------------------------------------
192 -- | Select the elements of an array that have their tag set to True.
193 pack :: PArray a -> PArray Bool -> PArray a
194 pack (PArray n1# xs) (PArray n2# bs)
195 | I# n1# /= I# n2#
196 = die "pack" $ unlines
197 [ "array length mismatch"
198 , " data length = " ++ show (I# n1#)
199 , " flags length = " ++ show (I# n2#) ]
200
201 | otherwise
202 = let xs' = V.ifilter (\i _ -> bs V.! i) xs
203 !(I# n') = V.length xs'
204 in PArray n' xs'
205
206 -- | Lifted pack.
207 packl :: PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
208 packl = lift2 pack
209
210
211 -- | Filter an array based on some tags.
212 packByTag :: PArray a -> U.Array Tag -> Tag -> PArray a
213 packByTag (PArray n1# xs) tags tag
214 | I# n1# /= U.length tags
215 = die "packByTag" $ unlines
216 [ "array length mismatch"
217 , " data length = " ++ show (I# n1#)
218 , " flags length = " ++ (show $ U.length tags) ]
219
220 | otherwise
221 = let xs' = V.ifilter (\i _ -> tags U.!: i == tag) xs
222 !(I# n') = V.length xs'
223 in PArray n' xs'
224
225
226 -- | Combine two arrays based on a selector.
227 combine2 :: U.Sel2 -> PArray a -> PArray a -> PArray a
228 combine2 tags (PArray n1# vec1) (PArray n2# vec2)
229 = let
230 go [] [] [] = []
231 go (0 : bs) (x : xs) ys = x : go bs xs ys
232 go (1 : bs) xs (y : ys) = y : go bs xs ys
233
234 vec3 = V.fromList
235 $ go (V.toList $ V.convert $ U.tagsSel2 tags)
236 (V.toList vec1)
237 (V.toList vec2)
238 !(I# n') = V.length vec3
239
240 in PArray n' vec3