Reorganise the way the lifted reference implementation works
[packages/dph.git] / dph-lifted-boxed / Data / Array / Parallel / PArray.hs
1
2 -- | Unvectorised parallel arrays.
3 --
4 -- * These operators may be used directly by unvectorised client programs.
5 --
6 -- * They are also used by the "Data.Array.Parallel.Lifted.Combinators"
7 -- module to define the closure converted versions that vectorised code
8 -- uses.
9 --
10 -- * In general, the operators here are all unsafe and don't do bounds checks.
11 -- The lifted versions also don't check that each of the argument arrays
12 -- have the same length.
13 --
14 -- TODO: check lengths properly in functions like zip, extracts
15 --
16 module Data.Array.Parallel.PArray
17 ( PArray(..), PA
18 , valid
19 , nf
20
21 -- * Constructors
22 , empty
23 , singleton, singletonl
24 , replicate, replicatel, replicates, replicates'
25 , append, appendl
26 , concat, concatl
27 , unconcat
28 , nestUSegd
29
30 -- * Projections
31 , length, lengthl
32 , index, indexl
33 , extract, extracts, extracts'
34 , slice, slicel
35 , takeUSegd
36
37 -- * Pack and Combine
38 , pack, packl
39 , packByTag
40 , combine2
41
42 -- * Enumerations
43 , enumFromTo, enumFromTol
44
45 -- * Tuples
46 , zip, zipl
47 , unzip, unzipl
48
49 -- * Conversions
50 , fromVector, toVector
51 , fromList, toList
52 , fromUArray, toUArray
53 , fromUArray2)
54 where
55 import Data.Array.Parallel.PArray.PData
56 import Data.Array.Parallel.PArray.PRepr
57 import Data.Array.Parallel.Base (Tag)
58 import Data.Vector (Vector)
59 import qualified Data.Array.Parallel.Unlifted as U
60 import qualified Data.Array.Parallel.Array as A
61 import qualified Data.Vector as V
62 import Control.Monad
63 import GHC.Exts
64 import qualified Prelude as P
65 import Prelude hiding
66 ( replicate, length, concat
67 , enumFromTo
68 , zip, unzip)
69
70 die fn str = error $ "Data.Array.Parallel.PArray: " ++ fn ++ " " ++ str
71
72
73 -- Array Instances ------------------------------------------------------------
74 instance PA a => A.Array PArray a where
75 valid = const True
76 singleton = A.singleton
77
78 length (PArray _ vec)
79 = V.length $ toVectorPA vec
80
81 index (PArray _ vec) ix
82 = (toVectorPA vec) V.! ix
83
84 append (PArray n1# xs) (PArray n2# ys)
85 = PArray (n1# +# n2#)
86 $ fromVectorPA (toVectorPA xs V.++ toVectorPA ys)
87
88 toVector (PArray _ vec)
89 = toVectorPA vec
90
91 fromVector vec
92 = case V.length vec of
93 I# n# -> PArray n# (fromVectorPA vec)
94
95
96 -- | Lift a unary array operator.
97 lift1 :: (PA a, PA b)
98 => (a -> b) -> PArray a -> PArray b
99 lift1 f (PArray n# vec)
100 = PArray n#
101 $ fromVectorPA
102 $ V.map f (toVectorPA vec)
103
104
105 -- | Lift a binary array operator.
106 lift2 :: (PA a, PA b, PA c)
107 => (a -> b -> c) -> PArray a -> PArray b -> PArray c
108 lift2 f (PArray n1# vec1) (PArray n2# vec2)
109 | I# n1# /= I# n2#
110 = die "lift2" "length mismatch"
111
112 | otherwise
113 = PArray n1#
114 $ fromVectorPA
115 $ V.zipWith f
116 (toVectorPA vec1)
117 (toVectorPA vec2)
118
119
120 -- | Lift a trinary array operator
121 lift3 :: (PA a, PA b, PA c, PA d)
122 => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
123 lift3 f (PArray n1# vec1) (PArray n2# vec2) (PArray n3# vec3)
124 | I# n1# /= I# n2#
125 || I# n1# /= I# n3#
126 = die "lift3" "length mismatch"
127
128 | otherwise
129 = PArray n1#
130 $ fromVectorPA
131 $ V.zipWith3 f
132 (toVectorPA vec1)
133 (toVectorPA vec2)
134 (toVectorPA vec3)
135
136
137 -- Basics ---------------------------------------------------------------------
138 -- | Check that an array has a valid internal representation.
139 valid :: PArray a -> Bool
140 valid _ = True
141
142 -- | Force an array to normal form.
143 nf :: PArray a -> ()
144 nf _ = ()
145
146
147 -- Constructors ----------------------------------------------------------------
148 -- | O(1). An empty array.
149 empty :: PA a => PArray a
150 empty = PArray 0# $ fromVectorPA V.empty
151
152
153 -- | O(1). Produce an array containing a single element.
154 singleton :: PA a => a -> PArray a
155 singleton x = PArray 1# $ fromVectorPA $ V.singleton x
156
157
158 -- | O(n). Produce an array of singleton arrays.
159 singletonl :: PA a => PArray a -> PArray (PArray a)
160 singletonl = lift1 singleton
161
162
163 -- | O(n). Define an array of the given size, that maps all elements to the same value.
164 replicate :: PA a => Int -> a -> PArray a
165 replicate n@(I# n#) x
166 = PArray n# $ fromVectorPA $ V.replicate n x
167
168
169 -- | O(sum lengths). Lifted replicate.
170 replicatel :: PA a => PArray Int -> PArray a -> PArray (PArray a)
171 replicatel = lift2 replicate
172
173
174 -- | O(sum lengths). Segmented replicate.
175 replicates :: PA a => U.Segd -> PArray a -> PArray a
176 replicates segd (PArray n# pdata)
177 | I# n# /= U.lengthSegd segd
178 = die "replicates" $ unlines
179 [ "segd length mismatch"
180 , " segd length = " ++ show (U.lengthSegd segd)
181 , " array length = " ++ show (I# n#) ]
182
183 | otherwise
184 = let !(I# n2#) = U.elementsSegd segd
185 in PArray n2#
186 $ fromVectorPA
187 $ join $ V.zipWith V.replicate
188 (V.convert $ U.lengthsSegd segd)
189 (toVectorPA pdata)
190
191
192 -- | O(sum lengths). Wrapper for segmented replicate that takes replication counts
193 -- and uses them to build the `U.Segd`.
194 replicates' :: PA a => PArray Int -> PArray a -> PArray a
195 replicates' (PArray _ reps) arr
196 = replicates (U.lengthsToSegd $ V.convert $ toVectorPA reps) arr
197
198
199 -- | Append two arrays.
200 append :: PA a => PArray a -> PArray a -> PArray a
201 append (PArray n1# xs) (PArray n2# ys)
202 = PArray (n1# +# n2#)
203 $ fromVectorPA (toVectorPA xs V.++ toVectorPA ys)
204
205
206 -- | Lifted append.
207 appendl :: PA a => PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
208 appendl = lift2 append
209
210
211 -- | Concatenation
212 concat :: PA a => PArray (PArray a) -> PArray a
213 concat (PArray _ xss)
214 = let xs = join $ V.map A.toVector $ toVectorPA xss
215 !(I# n') = V.length xs
216 in PArray n' $ fromVectorPA xs
217
218
219 -- | Lifted concatenation
220 concatl :: PA a => PArray (PArray (PArray a)) -> PArray (PArray a)
221 concatl = lift1 concat
222
223
224 -- | Impose a nesting structure on a flat array
225 unconcat :: (PA a, PA b) => PArray (PArray a) -> PArray b -> PArray (PArray b)
226 unconcat arr1 arr2
227 = nestUSegd (takeUSegd arr1) arr2
228
229
230 -- | Create a nested array from a segment descriptor and some flat data.
231 -- The segment descriptor must represent as many elements as present
232 -- in the flat data array, else `error`
233 nestUSegd :: PA a => U.Segd -> PArray a -> PArray (PArray a)
234 nestUSegd segd (PArray n# pdata)
235 | U.elementsSegd segd == I# n#
236 , I# n2# <- U.lengthSegd segd
237 = PArray n2#
238 $ fromVectorPA
239 $ V.zipWith
240 (\start len@(I# len#) -> PArray len# $ fromVectorPA $ V.slice start len (toVectorPA pdata))
241 (V.convert $ U.indicesSegd segd)
242 (V.convert $ U.lengthsSegd segd)
243
244 | otherwise
245 = error $ unlines
246 [ "Data.Array.Parallel.PArray.nestSegd: number of elements defined by "
247 ++ "segment descriptor and data array do not match"
248 , " length of segment desciptor = " ++ show (U.elementsSegd segd)
249 , " length of data array = " ++ show (I# n#) ]
250 {-# NOINLINE nestUSegd #-}
251
252
253 -- Projections ----------------------------------------------------------------
254 -- | Take the length of an array
255 length :: PA a => PArray a -> Int
256 length (PArray n# _) = I# n#
257
258
259 -- | Take the length of some arrays.
260 lengthl :: PA a => PArray (PArray a) -> PArray Int
261 lengthl = lift1 length
262
263
264 -- | Lookup a single element from the source array.
265 index :: PA a => PArray a -> Int -> a
266 index (PArray _ arr) ix
267 = (toVectorPA arr) V.! ix
268
269
270 -- | Lookup a several elements from several source arrays.
271 indexl :: PA a => PArray (PArray a) -> PArray Int -> PArray a
272 indexl = lift2 index
273
274
275 -- | Extract a range of elements from an array.
276 extract :: PA a => PArray a -> Int -> Int -> PArray a
277 extract (PArray _ vec) start len@(I# len#)
278 = PArray len#
279 $ fromVectorPA
280 $ V.slice start len (toVectorPA vec)
281
282
283 -- | Segmented extract.
284 extracts :: PA a => Vector (PArray a) -> U.SSegd -> PArray a
285 extracts arrs ssegd
286 = concat
287 $ fromVector
288 $ V.zipWith3
289 (\src start len -> extract (arrs V.! src) start len)
290 (V.convert $ U.sourcesSSegd ssegd)
291 (V.convert $ U.startsSSegd ssegd)
292 (V.convert $ U.lengthsSSegd ssegd)
293
294
295 -- | Wrapper for `extracts` that takes arrays of sources, starts and lengths of
296 -- the segments, and uses these to build the `U.SSegd`.
297 extracts'
298 :: PA a => Vector (PArray a)
299 -> PArray Int -- ^ id of source array for each segment.
300 -> PArray Int -- ^ starting index of each segment in its source array.
301 -> PArray Int -- ^ length of each segment.
302 -> PArray a
303 extracts' arrs (PArray _ sources) (PArray _ starts) (PArray _ lengths)
304 = let segd = U.lengthsToSegd $ V.convert $ toVectorPA lengths
305 ssegd = U.mkSSegd
306 (V.convert $ toVectorPA starts)
307 (V.convert $ toVectorPA sources)
308 segd
309 in extracts arrs ssegd
310
311
312 -- | Extract a range of elements from an arrary.
313 -- Like `extract` but with the parameters in a different order.
314 slice :: PA a => Int -> Int -> PArray a -> PArray a
315 slice start len arr
316 = extract arr start len
317
318
319 -- | Extract some slices from some arrays.
320 -- The arrays of starting indices and lengths must themselves
321 -- have the same length.
322 slicel :: PA a => PArray Int -> PArray Int -> PArray (PArray a) -> PArray (PArray a)
323 slicel = lift3 slice
324
325
326 -- | Take the segment descriptor from a nested array. This can cause index space
327 -- overflow if the number of elements in the result does not can not be
328 -- represented by a single machine word.
329 takeUSegd :: PA a => (PArray (PArray a)) -> U.Segd
330 takeUSegd (PArray _ pdata)
331 = U.lengthsToSegd
332 $ V.convert
333 $ V.map length
334 $ toVectorPA pdata
335
336
337 -- Pack and Combine -----------------------------------------------------------
338 -- | Select the elements of an array that have their tag set to True.
339 pack :: PA a => PArray a -> PArray Bool -> PArray a
340 pack (PArray n1# xs) (PArray n2# bs)
341 | I# n1# /= I# n2#
342 = die "pack" $ unlines
343 [ "array length mismatch"
344 , " data length = " ++ show (I# n1#)
345 , " flags length = " ++ show (I# n2#) ]
346
347 | otherwise
348 = let xs' = V.ifilter (\i _ -> (toVectorPA bs) V.! i) $ toVectorPA xs
349 !(I# n') = V.length xs'
350 in PArray n' $ fromVectorPA xs'
351
352 -- | Lifted pack.
353 packl :: PA a => PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
354 packl = lift2 pack
355
356
357 -- | Filter an array based on some tags.
358 packByTag :: PA a => PArray a -> U.Array Tag -> Tag -> PArray a
359 packByTag (PArray n1# xs) tags tag
360 | I# n1# /= U.length tags
361 = die "packByTag" $ unlines
362 [ "array length mismatch"
363 , " data length = " ++ show (I# n1#)
364 , " flags length = " ++ (show $ U.length tags) ]
365
366 | otherwise
367 = let xs' = V.ifilter (\i _ -> tags U.!: i == tag) $ toVectorPA xs
368 !(I# n') = V.length xs'
369 in PArray n' $ fromVectorPA xs'
370
371
372 -- | Combine two arrays based on a selector.
373 combine2 :: PA a => U.Sel2 -> PArray a -> PArray a -> PArray a
374 combine2 tags (PArray _ pdata1) (PArray _ pdata2)
375 = let
376 go [] [] [] = []
377 go (0 : bs) (x : xs) ys = x : go bs xs ys
378 go (1 : bs) xs (y : ys) = y : go bs xs ys
379 go _ _ _ = error "Data.Array.Parallel.PArray.combine: length mismatch"
380
381 vec3 = V.fromList
382 $ go (V.toList $ V.convert $ U.tagsSel2 tags)
383 (V.toList $ toVectorPA pdata1)
384 (V.toList $ toVectorPA pdata2)
385 !(I# n') = V.length vec3
386
387 in PArray n' $ fromVectorPA vec3
388
389
390 -- Enumerations ---------------------------------------------------------------
391 -- | Construct a range of integers
392 enumFromTo :: Int -> Int -> PArray Int
393 enumFromTo m n
394 = fromList [m..n]
395
396
397 -- | Lifted enumeration
398 enumFromTol :: PArray Int -> PArray Int -> PArray (PArray Int)
399 enumFromTol = lift2 enumFromTo
400
401
402 -- Tuples ---------------------------------------------------------------------
403 -- | O(n). Zip a pair of arrays into an array of pairs.
404 zip :: (PA a, PA b) => PArray a -> PArray b -> PArray (a, b)
405 zip (PArray n1# pdata1) (PArray _ pdata2)
406 = PArray n1#
407 $ fromVectorPA
408 $ V.zip (toVectorPA pdata1) (toVectorPA pdata2)
409
410
411 -- | Lifted zip
412 zipl :: (PA a, PA b) => PArray (PArray a) -> PArray (PArray b) -> PArray (PArray (a, b))
413 zipl = lift2 zip
414
415
416 -- | O(n). Unzip an array of pairs into a pair of arrays.
417 unzip :: (PA a, PA b) => PArray (a, b) -> (PArray a, PArray b)
418 unzip (PArray n# pdata)
419 = let (xs, ys) = V.unzip $ toVectorPA pdata
420 in ( PArray n# $ fromVectorPA xs
421 , PArray n# $ fromVectorPA ys)
422
423
424 -- | Lifted unzip
425 unzipl :: (PA a, PA b) => PArray (PArray (a, b)) -> PArray (PArray a, PArray b)
426 unzipl = lift1 unzip
427
428
429 -- Conversions ----------------------------------------------------------------
430 -- | Convert a `Vector` to a `PArray`
431 fromVector :: PA a => Vector a -> PArray a
432 fromVector vec
433 = let !(I# n#) = V.length vec
434 in PArray n# $ fromVectorPA vec
435
436
437 -- | Convert a `PArray` to a `Vector`
438 toVector :: PA a => PArray a -> Vector a
439 toVector (PArray _ vec)
440 = toVectorPA vec
441
442
443 -- | Convert a list to a `PArray`.
444 fromList :: PA a => [a] -> PArray a
445 fromList xx
446 = let !(I# n#) = P.length xx
447 in PArray n# (fromVectorPA $ V.fromList xx)
448
449
450 -- | Convert a `PArray` to a list.
451 toList :: PA a => PArray a -> [a]
452 toList (PArray _ vec)
453 = V.toList $ toVectorPA vec
454
455
456 -- | Convert a `U.Array` to a `PArray`
457 fromUArray :: (PA a, U.Elt a) => U.Array a -> PArray a
458 fromUArray uarr
459 = let !(I# n#) = U.length uarr
460 in PArray n# (fromVectorPA $ V.convert uarr)
461
462
463 -- | Convert a `PArray` to a `U.Array`
464 toUArray :: (PA a, U.Elt a) => PArray a -> U.Array a
465 toUArray (PArray _ vec)
466 = V.convert $ toVectorPA vec
467
468
469 -- | Convert a `U.Array` of tuples to a `PArray`
470 fromUArray2
471 :: (PA a, U.Elt a, PA b, U.Elt b)
472 => U.Array (a, b) -> PArray (a, b)
473
474 fromUArray2 uarr
475 = let !(I# n#) = U.length uarr
476 in PArray n# $ fromVectorPA $ V.convert uarr