dph-lifted-vseg: eliminate sharing in arrays during zipl
[packages/dph.git] / dph-lifted-vseg / Data / Array / Parallel / PArray.hs
1 {-# OPTIONS -fno-spec-constr #-}
2 #include "fusion-phases.h"
3
4 -- | Functions that work directly on PArrays.
5
6 -- * The functions in this module are used by the D.A.P.Lifted.Combinator module to
7 -- define the closures that the vectoriser uses.
8 --
9 -- * The functions in this module may also be used directly by user programs.
10 --
11 -- * In general, these functions are all unsafe and don't do bounds checks.
12 -- The lifted versions also don't check that each of the argument arrays
13 -- have the same length.
14 --
15 -- TODO:
16 -- Export unsafe versions from Data.Array.Parallel.PArray.Unsafe,
17 -- and make this module export safe wrappers.
18 -- We want to use the unsafe versions in D.A.P.Lifted.Combinators
19 -- for performance reasons, but the user facing PArray functions
20 -- should all be safe.
21 --
22 module Data.Array.Parallel.PArray
23 ( PArray(..), PA
24 , valid
25 , nf
26
27 -- * Constructors
28 , empty
29 , singleton, singletonl
30 , replicate, replicatel, replicates, replicates'
31 , append, appendl
32 , concat, concatl
33 , unconcat
34 , nestUSegd
35
36 -- * Projections
37 , length, lengthl -- length from D.A.P.PArray.PData.Base
38 , index, indexl
39 , extract, extracts, extracts'
40 , slice, slicel
41 , unsafeTakeSegd
42
43 -- * Pack and Combine
44 , pack, packl
45 , packByTag
46 , combine2
47
48 -- * Enumerations
49 , enumFromTo, enumFromTol -- from D.A.P.PArray.Scalar
50
51 -- * Tuples
52 , zip, zipl -- from D.A.P.PArray.Tuple
53 , unzip, unzipl -- from D.A.P.PArray.Tuple
54
55 -- * Conversions
56 , fromVector, toVector
57 , fromList, toList
58 , fromUArray, toUArray -- from D.A.P.PArray.Scalar
59 , fromUArray2) -- from D.A.P.PArray.Scalar
60 where
61 import qualified Data.Array.Parallel.Pretty as T
62 import Data.Array.Parallel.PArray.PData
63 import Data.Array.Parallel.PArray.PRepr
64 import Data.Array.Parallel.PArray.Scalar
65 import GHC.Exts
66 import Data.Maybe
67 import Data.Vector (Vector)
68 import Data.Array.Parallel.Base (Tag)
69 import qualified "dph-lifted-reference" Data.Array.Parallel.PArray as R
70 import qualified Data.Array.Parallel.Array as A
71 import qualified Data.Array.Parallel.Unlifted as U
72 import qualified Data.Vector as V
73 import qualified Prelude as P
74 import Prelude hiding
75 ( length, replicate, concat
76 , enumFromTo
77 , zip, unzip)
78
79 import Debug.Trace
80
81 -- Config ---------------------------------------------------------------------
82 debugLiftedTrace = False
83 debugLiftedCompare = False
84
85
86 -- Tracing --------------------------------------------------------------------
87 -- TODO: we could use this to trace the lengths of the vectors being used,
88 -- as well as the types that each opeartor is being called at.
89
90 instance PA e => A.Array PArray e where
91 length arr = length arr
92
93 index (PArray _ pdata) ix
94 = indexPA pdata ix
95
96 append = append
97
98 -- The toVector conversion used for testing is built by looking up every index
99 -- instead of using the bulk fromVectorPA function.
100 -- We need to do this to convert arrays of type (PArray Void) properly, as
101 -- although a (PArray Void) has an intrinsic length, a (PData Void) does not.
102 -- Arrays of type PArray Void aren't visible in the user API, but during debugging
103 -- we need to be able to print them out with the correct length.
104 toVector arr
105 = V.map (A.index arr) $ V.enumFromTo 0 (A.length arr - 1)
106
107 fromVector
108 = fromVector
109
110 instance PA a => PprPhysical (PArray a) where
111 pprp (PArray n# pdata)
112 = ( T.text "PArray " T.<+> T.int (I# n#))
113 T.$+$ ( T.nest 4
114 $ pprpDataPA pdata)
115
116 instance PA a => PprPhysical (Vector a) where
117 pprp vec
118 = T.brackets
119 $ T.hcat
120 $ T.punctuate (T.text ", ")
121 $ V.toList $ V.map pprpPA vec
122
123 -- TODO: shift this stuff to the reference implementation module.
124 -- make the PArray constructor polymorphic
125 -- | Compare a flat array against a reference
126 withRef1 :: PA a
127 => String
128 -> R.PArray a
129 -> PArray a
130 -> PArray a
131
132 withRef1 name arrRef arrImpl
133 = let trace'
134 = if debugLiftedTrace
135 then trace (T.render $ T.text " "
136 T.$$ T.text name
137 T.$$ (T.nest 8 $ pprpPA arrImpl))
138 else id
139
140 resultOk
141 = valid arrImpl
142 && A.length arrRef == A.length arrImpl
143 && (V.and $ V.zipWith
144 similarPA
145 (A.toVectors1 arrRef) (A.toVectors1 arrImpl))
146
147 resultFail
148 = error $ T.render $ T.vcat
149 [ T.text "withRef1: failure " T.<> T.text name
150 , T.nest 4 $ pprp $ A.toVectors1 arrRef
151 , T.nest 4 $ pprpPA arrImpl ]
152
153 in trace' (if debugLiftedCompare
154 then (if resultOk then arrImpl else resultFail)
155 else arrImpl)
156 {-# INLINE withRef1 #-}
157
158
159 withRef2 :: PA a
160 => String
161 -> R.PArray (R.PArray a)
162 -> PArray (PArray a)
163 -> PArray (PArray a)
164
165 withRef2 name arrRef arrImpl
166 = let trace'
167 = if debugLiftedTrace
168 then trace (T.render $ T.text " "
169 T.$$ T.text name
170 T.$$ (T.nest 8 $ pprpPA arrImpl))
171 else id
172
173 resultOK
174 = valid arrImpl
175 && A.length arrRef == A.length arrImpl
176 && (V.and $ V.zipWith
177 (\xs ys -> V.and $ V.zipWith similarPA xs ys)
178 (A.toVectors2 arrRef) (A.toVectors2 arrImpl))
179
180 resultFail
181 = error $ T.render $ T.vcat
182 [ T.text "withRef2: failure " T.<> T.text name
183 , T.nest 4 $ pprpPA arrImpl ]
184
185 in trace' (if debugLiftedCompare
186 then (if resultOK then arrImpl else resultFail)
187 else arrImpl)
188 {-# INLINE withRef2 #-}
189
190
191 -- TODO: shift this stuff to the reference implementation module.
192 -- make the parray constructor polymorphic.
193 toRef1 :: PA a => PArray a -> R.PArray a
194 toRef1 = A.fromVectors1 . A.toVectors1
195
196 toRef2 :: PA a => PArray (PArray a) -> R.PArray (R.PArray a)
197 toRef2 = A.fromVectors2 . A.toVectors2
198
199 toRef3 :: PA a => PArray (PArray (PArray a)) -> R.PArray (R.PArray (R.PArray a))
200 toRef3 = A.fromVectors3 . A.toVectors3
201
202
203 -- Basics ---------------------------------------------------------------------
204 instance (Eq a, PA a) => Eq (PArray a) where
205 (==) (PArray _ xs) (PArray _ ys) = toVectorPA xs == toVectorPA ys
206 (/=) (PArray _ xs) (PArray _ ys) = toVectorPA xs /= toVectorPA ys
207
208
209 -- | Check that an array has a valid internal representation.
210 valid :: PA a => PArray a -> Bool
211 valid (PArray n# darr1)
212 = validPA darr1
213 && coversPA True darr1 (I# n#)
214 {-# INLINE_PA valid #-}
215
216
217 -- | Force an array to normal form.
218 nf :: PA a => PArray a -> ()
219 nf (PArray n# d)
220 = nfPA d
221 {-# INLINE_PA nf #-}
222
223
224 -- Constructors ----------------------------------------------------------------
225 -- | O(1). An empty array.
226 empty :: PA a => PArray a
227 empty
228 = withRef1 "empty" R.empty
229 $ PArray 0# emptyPA
230
231 {-# INLINE_PA empty #-}
232
233
234 -- | O(1). Produce an array containing a single element.
235 singleton :: PA a => a -> PArray a
236 singleton x
237 = withRef1 "singleton" (R.singleton x)
238 $ PArray 1# (replicatePA 1 x)
239 {-# INLINE_PA singleton #-}
240
241
242 -- | O(n). Produce an array of singleton arrays.
243 singletonl :: PA a => PArray a -> PArray (PArray a)
244 singletonl arr
245 = withRef2 "singletonl" (R.singletonl (toRef1 arr))
246 $ replicatel (replicate (length arr) 1) arr
247 {-# INLINE_PA singletonl #-}
248
249
250 -- | O(n). Define an array of the given size, that maps all elements to the same value.
251 -- We require the replication count to be > 0 so that it's easier to maintain
252 -- the validPR invariants for nested arrays.
253 replicate :: PA a => Int -> a -> PArray a
254 replicate n@(I# n#) x
255 = withRef1 "replicate" (R.replicate n x)
256 $ PArray n# (replicatePA (I# n#) x)
257 {-# INLINE_PA replicate #-}
258
259
260 -- | O(sum lengths). Lifted replicate.
261 replicatel :: PA a => PArray Int -> PArray a -> PArray (PArray a)
262 replicatel reps@(PArray n# (PInt lens)) arr@(PArray _ pdata)
263 = withRef2 "replicatel" (R.replicatel (toRef1 reps) (toRef1 arr))
264 $ if n# ==# 0# then empty else
265 let segd = U.lengthsToSegd lens
266 pdata' = replicatesPA segd pdata
267 c = I# n#
268
269 in PArray n#
270 $ mkPNested
271 (U.enumFromTo 0 (c - 1))
272 lens
273 (U.indicesSegd segd)
274 (U.replicate c 0)
275 (singletondPA pdata')
276
277 {-# INLINE_PA replicatel #-}
278
279
280 -- | O(sum lengths). Segmented replicate.
281 replicates :: PA a => U.Segd -> PArray a -> PArray a
282 replicates segd arr@(PArray _ pdata)
283 = withRef1 "replicates" (R.replicates segd (toRef1 arr))
284 $ let !(I# n#) = U.elementsSegd segd
285 in PArray n# $ replicatesPA segd pdata
286 {-# INLINE_PA replicates #-}
287
288
289 -- | O(sum lengths). Wrapper for segmented replicate that takes replication counts
290 -- and uses them to build the `U.Segd`.
291 replicates' :: PA a => PArray Int -> PArray a -> PArray a
292 replicates' (PArray _ (PInt reps)) arr
293 = replicates (U.lengthsToSegd reps) arr
294 {-# INLINE_PA replicates' #-}
295
296
297 -- | Append two arrays.
298 append :: PA a => PArray a -> PArray a -> PArray a
299 append arr1@(PArray n1# darr1) arr2@(PArray n2# darr2)
300 = withRef1 "append" (R.append (toRef1 arr1) (toRef1 arr2))
301 $ PArray (n1# +# n2#) (appendPA darr1 darr2)
302 {-# INLINE_PA append #-}
303
304
305 -- | Lifted append.
306 -- Both arrays must have the same length
307 appendl :: PA a => PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
308 appendl arr1@(PArray n# pdata1) arr2@(PArray _ pdata2)
309 = withRef2 "appendl" (R.appendl (toRef2 arr1) (toRef2 arr2))
310 $ PArray n# $ appendlPA pdata1 pdata2
311 {-# INLINE_PA appendl #-}
312
313
314 -- | Concatenate a nested array.
315 concat :: PA a => PArray (PArray a) -> PArray a
316 concat arr@(PArray _ darr)
317 = withRef1 "concat" (R.concat (toRef2 arr))
318 $ let darr' = concatPA darr
319 !(I# n#) = lengthPA darr'
320 in PArray n# darr'
321 {-# INLINE_PA concat #-}
322
323
324 -- | Lifted concat.
325 concatl :: PA a => PArray (PArray (PArray a)) -> PArray (PArray a)
326 concatl arr@(PArray n# pdata1)
327 = withRef2 "concatl" (R.concatl (toRef3 arr))
328 $ PArray n# $ concatlPA pdata1
329 {-# INLINE_PA concatl #-}
330
331
332 -- | Impose a nesting structure on a flat array
333 unconcat :: (PA a, PA b) => PArray (PArray a) -> PArray b -> PArray (PArray b)
334 unconcat (PArray n# pdata1) (PArray _ pdata2)
335 = PArray n# $ unconcatPA pdata1 pdata2
336 {-# INLINE_PA unconcat #-}
337
338
339 -- | Create a nested array from a segment descriptor and some flat data.
340 -- The segment descriptor must represent as many elements as present
341 -- in the flat data array, else `error`
342 nestUSegd :: PA a => U.Segd -> PArray a -> PArray (PArray a)
343 nestUSegd segd (PArray n# pdata)
344 | U.elementsSegd segd == I# n#
345 , I# n2# <- U.lengthSegd segd
346 = PArray n2#
347 $ PNested (U.promoteSegdToVSegd segd) (singletondPA pdata)
348
349 | otherwise
350 = error $ unlines
351 [ "Data.Array.Parallel.PArray.nestUSegdPA: number of elements defined by "
352 ++ "segment descriptor and data array do not match"
353 , " length of segment desciptor = " ++ show (U.elementsSegd segd)
354 , " length of data array = " ++ show (I# n#) ]
355 {-# NOINLINE nestUSegd #-}
356
357
358 -- Projections ---------------------------------------------------------------
359 -- | Take the length of some arrays.
360 lengthl :: PA a => PArray (PArray a) -> PArray Int
361 lengthl arr@(PArray n# (PNested vsegd _))
362 = withRef1 "lengthl" (R.lengthl (toRef2 arr))
363 $ PArray n# $ PInt $ U.takeLengthsOfVSegd vsegd
364
365
366 -- | O(1). Lookup a single element from the source array.
367 index :: PA a => PArray a -> Int -> a
368 index (PArray _ arr) ix
369 = indexPA arr ix
370 {-# INLINE_PA index #-}
371
372
373 -- | O(len indices). Lookup a several elements from several source arrays
374 indexl :: PA a => PArray (PArray a) -> PArray Int -> PArray a
375 indexl (PArray n# darr) (PArray _ ixs)
376 = PArray n# (indexlPA darr ixs)
377 {-# INLINE_PA indexl #-}
378
379
380 -- | Extract a range of elements from an array.
381 extract :: PA a => PArray a -> Int -> Int -> PArray a
382 extract (PArray _ arr) start len@(I# len#)
383 = PArray len# (extractPA arr start len)
384 {-# INLINE_PA extract #-}
385
386
387 -- | Segmented extract.
388 extracts :: PA a => Vector (PArray a) -> U.SSegd -> PArray a
389 extracts arrs ssegd
390 = let pdatas = fromVectordPA $ V.map (\(PArray _ vec) -> vec) arrs
391 !(I# n#) = (U.sum $ U.lengthsSSegd ssegd)
392 in PArray n#
393 (extractsPA pdatas ssegd)
394 {-# INLINE_PA extracts #-}
395
396
397 -- | Wrapper for `extracts` that takes arrays of sources, starts and lengths of
398 -- the segments, and uses these to build the `U.SSegd`.
399 -- TODO: The lengths of the sources, starts and lengths arrays must be the same,
400 -- but this is not checked.
401 -- All sourceids must point to valid data arrays.
402 -- Segments must be within their corresponding source array.
403 extracts'
404 :: PA a
405 => Vector (PArray a)
406 -> PArray Int -- ^ id of source array for each segment.
407 -> PArray Int -- ^ starting index of each segment in its source array.
408 -> PArray Int -- ^ length of each segment.
409 -> PArray a
410 extracts' arrs (PArray _ (PInt sources)) (PArray _ (PInt starts)) (PArray _ (PInt lengths))
411 = let segd = U.lengthsToSegd lengths
412 ssegd = U.mkSSegd starts sources segd
413 in extracts arrs ssegd
414 {-# INLINE_PA extracts' #-}
415
416
417 -- | Extract a range of elements from an arrary.
418 -- Like `extract` but with the parameters in a different order.
419 slice :: PA a => Int -> Int -> PArray a -> PArray a
420 slice start len@(I# len#) (PArray _ darr)
421 = PArray len# (extractPA darr start len)
422 {-# INLINE_PA slice #-}
423
424
425 -- | Extract some slices from some arrays.
426 -- The arrays of starting indices and lengths must themselves
427 -- have the same length.
428 slicel :: PA a => PArray Int -> PArray Int -> PArray (PArray a) -> PArray (PArray a)
429 slicel (PArray n# sliceStarts) (PArray _ sliceLens) (PArray _ darr)
430 = PArray n# (slicelPD sliceStarts sliceLens darr)
431 {-# INLINE_PA slicel #-}
432
433
434 -- | Take the segment descriptor from a nested array and demote it to a
435 -- plain Segd. This is unsafe because it can cause index space overflow.
436 unsafeTakeSegd :: PArray (PArray a) -> U.Segd
437 unsafeTakeSegd (PArray _ pdata)
438 = unsafeTakeSegdPD pdata
439 {-# INLINE_PA unsafeTakeSegd #-}
440
441
442 -- Pack and Combine -----------------------------------------------------------
443 -- | Select the elements of an array that have their tag set to True.
444 pack :: PA a => PArray a -> PArray Bool -> PArray a
445 pack arr@(PArray _ xs) flags@(PArray _ (PBool sel2))
446 = withRef1 "pack" (R.pack (toRef1 arr) (toRef1 flags))
447 $ let darr' = packByTagPA xs (U.tagsSel2 sel2) 1
448
449 -- The selector knows how many elements are set to '1',
450 -- so we can use this for the length of the resulting array.
451 !(I# m#) = U.elementsSel2_1 sel2
452
453 in PArray m# darr'
454 {-# INLINE_PA pack #-}
455
456
457 -- | Lifted pack.
458 packl :: PA a => PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
459 packl xss@(PArray n# xdata@(PNested vsegd _))
460 fss@(PArray _ fdata)
461 = withRef2 "packl" (R.packl (toRef2 xss) (toRef2 fss))
462 $ let
463 -- Demote the vsegd to get the virtual segmentation of the two arrays.
464 -- The virtual segmentation of both must be the same, but this is not checked.
465 segd = U.demoteToSegdOfVSegd vsegd
466
467 -- Concatenate both arrays to get the flat data.
468 -- Although the virtual segmentation should be the same,
469 -- the physical segmentation of both arrays may be different.
470 xdata_flat = concatPA xdata
471 fdata_flat@(PBool sel) = concatPA fdata
472 tags = U.tagsSel2 sel
473
474 -- Count how many elements go into each segment.
475 segd' = U.lengthsToSegd $ U.count_s segd tags 1
476
477 -- Build the result array
478 vsegd' = U.promoteSegdToVSegd segd'
479 xdata' = packByTagPA xdata_flat tags 1
480
481 in PArray n# (PNested vsegd' $ singletondPA xdata')
482 {-# INLINE_PA packl #-}
483
484
485 -- | Filter an array based on some tags.
486 packByTag :: PA a => PArray a -> U.Array Tag -> Tag -> PArray a
487 packByTag arr@(PArray _ darr) tags tag
488 = withRef1 "packByTag" (R.packByTag (toRef1 arr) tags tag)
489 $ let darr' = packByTagPA darr tags tag
490 !(I# n#) = lengthPA darr'
491 in PArray n# darr'
492
493 {-# INLINE_PA packByTag #-}
494
495
496 -- | Combine two arrays based on a selector.
497 combine2 :: forall a. PA a => U.Sel2 -> PArray a -> PArray a -> PArray a
498 combine2 sel arr1@(PArray _ darr1) arr2@(PArray _ darr2)
499 = withRef1 "combine2" (R.combine2 sel (toRef1 arr1) (toRef1 arr2))
500 $ let darr' = combine2PA sel darr1 darr2
501 !(I# n#) = lengthPA darr'
502 in PArray n# darr'
503 {-# INLINE_PA combine2 #-}
504
505
506 -- Conversions ----------------------------------------------------------------
507 -- | Convert a `Vector` to a `PArray`
508 fromVector :: PA a => Vector a -> PArray a
509 fromVector vec
510 = let !(I# n#) = V.length vec
511 in PArray n# (fromVectorPA vec)
512 {-# INLINE_PA fromVector #-}
513
514
515 -- | Convert a `PArray` to a `Vector`
516 toVector :: PA a => PArray a -> Vector a
517 toVector (PArray _ arr)
518 = toVectorPA arr
519 {-# INLINE_PA toVector #-}
520
521
522 -- | Convert a list to a `PArray`.
523 fromList :: PA a => [a] -> PArray a
524 fromList xx
525 = let !(I# n#) = P.length xx
526 in PArray n# (fromVectorPA $ V.fromList xx)
527 {-# INLINE_PA fromList #-}
528
529
530 -- | Convert a `PArray` to a list.
531 toList :: PA a => PArray a -> [a]
532 toList (PArray _ arr)
533 = V.toList $ toVectorPA arr
534 {-# INLINE_PA toList #-}
535