dph-lifted-vseg: also store pre-demoted segd in nested arrays
[packages/dph.git] / dph-lifted-vseg / Data / Array / Parallel / PArray.hs
1 {-# OPTIONS -fno-spec-constr #-}
2 {-# LANGUAGE UndecidableInstances #-}
3 #include "fusion-phases.h"
4
5 -- | Unvectorised parallel arrays.
6 --
7 -- * These operators may be used directly by unvectorised client programs.
8 --
9 -- * They are also used by the "Data.Array.Parallel.Lifted.Combinators"
10 -- module to define the closure converted versions that vectorised code
11 -- uses.
12 --
13 -- * In general, the operators here are all unsafe and don't do bounds checks.
14 -- The lifted versions also don't check that each of the argument arrays
15 -- have the same length.
16
17 -- TODO:
18 -- Export unsafe versions from Data.Array.Parallel.PArray.Unsafe, and ensure
19 -- this module exports safe wrappers. We want to use the unsafe versions in
20 -- D.A.P.Lifted.Combinators for performance reasons, but the user facing PArray
21 -- functions should all be safe. In particular, the vectoriser guarantees
22 -- that all arrays passed to lifted functions will have the same length, but
23 -- the user may not obey this restriction.
24 --
25 module Data.Array.Parallel.PArray
26 ( PArray, PA
27 , valid
28 , nf
29
30 -- * Constructors
31 , empty
32 , singleton, singletonl
33 , replicate, replicatel, replicates, replicates'
34 , append, appendl
35 , concat, concatl
36 , unconcat
37 , nestUSegd
38
39 -- * Projections
40 , length, lengthl -- length from D.A.P.PArray.PData.Base
41 , index, indexl
42 , extract, extracts, extracts'
43 , slice, slicel
44 , takeUSegd
45
46 -- * Pack and Combine
47 , pack, packl
48 , packByTag
49 , combine2
50
51 -- * Enumerations
52 , enumFromTo, enumFromTol -- from D.A.P.PArray.Scalar
53
54 -- * Tuples
55 , zip, zipl
56 , zip3
57 , zip4
58 , zip5
59 , unzip, unzipl
60
61 -- * Conversions
62 , fromVector, toVector
63 , fromList, toList
64 , fromUArray, toUArray -- from D.A.P.PArray.Scalar
65 , fromUArray2) -- from D.A.P.PArray.Scalar
66 where
67 import qualified Data.Array.Parallel.Pretty as T
68 import Data.Array.Parallel.PArray.PData
69 import Data.Array.Parallel.PArray.PRepr
70 import Data.Array.Parallel.PArray.Scalar
71 import Data.Array.Parallel.PArray.Reference
72 import GHC.Exts
73 import Data.Vector (Vector)
74 import Data.Array.Parallel.Base (Tag)
75 import qualified Data.Array.Parallel.Array as A
76 import qualified Data.Array.Parallel.Unlifted as U
77 import qualified Data.Vector as V
78 import qualified "dph-lifted-base" Data.Array.Parallel.PArray as R
79 import qualified "dph-lifted-base" Data.Array.Parallel.PArray.Reference as R
80 -- import qualified Debug.Trace
81
82 import qualified Prelude as P
83 import Prelude hiding
84 ( length, replicate, concat
85 , enumFromTo
86 , zip, zip3, unzip)
87
88
89 -- Pretty ---------------------------------------------------------------------
90 instance PA a => T.PprPhysical (PArray a) where
91 pprp (PArray n# pdata)
92 = ( T.text "PArray " T.<+> T.int (I# n#))
93 T.$+$ ( T.nest 4
94 $ pprpDataPA pdata)
95
96 instance PA a => Similar a where
97 similar = similarPA
98
99 instance PA a => R.PprPhysical1 a where
100 pprp1 = pprpPA
101
102
103 trace :: String -> a -> a
104 trace _str x
105 = x -- Debug.Trace.trace ("! " ++ str) x
106
107 -- Array -----------------------------------------------------------------------
108 -- Generic interface to PArrays.
109 --
110 -- NOTE:
111 -- The toVector conversion is defined by looking up every index instead of
112 -- using the bulk fromVectorPA function.
113 -- We do this to convert arrays of type (PArray Void) properly, as although a
114 -- (PArray Void) has an intrinsic length, a (PData Void) does not. If we try
115 -- to se the fromVectorPA function at this type we'll just get an `error`.
116 -- Arrays of type PArray Void aren't visible in the user API, but during
117 -- debugging we need to be able to print them out with the implied length.
118 --
119 instance PA e => A.Array PArray e where
120 valid = valid
121 singleton = singleton
122 append = append
123
124 length = length
125 index (PArray _ pdata) ix
126 = indexPA pdata ix
127
128 toVector arr = V.map (A.index arr) $ V.enumFromTo 0 (A.length arr - 1)
129 fromVector = fromVector
130
131 -- Operators ==================================================================
132 -- Each of these operators is wrapped in withRef functions so that we can
133 -- compare their outputs to the reference implementation.
134 -- See D.A.P.Reference for details.
135
136
137 -- Basics ---------------------------------------------------------------------
138 instance (Eq a, PA a) => Eq (PArray a) where
139 (==) (PArray _ xs) (PArray _ ys) = toVectorPA xs == toVectorPA ys
140 (/=) (PArray _ xs) (PArray _ ys) = toVectorPA xs /= toVectorPA ys
141
142
143 -- | Check that an array has a valid internal representation.
144 valid :: PA a => PArray a -> Bool
145 valid (PArray n# darr1)
146 = trace "valid"
147 $ validPA darr1
148 && coversPA True darr1 (I# n#)
149 {-# INLINE_PA valid #-}
150
151
152 -- | Force an array to normal form.
153 nf :: PA a => PArray a -> ()
154 nf (PArray _ d)
155 = trace "nf"
156 $ nfPA d
157 {-# INLINE_PA nf #-}
158
159
160 -- Constructors ----------------------------------------------------------------
161 -- | O(1). An empty array.
162 empty :: PA a => PArray a
163 empty
164 = withRef1 "empty" R.empty
165 $ PArray 0# emptyPA
166
167 {-# INLINE_PA empty #-}
168
169
170 -- | O(1). Produce an array containing a single element.
171 singleton :: PA a => a -> PArray a
172 singleton x
173 = withRef1 "singleton" (R.singleton x)
174 $ PArray 1# (replicatePA 1 x)
175 {-# INLINE_PA singleton #-}
176
177
178 -- | O(n). Produce an array of singleton arrays.
179 singletonl :: PA a => PArray a -> PArray (PArray a)
180 singletonl arr
181 = withRef2 "singletonl" (R.singletonl (toRef1 arr))
182 $ replicatel_ (replicate_ (length arr) 1) arr
183 {-# INLINE_PA singletonl #-}
184
185
186 -- | O(n). Define an array of the given size, that maps all elements to the same value.
187 -- We require the replication count to be > 0 so that it's easier to maintain
188 -- the validPR invariants for nested arrays.
189 replicate :: PA a => Int -> a -> PArray a
190 replicate n x
191 = withRef1 "replicate" (R.replicate n x)
192 $ replicate_ n x
193 {-# INLINE_PA replicate #-}
194
195 replicate_ :: PA a => Int -> a -> PArray a
196 replicate_ (I# n#) x
197 = PArray n# (replicatePA (I# n#) x)
198 {-# INLINE_PA replicate_ #-}
199
200
201 -- | O(sum lengths). Lifted replicate.
202 replicatel :: PA a => PArray Int -> PArray a -> PArray (PArray a)
203 replicatel reps arr
204 = withRef2 "replicatel" (R.replicatel (toRef1 reps) (toRef1 arr))
205 $ replicatel_ reps arr
206
207 replicatel_ :: PA a => PArray Int -> PArray a -> PArray (PArray a)
208 replicatel_ (PArray n# (PInt lens)) (PArray _ pdata)
209 = if n# ==# 0# then empty else
210 let !segd = U.lengthsToSegd lens
211 !pdata' = replicatesPA segd pdata
212 !c = I# n#
213
214 in PArray n#
215 $ mkPNestedPA
216 (U.enumFromTo 0 (c - 1))
217 lens
218 (U.indicesSegd segd)
219 (U.replicate c 0)
220 (singletondPA pdata')
221 {-# INLINE_PA replicatel_ #-}
222
223
224 -- | O(sum lengths). Segmented replicate.
225 replicates :: PA a => U.Segd -> PArray a -> PArray a
226 replicates segd arr@(PArray _ pdata)
227 = trace (T.render $ T.text "!!! replicates " T.$+$ T.pprp segd T.$+$ T.pprp arr)
228 $ withRef1 "replicates" (R.replicates segd (toRef1 arr))
229 $ let !(I# n#) = U.elementsSegd segd
230 in PArray n# $ replicatesPA segd pdata
231 {-# INLINE_PA replicates #-}
232
233
234 -- | O(sum lengths). Wrapper for segmented replicate that takes replication counts
235 -- and uses them to build the `U.Segd`.
236 replicates' :: PA a => PArray Int -> PArray a -> PArray a
237 replicates' (PArray _ (PInt reps)) arr
238 = trace "replicates'"
239 $ replicates (U.lengthsToSegd reps) arr
240 {-# INLINE_PA replicates' #-}
241
242
243 -- | Append two arrays.
244 append :: PA a => PArray a -> PArray a -> PArray a
245 append arr1@(PArray n1# darr1) arr2@(PArray n2# darr2)
246 = withRef1 "append" (R.append (toRef1 arr1) (toRef1 arr2))
247 $ PArray (n1# +# n2#) (appendPA darr1 darr2)
248 {-# INLINE_PA append #-}
249
250
251 -- | Lifted append.
252 -- Both arrays must have the same length
253 appendl :: PA a => PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
254 appendl arr1@(PArray n# pdata1) arr2@(PArray _ pdata2)
255 = withRef2 "appendl" (R.appendl (toRef2 arr1) (toRef2 arr2))
256 $ PArray n# $ appendlPA pdata1 pdata2
257 {-# INLINE_PA appendl #-}
258
259
260 -- | Concatenate a nested array.
261 concat :: PA a => PArray (PArray a) -> PArray a
262 concat arr@(PArray _ darr)
263 = withRef1 "concat" (R.concat (toRef2 arr))
264 $ let darr' = concatPA darr
265 !(I# n#) = lengthPA darr'
266 in PArray n# darr'
267 {-# INLINE_PA concat #-}
268
269
270 -- | Lifted concat.
271 concatl :: PA a => PArray (PArray (PArray a)) -> PArray (PArray a)
272 concatl arr@(PArray n# pdata1)
273 = withRef2 "concatl" (R.concatl (toRef3 arr))
274 $ PArray n# $ concatlPA pdata1
275 {-# INLINE_PA concatl #-}
276
277
278 -- | Impose a nesting structure on a flat array
279 unconcat :: (PA a, PA b) => PArray (PArray a) -> PArray b -> PArray (PArray b)
280 unconcat (PArray n# pdata1) (PArray _ pdata2)
281 = trace "! unconcat"
282 $ PArray n# $ unconcatPA pdata1 pdata2
283 {-# INLINE_PA unconcat #-}
284
285
286 -- | Create a nested array from a segment descriptor and some flat data.
287 -- The segment descriptor must represent as many elements as present
288 -- in the flat data array, else `error`
289 nestUSegd :: PA a => U.Segd -> PArray a -> PArray (PArray a)
290 nestUSegd segd (PArray n# pdata)
291 | U.elementsSegd segd == I# n#
292 , I# n2# <- U.lengthSegd segd
293 = PArray n2#
294 $ PNested (U.promoteSegdToVSegd segd) (singletondPA pdata) segd pdata
295
296 | otherwise
297 = error $ unlines
298 [ "Data.Array.Parallel.PArray.nestUSegd: number of elements defined by "
299 ++ "segment descriptor and data array do not match"
300 , " length of segment desciptor = " ++ show (U.elementsSegd segd)
301 , " length of data array = " ++ show (I# n#) ]
302 {-# NOINLINE nestUSegd #-}
303
304
305 -- Projections ---------------------------------------------------------------
306 -- | Take the length of some arrays.
307 lengthl :: PA a => PArray (PArray a) -> PArray Int
308 lengthl arr@(PArray n# (PNested vsegd _ _ _))
309 = withRef1 "lengthl" (R.lengthl (toRef2 arr))
310 $ PArray n# $ PInt $ U.takeLengthsOfVSegd vsegd
311 {-# INLINE_PA lengthl #-}
312
313
314 -- | O(1). Lookup a single element from the source array.
315 index :: PA a => PArray a -> Int -> a
316 index (PArray _ arr) ix
317 = trace "index"
318 $ indexPA arr ix
319 {-# INLINE_PA index #-}
320
321
322 -- | O(len indices). Lookup a several elements from several source arrays
323 indexl :: PA a => PArray (PArray a) -> PArray Int -> PArray a
324 indexl (PArray n# darr) (PArray _ ixs)
325 = trace "indexl"
326 $ PArray n# (indexlPA darr ixs)
327 {-# INLINE_PA indexl #-}
328
329
330 -- | Extract a range of elements from an array.
331 extract :: PA a => PArray a -> Int -> Int -> PArray a
332 extract (PArray _ arr) start len@(I# len#)
333 = trace "extract"
334 $ PArray len# (extractPA arr start len)
335 {-# INLINE_PA extract #-}
336
337
338 -- | Segmented extract.
339 extracts :: PA a => Vector (PArray a) -> U.SSegd -> PArray a
340 extracts arrs ssegd
341 = trace "extracts"
342 $ let pdatas = fromVectordPA $ V.map (\(PArray _ vec) -> vec) arrs
343 !(I# n#) = (U.sum $ U.lengthsOfSSegd ssegd)
344 in PArray n#
345 (extractssPA pdatas ssegd)
346 {-# INLINE_PA extracts #-}
347
348
349 -- | Wrapper for `extracts` that takes arrays of sources, starts and lengths of
350 -- the segments, and uses these to build the `U.SSegd`.
351 -- TODO: The lengths of the sources, starts and lengths arrays must be the same,
352 -- but this is not checked.
353 -- All sourceids must point to valid data arrays.
354 -- Segments must be within their corresponding source array.
355 extracts'
356 :: PA a
357 => Vector (PArray a)
358 -> PArray Int -- ^ id of source array for each segment.
359 -> PArray Int -- ^ starting index of each segment in its source array.
360 -> PArray Int -- ^ length of each segment.
361 -> PArray a
362 extracts' arrs (PArray _ (PInt sources)) (PArray _ (PInt starts)) (PArray _ (PInt lengths))
363 = trace "extracts'"
364 $ let segd = U.lengthsToSegd lengths
365 ssegd = U.mkSSegd starts sources segd
366 in extracts arrs ssegd
367 {-# INLINE_PA extracts' #-}
368
369
370 -- | Extract a range of elements from an arrary.
371 -- Like `extract` but with the parameters in a different order.
372 slice :: PA a => Int -> Int -> PArray a -> PArray a
373 slice start len@(I# len#) (PArray _ darr)
374 = trace "slice"
375 $ PArray len# (extractPA darr start len)
376 {-# INLINE_PA slice #-}
377
378
379 -- | Extract some slices from some arrays.
380 -- The arrays of starting indices and lengths must themselves
381 -- have the same length.
382 slicel :: PA a => PArray Int -> PArray Int -> PArray (PArray a) -> PArray (PArray a)
383 slicel (PArray n# sliceStarts) (PArray _ sliceLens) (PArray _ darr)
384 = trace "slicel"
385 $ PArray n# (slicelPA sliceStarts sliceLens darr)
386 {-# INLINE_PA slicel #-}
387
388
389 -- | Take the segment descriptor from a nested array and demote it to a
390 -- plain Segd. This is unsafe because it can cause index space overflow.
391 takeUSegd :: PArray (PArray a) -> U.Segd
392 takeUSegd (PArray _ pdata)
393 = trace "takeUSegd"
394 $ takeSegdPD pdata
395 {-# INLINE_PA takeUSegd #-}
396
397
398 -- Pack and Combine -----------------------------------------------------------
399 -- | Select the elements of an array that have their tag set to True.
400 pack :: PA a => PArray a -> PArray Bool -> PArray a
401 pack arr@(PArray _ xs) flags@(PArray _ (PBool sel2))
402 = withRef1 "pack" (R.pack (toRef1 arr) (toRef1 flags))
403 $ let darr' = packByTagPA xs (U.tagsSel2 sel2) 1
404
405 -- The selector knows how many elements are set to '1',
406 -- so we can use this for the length of the resulting array.
407 !(I# m#) = U.elementsSel2_1 sel2
408
409 in PArray m# darr'
410 {-# INLINE_PA pack #-}
411
412
413 -- | Lifted pack.
414 packl :: PA a => PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
415 packl xss@(PArray n# xdata@(PNested _ _ segd _))
416 fss@(PArray _ fdata)
417 = withRef2 "packl" (R.packl (toRef2 xss) (toRef2 fss))
418 $ let
419 -- Concatenate both arrays to get the flat data.
420 -- Although the virtual segmentation should be the same,
421 -- the physical segmentation of both arrays may be different.
422 xdata_flat = concatPA xdata
423 PBool sel = concatPA fdata
424 tags = U.tagsSel2 sel
425
426 -- Count how many elements go into each segment.
427 segd' = U.lengthsToSegd $ U.count_s segd tags 1
428
429 -- Build the result array
430 vsegd' = U.promoteSegdToVSegd segd'
431 flat' = packByTagPA xdata_flat tags 1
432 pdatas' = singletondPA flat'
433
434 in PArray n# (PNested vsegd' pdatas' segd' flat')
435 {-# INLINE_PA packl #-}
436
437
438 -- | Filter an array based on some tags.
439 packByTag :: PA a => PArray a -> U.Array Tag -> Tag -> PArray a
440 packByTag arr@(PArray _ darr) tags tag
441 = withRef1 "packByTag" (R.packByTag (toRef1 arr) tags tag)
442 $ let darr' = packByTagPA darr tags tag
443 !(I# n#) = lengthPA darr'
444 in PArray n# darr'
445
446 {-# INLINE_PA packByTag #-}
447
448
449 -- | Combine two arrays based on a selector.
450 combine2 :: forall a. PA a => U.Sel2 -> PArray a -> PArray a -> PArray a
451 combine2 sel arr1@(PArray _ darr1) arr2@(PArray _ darr2)
452 = withRef1 "combine2" (R.combine2 sel (toRef1 arr1) (toRef1 arr2))
453 $ let darr' = combine2PA sel darr1 darr2
454 !(I# n#) = lengthPA darr'
455 in PArray n# darr'
456 {-# INLINE_PA combine2 #-}
457
458
459 -- Tuples ---------------------------------------------------------------------
460 -- | O(1). Zip a pair of arrays into an array of pairs.
461 -- The two arrays must have the same length, else `error`.
462 zip :: PArray a -> PArray b -> PArray (a, b)
463 zip (PArray n# pdata1) (PArray _ pdata2)
464 = trace "zip"
465 $ PArray n# $ zipPD pdata1 pdata2
466 {-# INLINE_PA zip #-}
467
468
469 -- | Lifted zip.
470 zipl :: (PA a, PA b)
471 => PArray (PArray a) -> PArray (PArray b) -> PArray (PArray (a, b))
472 zipl (PArray n# xs) (PArray _ ys)
473 = trace "zipl"
474 $ PArray n# $ ziplPA xs ys
475 {-# INLINE_PA zipl #-}
476
477
478 -- | O(1). Zip three arrays.
479 -- All arrays must have the same length, else `error`.
480 zip3 :: PArray a -> PArray b -> PArray c -> PArray (a, b, c)
481 zip3 (PArray n# pdata1) (PArray _ pdata2) (PArray _ pdata3)
482 = trace "zip3"
483 $ PArray n# $ zip3PD pdata1 pdata2 pdata3
484 {-# INLINE_PA zip3 #-}
485
486
487 -- | O(1). Zip four arrays.
488 -- All arrays must have the same length, else `error`.
489 zip4 :: PArray a -> PArray b -> PArray c -> PArray d -> PArray (a, b, c, d)
490 zip4 (PArray n# pdata1) (PArray _ pdata2) (PArray _ pdata3) (PArray _ pdata4)
491 = trace "zip4"
492 $ PArray n# $ zip4PD pdata1 pdata2 pdata3 pdata4
493 {-# INLINE_PA zip4 #-}
494
495
496 -- | O(1). Zip five arrays.
497 -- All arrays must have the same length, else `error`.
498 zip5 :: PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray (a, b, c, d, e)
499 zip5 (PArray n# pdata1) (PArray _ pdata2) (PArray _ pdata3) (PArray _ pdata4) (PArray _ pdata5)
500 = trace "zip5"
501 $ PArray n# $ zip5PD pdata1 pdata2 pdata3 pdata4 pdata5
502 {-# INLINE_PA zip5 #-}
503
504
505 -- | O(1). Unzip an array of pairs into a pair of arrays.
506 unzip :: PArray (a, b) -> (PArray a, PArray b)
507 unzip (PArray n# (PTuple2 xs ys))
508 = trace "unzip"
509 $ (PArray n# xs, PArray n# ys)
510 {-# INLINE_PA unzip #-}
511
512
513 -- | Lifted unzip
514 unzipl :: PArray (PArray (a, b)) -> PArray (PArray a, PArray b)
515 unzipl (PArray n# pdata)
516 = trace "unzipl"
517 $ PArray n# $ unziplPD pdata
518 {-# INLINE_PA unzipl #-}
519
520
521 -- Conversions ----------------------------------------------------------------
522 -- | Convert a `Vector` to a `PArray`
523 fromVector :: PA a => Vector a -> PArray a
524 fromVector vec
525 = trace "fromVector"
526 $ let !(I# n#) = V.length vec
527 in PArray n# (fromVectorPA vec)
528 {-# INLINE_PA fromVector #-}
529
530
531 -- | Convert a `PArray` to a `Vector`
532 toVector :: PA a => PArray a -> Vector a
533 toVector (PArray _ arr)
534 = trace "toVector"
535 $ toVectorPA arr
536 {-# INLINE_PA toVector #-}
537
538
539 -- | Convert a list to a `PArray`.
540 fromList :: PA a => [a] -> PArray a
541 fromList xx
542 = trace "fromList"
543 $ let !(I# n#) = P.length xx
544 in PArray n# (fromVectorPA $ V.fromList xx)
545 {-# INLINE_PA fromList #-}
546
547
548 -- | Convert a `PArray` to a list.
549 toList :: PA a => PArray a -> [a]
550 toList (PArray _ arr)
551 = trace "toList"
552 $ V.toList $ toVectorPA arr
553 {-# INLINE_PA toList #-}
554