dph-lifted-vseg: also store pre-demoted segd in nested arrays
[packages/dph.git] / dph-lifted-vseg / Data / Array / Parallel / PArray / PData / Nested.hs
1 {-# OPTIONS_HADDOCK hide #-}
2 {-# LANGUAGE UndecidableInstances, ParallelListComp #-}
3 {-# OPTIONS -fno-spec-constr #-}
4 #include "fusion-phases.h"
5
6 -- | PR instance for nested arrays.
7 module Data.Array.Parallel.PArray.PData.Nested
8 ( PData(..)
9 , PDatas(..)
10 , mkPNested
11 , concatPR, concatlPR
12 , flattenPR, takeSegdPD
13 , unconcatPR
14 , appendlPR
15 , indexlPR
16 , slicelPR
17 , extractvs_delay)
18 where
19 import Data.Array.Parallel.Base
20 import Data.Array.Parallel.Pretty
21 import Data.Array.Parallel.PArray.PData.Base as PA
22 import qualified Data.IntSet as IS
23 import qualified Data.Array.Parallel.Unlifted as U
24 import qualified Data.Vector as V
25 import GHC.Exts
26 import System.IO.Unsafe
27
28 -- Nested arrays --------------------------------------------------------------
29 data instance PData (PArray a)
30 = PNested
31 { pnested_uvsegd :: !U.VSegd
32 -- ^ Virtual segmentation descriptor.
33 -- Defines a virtual nested array based on physical data.
34
35 , pnested_psegdata :: !(PDatas a)
36 -- ^ Chunks of array data, where each chunk has a linear index space.
37
38 , pnested_segd :: U.Segd -- LAZY FIELD
39 -- ^ A demoted version of the VSegd.
40 -- If the function that creates the array already has the plain Segd,
41 -- then it should stash it here, otherwise build a thunk that makes it.
42
43 , pnested_flat :: PData a -- LAZY FIELD
44 -- ^ A pre-concatenated version of the array.
45 -- If the function that creates the array already has a flat form,
46 -- then it should stash it here, otherwise build a thunk that makes it.
47 }
48
49
50 -- TODO: should we unpack the vsegd fields here?
51 data instance PDatas (PArray a)
52 = PNesteds (V.Vector (PData (PArray a)))
53
54 -- | Conatruct a nested array.
55 -- TODO: this function needs to die.
56 --
57 mkPNested :: PR a
58 => U.Array Int -- ^ Virtual segment ids.
59 -> U.Array Int -- ^ Lengths of physical segments.
60 -> U.Array Int -- ^ Starting indices of physical segments.
61 -> U.Array Int -- ^ Source id (what chunk to get each segment from).
62 -> PDatas a -- ^ Chunks of array data.
63 -> PData (PArray a)
64 mkPNested vsegids pseglens psegstarts psegsrcids pdatas
65 = let vsegd = U.mkVSegd vsegids
66 $ U.mkSSegd psegstarts psegsrcids
67 $ U.lengthsToSegd pseglens
68
69 segd = U.unsafeDemoteToSegdOfVSegd vsegd
70 flat = extractvs_delay pdatas vsegd
71 in PNested vsegd pdatas segd flat
72
73 {-# INLINE_PDATA mkPNested #-}
74
75
76 -- Old projection functions.
77 -- TODO: refactor to eliminate the need for these.
78 pnested_vsegids :: PData (PArray a) -> U.Array Int
79 pnested_vsegids = U.takeVSegidsOfVSegd . pnested_uvsegd
80
81 pnested_pseglens :: PData (PArray a) -> U.Array Int
82 pnested_pseglens = U.lengthsOfSSegd . U.takeSSegdOfVSegd . pnested_uvsegd
83
84 pnested_psegstarts :: PData (PArray a) -> U.Array Int
85 pnested_psegstarts = U.startsOfSSegd . U.takeSSegdOfVSegd . pnested_uvsegd
86
87 pnested_psegsrcids :: PData (PArray a) -> U.Array Int
88 pnested_psegsrcids = U.sourcesOfSSegd . U.takeSSegdOfVSegd . pnested_uvsegd
89
90
91
92 -- PR Instances ---------------------------------------------------------------
93 instance U.Elt (Int, Int, Int)
94
95 instance PR a => PR (PArray a) where
96 -- TODO: make this check all sub arrays as well
97 -- TODO: ensure that all psegdata arrays are referenced from some psegsrc.
98 -- TODO: shift segd checks into associated modules.
99 {-# NOINLINE validPR #-}
100 validPR arr
101 = let vsegids = pnested_vsegids arr
102 pseglens = pnested_pseglens arr
103 psegstarts = pnested_psegstarts arr
104 psegsrcs = pnested_psegsrcids arr
105 psegdata = pnested_psegdata arr
106
107
108 -- The lengths of the pseglens, psegstarts and psegsrcs fields must all be the same
109 fieldLensOK
110 = validBool "nested array field lengths not identical"
111 $ and
112 [ U.length psegstarts == U.length pseglens
113 , U.length psegsrcs == U.length pseglens ]
114
115 -- Every vseg must reference a valid pseg.
116 vsegsRefOK
117 = validBool "nested array vseg doesn't ref pseg"
118 $ U.and
119 $ U.map (\vseg -> vseg < U.length pseglens) vsegids
120
121 -- Every pseg source id must point to a flat data array
122 psegsrcsRefOK
123 = validBool "nested array psegsrc doesn't ref flat array"
124 $ U.and
125 $ U.map (\srcid -> srcid < lengthdPR psegdata) psegsrcs
126
127 -- Every physical segment must be a valid slice of the corresponding flat array.
128 --
129 -- We allow psegs with len 0, start 0 even if the flat array is empty.
130 -- This occurs with [ [] ].
131 --
132 -- As a generalistion of above, we allow segments with len 0, start <= srclen.
133 -- This occurs when there is an empty array as the last segment
134 -- For example:
135 -- [ [5, 4, 3, 2] [ ] ].
136 -- PNested vsegids: [0,1]
137 -- pseglens: [4,0]
138 -- psegstarts: [0,4] -- last '4' here is <= length of flat array
139 -- psegsrcs: [0,0]
140 -- PInt [5, 4, 3, 2]
141 --
142 psegSlicesOK
143 = validBool "nested array pseg slices are invalid"
144 $ U.and
145 $ U.zipWith3
146 (\len start srcid
147 -> let pdata = psegdata `indexdPR` srcid
148 in and [ coversPR (len == 0) pdata start
149 , coversPR True pdata (start + len) ])
150 pseglens psegstarts psegsrcs
151
152 -- Every pseg must be referenced by some vseg.
153 vsegs = IS.fromList $ U.toList vsegids
154 psegsReffedOK
155 = validBool "nested array pseg not reffed by vseg"
156 $ (U.length pseglens == 0)
157 || (U.and $ U.map (flip IS.member vsegs)
158 $ U.enumFromTo 0 (U.length pseglens - 1))
159
160 in unsafePerformIO
161 $ do {-print fieldLensOK
162 print vsegsRefOK
163 print psegsrcsRefOK
164 print psegSlicesOK
165 print psegsReffedOK-}
166 return $
167 and [ fieldLensOK
168 , vsegsRefOK
169 , psegsrcsRefOK
170 , psegSlicesOK
171 , psegsReffedOK ]
172
173 {-# NOINLINE nfPR #-}
174 nfPR = error "nfPR[PArray]: not defined yet"
175
176
177 {-# NOINLINE similarPR #-}
178 similarPR (PArray _ pdata1) (PArray _ pdata2)
179 = V.and $ V.zipWith similarPR
180 (toVectorPR pdata1)
181 (toVectorPR pdata2)
182
183
184 {-# NOINLINE coversPR #-}
185 coversPR weak (PNested vsegd _ _ _) ix
186 | weak = ix <= (U.length $ U.takeVSegidsOfVSegd vsegd)
187 | otherwise = ix < (U.length $ U.takeVSegidsOfVSegd vsegd)
188
189 {-# NOINLINE pprpPR #-}
190 pprpPR (PArray n# pdata)
191 = (text "PArray " <+> int (I# n#))
192 $+$ ( nest 4
193 $ pprpDataPR pdata)
194
195 {-# NOINLINE pprpDataPR #-}
196 pprpDataPR (PNested vsegd pdatas _ _)
197 = text "PNested"
198 $+$ ( nest 4
199 $ pprp vsegd $$ pprp pdatas)
200
201
202 -- Constructors -----------------------------------------
203 {-# INLINE_PDATA emptyPR #-}
204 emptyPR = PNested U.emptyVSegd emptydPR U.emptySegd emptyPR
205
206
207 -- When replicating an array we use the source as the single physical
208 -- segment, then point all the virtual segments to it.
209 {-# INLINE_PDATA replicatePR #-}
210 replicatePR c (PArray n# darr)
211 = {-# SCC "replicatePR" #-}
212 checkNotEmpty "replicatePR[PArray]" c
213 $ let -- Physical segment descriptor contains a single segment.
214 ussegd = U.singletonSSegd (I# n#)
215
216 -- All virtual segments point to the same physical segment.
217 vsegd = U.mkVSegd (U.replicate c 0) ussegd
218 pdatas = singletondPR darr
219
220 -- Pre-concatenated version
221 segd = U.unsafeDemoteToSegdOfVSegd vsegd
222 flat = extractvs_delay pdatas vsegd
223
224 in PNested vsegd pdatas segd flat
225
226
227 -- For segmented replicates, we just replicate the vsegids field.
228 --
229 -- TODO: Does replicate_s really need the whole segd,
230 -- or could we get away without creating the indices field?
231 --
232 -- TODO: If we know the lens does not contain zeros, then we don't need
233 -- to cull down the psegs.
234 --
235 {-# INLINE_PDATA replicatesPR #-}
236 replicatesPR segd (PNested uvsegd pdatas _ _)
237 = let vsegd' = U.updateVSegsOfVSegd (\vsegids -> U.replicate_s segd vsegids) uvsegd
238 segd' = U.unsafeDemoteToSegdOfVSegd vsegd'
239 flat' = extractvs_delay pdatas vsegd'
240 in PNested vsegd' pdatas segd' flat'
241
242
243 -- Append nested arrays by appending the segment descriptors,
244 -- and putting all physical arrays in the result.
245 {-# NOINLINE appendPR #-}
246 appendPR (PNested uvsegd1 pdatas1 _ _) (PNested uvsegd2 pdatas2 _ _)
247 = let vsegd' = U.appendVSegd
248 uvsegd1 (lengthdPR pdatas1)
249 uvsegd2 (lengthdPR pdatas2)
250
251 pdatas' = appenddPR pdatas1 pdatas2
252 segd' = U.unsafeDemoteToSegdOfVSegd vsegd'
253 flat' = extractvs_delay pdatas' vsegd'
254
255 in PNested vsegd' pdatas' segd' flat'
256
257
258 -- Performing segmented append requires segments from the physical arrays to
259 -- be interspersed, so we need to copy data from the second level of nesting.
260 --
261 -- In the implementation we can safely flatten out replication in the vsegs
262 -- because the source program result would have this same physical size
263 -- anyway. Once this is done we use copying segmented append on the flat
264 -- arrays, and then reconstruct the segment descriptor.
265 --
266 {-# NOINLINE appendsPR #-}
267 appendsPR rsegd segd1 xarr segd2 yarr
268 = let (xsegd, xs) = flattenPR xarr
269 (ysegd, ys) = flattenPR yarr
270
271 xsegd' = U.lengthsToSegd
272 $ U.sum_s segd1 (U.lengthsSegd xsegd)
273
274 ysegd' = U.lengthsToSegd
275 $ U.sum_s segd2 (U.lengthsSegd ysegd)
276
277 segd' = U.lengthsToSegd
278 $ U.append_s rsegd segd1 (U.lengthsSegd xsegd)
279 segd2 (U.lengthsSegd ysegd)
280
281
282 -- The pdatas only contains a single flat chunk.
283 vsegd' = U.promoteSegdToVSegd segd'
284 flat' = appendsPR (U.plusSegd xsegd' ysegd')
285 xsegd' xs
286 ysegd' ys
287
288 pdatas' = singletondPR flat'
289
290 in PNested vsegd' pdatas' segd' flat'
291
292
293 -- Projections ------------------------------------------
294 {-# INLINE_PDATA lengthPR #-}
295 lengthPR (PNested vsegd _ _ _)
296 = U.lengthOfVSegd vsegd
297
298
299 -- To index into a nested array, first determine what segment the index
300 -- corresponds to, and extract that as a slice from that physical array.
301 --
302 -- IMPORTANT:
303 -- We need to go through the vsegd here, instead of demanding the
304 -- flat version, because we don't want to force creation of the
305 -- entire manifest array.
306 {-# INLINE_PDATA indexPR #-}
307 indexPR (PNested uvsegd pdatas _ _) ix
308 | (pseglen@(I# pseglen#), psegstart, psegsrcid) <- U.getSegOfVSegd uvsegd ix
309 = let !psrc = pdatas `indexdPR` psegsrcid
310 !pdata' = extractPR psrc psegstart pseglen
311 in PArray pseglen# pdata'
312
313
314 {-# INLINE_PDATA indexsPR #-}
315 indexsPR pdatas@(PNesteds arrs) srcixs
316 = let (srcids, ixs) = U.unzip srcixs
317
318 -- See Note: psrcoffset
319 !psrcoffset = V.prescanl (+) 0
320 $ V.map (lengthdPR . pnested_psegdata) arrs
321
322 -- length, start and srcid of the segments we're returning.
323 -- Note that we need to offset the srcid
324 -- TODO: don't unbox the VSegd for every iteration.
325 seginfo :: U.Array (Int, Int, Int)
326 seginfo
327 = U.zipWith (\srcid ix ->
328 let (PNested vsegd _ _ _) = pdatas `indexdPR` srcid
329 (len, start, srcid') = U.getSegOfVSegd vsegd ix
330 in (len, start, srcid' + (psrcoffset `V.unsafeIndex` srcid)))
331 srcids
332 ixs
333
334 (pseglens', psegstarts', psegsrcs')
335 = U.unzip3 seginfo
336
337 -- TODO: check that doing lengthsToSegd won't cause overflow
338 segd' = U.lengthsToSegd pseglens'
339 vsegd' = U.promoteSSegdToVSegd
340 $ U.mkSSegd psegstarts' psegsrcs' segd'
341
342 -- All flat data arrays in the sources go into the result.
343 pdatas' = fromVectordPR
344 $ V.concat $ V.toList
345 $ V.map (toVectordPR . pnested_psegdata) arrs
346
347 flat' = extractvs_delay pdatas' vsegd'
348
349 in PNested vsegd' pdatas' segd' flat'
350
351
352 -- To extract a range of elements from a nested array, perform the extract
353 -- on the vsegids field. The `updateVSegsOfUVSegd` function will then filter
354 -- out all of the psegs that are no longer reachable from the new vsegids.
355 --
356 -- IMPORTANT:
357 -- We need to go through the vsegd here, instead of demanding the
358 -- flat version, because we don't want to force creation of the
359 -- entire manifest array.
360 {-# INLINE_PDATA extractPR #-}
361 extractPR (PNested uvsegd pdatas _ _) start len
362 = let vsegd' = U.updateVSegsOfVSegd (\vsegids -> U.extract vsegids start len) uvsegd
363 segd' = U.unsafeDemoteToSegdOfVSegd vsegd'
364 flat' = extractvs_delay pdatas vsegd'
365 in PNested vsegd' pdatas segd' flat'
366
367
368 -- [Note: psrcoffset]
369 -- ~~~~~~~~~~~~~~~~~~
370 -- As all the flat data arrays in the sources are present in the result array,
371 -- we need to offset the psegsrcs field when combining multiple sources.
372 --
373 -- Exaple
374 -- Source Arrays:
375 -- arr0 ...
376 -- psrcids : [0, 0, 0, 1, 1]
377 -- psegdata : [PInt xs1, PInt xs2]
378 --
379 -- arr1 ...
380 -- psrcids : [0, 0, 1, 1, 2, 2, 2]
381 -- psegdata : [PInt ys1, PInt ys2, PInt ys3]
382 --
383 -- Result Array:
384 -- psrcids : [...]
385 -- psegdata : [PInt xs1, PInt xs2, PInt ys1, PInt ys2, PInt ys3]
386 --
387 -- Note that references to flatdata arrays [0, 1, 2] in arr1 need to be offset
388 -- by 2 (which is length arr0.psegdata) to refer to the same flat data arrays
389 -- in the result.
390 --
391 -- We encode these offsets in the psrcoffset vector:
392 -- psrcoffset : [0, 2]
393 --
394 -- TODO: cleanup pnested projections
395 -- use getSegOfUVSegd like in indexlPR
396 --
397 {-# NOINLINE extractssPR #-}
398 extractssPR (PNesteds arrs) ussegd
399 = let segsrcs = U.sourcesOfSSegd ussegd
400 seglens = U.lengthsOfSSegd ussegd
401
402 vsegids_src = U.extracts_nss ussegd (V.map pnested_vsegids arrs)
403 srcids' = U.replicate_s (U.lengthsToSegd seglens) segsrcs
404
405 -- See Note: psrcoffset
406 psrcoffset = V.prescanl (+) 0
407 $ V.map (lengthdPR . pnested_psegdata) arrs
408
409 -- Unpack the lens and srcids arrays so we don't need to
410 -- go though all the segment descriptors each time.
411 !arrs_pseglens = V.map pnested_pseglens arrs
412 !arrs_psegstarts = V.map pnested_psegstarts arrs
413 !arrs_psegsrcids = V.map pnested_psegsrcids arrs
414
415 !here' = "extractssPR[Nested]"
416 -- Function to get one element of the result.
417 {-# INLINE get #-}
418 get srcid vsegid
419 = let !pseglen = U.index here' (arrs_pseglens `V.unsafeIndex` srcid) vsegid
420 !psegstart = U.index here' (arrs_psegstarts `V.unsafeIndex` srcid) vsegid
421 !psegsrcid = (U.index here' (arrs_psegsrcids `V.unsafeIndex` srcid) vsegid)
422 + (psrcoffset `V.unsafeIndex` srcid)
423 in (pseglen, psegstart, psegsrcid)
424
425 (pseglens', psegstarts', psegsrcs')
426 = U.unzip3 $ U.zipWith get srcids' vsegids_src
427
428 -- All flat data arrays in the sources go into the result.
429 pdatas' = fromVectordPR
430 $ V.concat $ V.toList
431 $ V.map (toVectordPR . pnested_psegdata) arrs
432
433 -- Build the result segment descriptor.
434 segd' = U.lengthsToSegd pseglens'
435 vsegd' = U.promoteSSegdToVSegd
436 $ U.mkSSegd psegstarts' psegsrcs' segd'
437
438 flat' = extractvs_delay pdatas' vsegd'
439
440 in PNested vsegd' pdatas' segd' flat'
441
442
443 {-# INLINE_PDATA extractvsPR #-}
444 extractvsPR pdatas vsegd
445 = extractssPR pdatas (U.demoteToSSegdOfVSegd vsegd)
446
447
448 -- Pack and Combine -------------------------------------
449 -- Pack the vsegids to determine which of the vsegs are present in the result.
450 -- eg tags: [0 1 1 1 0 0 0 0 1 0 0 0 0 1 0 1 0 1 1] tag = 1
451 -- vsegids: [0 0 1 1 2 2 2 2 3 3 4 4 4 5 5 5 5 6 6]
452 -- => vsegids_packed: [ 0 1 1 3 5 5 6 6]
453 --
454 {-# INLINE_PDATA packByTagPR #-}
455 packByTagPR (PNested vsegd pdatas _ _) tags tag
456 = let vsegd' = U.updateVSegsOfVSegd (\vsegids -> U.packByTag vsegids tags tag) vsegd
457 segd' = U.unsafeDemoteToSegdOfVSegd vsegd'
458 flat' = extractvs_delay pdatas vsegd'
459 in PNested vsegd' pdatas segd' flat'
460
461
462 -- Combine nested arrays by combining the segment descriptors,
463 -- and putting all physical arrays in the result.
464 {-# INLINE_PDATA combine2PR #-}
465 combine2PR sel2 (PNested vsegd1 pdatas1 _ _) (PNested vsegd2 pdatas2 _ _)
466 = let vsegd' = U.combine2VSegd sel2
467 vsegd1 (lengthdPR pdatas1)
468 vsegd2 (lengthdPR pdatas2)
469
470 pdatas' = appenddPR pdatas1 pdatas2
471 segd' = U.unsafeDemoteToSegdOfVSegd vsegd'
472 flat' = extractvs_delay pdatas' vsegd'
473 in PNested vsegd' pdatas' segd' flat'
474
475
476 -- Conversions ----------------------
477 -- TODO: pack in pre-existing segd and flat version
478 {-# NOINLINE fromVectorPR #-}
479 fromVectorPR xx
480 | V.length xx == 0 = emptyPR
481 | otherwise
482 = let segd = U.lengthsToSegd $ U.fromList $ V.toList $ V.map PA.length xx
483 in mkPNested
484 (U.enumFromTo 0 (V.length xx - 1))
485 (U.lengthsSegd segd)
486 (U.indicesSegd segd)
487 (U.replicate (V.length xx) 0)
488 (singletondPR (V.foldl1 appendPR $ V.map takeData xx))
489
490
491 {-# NOINLINE toVectorPR #-}
492 toVectorPR arr
493 = V.generate (U.length (pnested_vsegids arr))
494 $ indexPR arr
495
496
497 -- PData --------------------------------------
498 {-# INLINE_PDATA emptydPR #-}
499 emptydPR
500 = PNesteds $ V.empty
501
502 {-# INLINE_PDATA singletondPR #-}
503 singletondPR pdata
504 = PNesteds $ V.singleton pdata
505
506 {-# INLINE_PDATA lengthdPR #-}
507 lengthdPR (PNesteds vec)
508 = V.length vec
509
510 {-# INLINE_PDATA indexdPR #-}
511 indexdPR (PNesteds vec) ix
512 = vec `V.unsafeIndex` ix
513
514 {-# INLINE_PDATA appenddPR #-}
515 appenddPR (PNesteds xs) (PNesteds ys)
516 = PNesteds $ xs V.++ ys
517
518 {-# INLINE_PDATA fromVectordPR #-}
519 fromVectordPR vec
520 = PNesteds vec
521
522 {-# INLINE_PDATA toVectordPR #-}
523 toVectordPR (PNesteds vec)
524 = vec
525
526 ------------------------------------------------------------------------------
527 -- | O(len result). Lifted indexing
528 indexlPR :: PR a => PData (PArray a) -> PData Int -> PData a
529 indexlPR (PNested vsegd pdatas _ _) (PInt ixs)
530 = indexvsPR pdatas vsegd
531 (U.zip (U.enumFromTo 0 (U.length ixs - 1))
532 ixs)
533 {-# INLINE_PDATA indexlPR #-}
534
535
536 -------------------------------------------------------------------------------
537 -- | O(len result). Concatenate a nested array.
538 --
539 -- OLD COMMENTS Attach to extracts instead.
540 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
541 --
542 -- This physically performs a gather operation, whereby array data is copied
543 -- through the index-space transformation defined by the segment descriptor
544 -- in the nested array. We must perform this copy because reducing the level
545 -- of nesting corresponds to discarding the segment descriptor, which means we
546 -- can no longer represent the layout of the array other than by physically
547 -- creating it.
548 --
549 -- As an optimisation, if the segment descriptor knows that the segments are
550 -- already in a single contiguous `PData` no sharing, then concat can just
551 -- return the underlying array directly, in constant time.
552 --
553 -- WARNING:
554 -- Concatenating a replicated array can cause index overflow, because the
555 -- source array can define more elements than we can count with a single
556 -- machine word.
557 -- For example, if we replicate an array with 1Meg elements 1Meg times then
558 -- the result defines a total of 1Meg*1Meg = 1Tera elements. This in itself
559 -- is fine, because the nested array is defined by an index space transform
560 -- that maps all the inner arrays back to the original data. However, if we
561 -- then concatenate the replicated array then we must physically copy the
562 -- data as we loose the segment descriptor that defines the mapping. Sad
563 -- things will happen when the library tries to construct an physical array
564 -- 1Tera elements long, especially on 32 bit machines.
565
566 -- IMPORTANT:
567 -- In the case where there is sharing between segments, or they are scattered
568 -- through multiple arrays, only outer-most two levels of nesting are physically
569 -- merged. The data for lower levels is not touched. This ensures that concat
570 -- has complexity proportional to the length of the result array, instead
571 -- of the total number of elements within it.
572 --
573 concatPR :: PR a => PData (PArray a) -> PData a
574 concatPR (PNested _ _ _ flat) = flat
575 {-# INLINE concatPR #-}
576
577
578 -- | Wrapper for extracts that is NOT INLINED.
579 --
580 -- This is experimental, used to initialise the pnested_flat field
581 -- of a nested array. It's' marked at NOINLINE to avoid code explosion.
582 --
583 -- TODO: at a later fusion stage we could rewrite this to an INLINED
584 -- version to generate core for the occurrences we actually use.
585 extractvs_delay :: PR a => PDatas a -> U.VSegd -> PData a
586 extractvs_delay pdatas vsegd
587 = extractvsPR pdatas vsegd
588 {-# NOINLINE extractvs_delay #-}
589 -- NOINLINE because we don't want a copy of the extracts loop to
590 -- be generated at the use site.
591
592
593 -- | Lifted concatenation.
594 --
595 -- Concatenate all the arrays in a triply nested array.
596 --
597 concatlPR :: PR a => PData (PArray (PArray a)) -> PData (PArray a)
598 concatlPR arr
599 = let (segd1, darr1) = flattenPR arr
600 (segd2, darr2) = flattenPR darr1
601
602 -- Generate indices for the result array
603 -- There is a tedious edge case when the last segment in the nested
604 -- array has length 0. For example:
605 --
606 -- concatl [ [[1, 2, 3] [4, 5, 6]] [] ]
607 --
608 -- After the calls to flattenPR we get:
609 -- segd1: lengths1 = [ 2 0 ]
610 -- indices1 = [ 0 2 ]
611
612 -- segd2: lengths2 = [ 3 3 ]
613 -- indices2 = [ 0 3 ]
614 --
615 -- The problem is that the last element of 'indices1' points off the end
616 -- of 'indices2' so we can't use use 'backpermute' as we'd like to:
617 -- ixs' = (U.bpermute (U.indicesSegd segd2) (U.indicesSegd segd1))
618 -- Instead, we have to explicitly check for the out-of-bounds condition.
619 -- TODO: We want a faster way of doing this, that doesn't require the
620 -- test for every element.
621 --
622 ixs1 = U.indicesSegd segd1
623 ixs2 = U.indicesSegd segd2
624 len2 = U.length ixs2
625
626 ixs' = U.map (\ix -> if ix >= len2
627 then 0
628 else U.index "concatlPR" ixs2 ix)
629 $ ixs1
630
631 segd' = U.mkSegd (U.sum_s segd1 (U.lengthsSegd segd2))
632 ixs'
633 (U.elementsSegd segd2)
634
635 vsegd' = U.promoteSegdToVSegd segd'
636 flat' = darr2
637 pdatas' = singletondPR flat'
638
639 in PNested vsegd' pdatas' segd' flat'
640 {-# INLINE_PDATA concatlPR #-}
641
642
643 -- | Build a nested array given a single flat data vector,
644 -- and a template nested array that defines the segmentation.
645
646 -- Although the template nested array may be using vsegids to describe
647 -- internal sharing, the provided data array has manifest elements
648 -- for every segment. Because of this we need flatten out the virtual
649 -- segmentation of the template array.
650 --
651 -- WARNING:
652 -- This can cause index space overflow, see the note in `concatPR`.
653 --
654 unconcatPR :: PR b => PData (PArray a) -> PData b -> PData (PArray b)
655 unconcatPR (PNested vsegd _ _ _) pdata
656 = {-# SCC "unconcatPD" #-}
657 let
658 -- Demote the vsegd to a manifest vsegd so it contains all the segment
659 -- lengths individually without going through the vsegids.
660 !segd' = U.unsafeDemoteToSegdOfVSegd vsegd
661
662 -- Rebuild the vsegd based on the manifest vsegd.
663 -- The vsegids will be just [0..len-1], but this field is constructed
664 -- lazilly and consumers aren't required to demand it.
665 !vsegd' = U.promoteSegdToVSegd segd'
666
667 pdatas' = singletondPR pdata
668
669 in PNested vsegd' pdatas' segd' pdata
670 {-# INLINE_PDATA unconcatPR #-}
671
672
673 -- | Flatten a nested array, yielding a plain segment descriptor and
674 -- concatenated data.
675 --
676 flattenPR :: PR a => PData (PArray a) -> (U.Segd, PData a)
677 flattenPR (PNested _ _ segd flat)
678 = (segd, flat)
679 {-# INLINE_PDATA flattenPR #-}
680
681
682 -- | Lifted append.
683 -- Both arrays must contain the same number of elements.
684 appendlPR :: PR a => PData (PArray a) -> PData (PArray a) -> PData (PArray a)
685 appendlPR arr1 arr2
686 = let (segd1, darr1) = flattenPR arr1
687 (segd2, darr2) = flattenPR arr2
688 segd' = U.plusSegd segd1 segd2
689 vsegd' = U.promoteSegdToVSegd segd'
690
691 flat' = appendsPR segd' segd1 darr1 segd2 darr2
692 pdatas' = singletondPR flat'
693 in PNested vsegd' pdatas' segd' flat'
694 {-# INLINE_PDATA appendlPR #-}
695
696
697 -- | Extract some slices from some arrays.
698 --
699 -- All three parameters must have the same length, and we take
700 -- one slice from each of the source arrays.
701
702 -- TODO: cleanup pnested projections
703 slicelPR
704 :: PR a
705 => PData Int -- ^ Starting indices of slices.
706 -> PData Int -- ^ Lengths of slices.
707 -> PData (PArray a) -- ^ Arrays to slice.
708 -> PData (PArray a)
709
710 slicelPR (PInt sliceStarts) (PInt sliceLens) arr
711 = let segs = U.length vsegids
712 vsegids = pnested_vsegids arr
713 psegstarts = pnested_psegstarts arr
714 psegsrcs = pnested_psegsrcids arr
715 psegdata = pnested_psegdata arr
716 in
717 mkPNested
718 (U.enumFromTo 0 (segs - 1))
719 sliceLens
720 (U.zipWith (+) (U.bpermute psegstarts vsegids) sliceStarts)
721 (U.bpermute psegsrcs vsegids)
722 psegdata
723
724 {-# NOINLINE slicelPR #-}
725 -- NOINLINE because it won't fuse with anything.
726 -- The operation is also entierly on the segment descriptor, so we don't
727 -- need to inline it to specialise it for the element type.
728
729
730 -- PD Functions ---------------------------------------------------------------
731 -- These functions work on nested PData arrays, but don't need a PR or PA
732 -- dictionary. They are segment descriptor operations that only care about the
733 -- outermost later of segmentation, and thus are oblivous to the element type.
734 --
735
736 -- | Take the segment descriptor from a nested array and demote it to a
737 -- plain Segd.
738 --
739 -- WARNING:
740 -- This can cause index space overflow, see the note in `concatPR`.
741 --
742 takeSegdPD :: PData (PArray a) -> U.Segd
743 takeSegdPD (PNested _ _ segd _)
744 = segd
745 {-# INLINE_PDATA takeSegdPD #-}
746
747
748
749 -- Testing --------------------------------------------------------------------
750 -- TODO: slurp debug flag from base
751 validBool :: String -> Bool -> Bool
752 validBool str b
753 = if b then True
754 else error $ "validBool check failed -- " ++ str
755
756
757 -- Pretty ---------------------------------------------------------------------
758 deriving instance (Show (PDatas a), Show (PData a)) => Show (PDatas (PArray a))
759 deriving instance (Show (PDatas a), Show (PData a)) => Show (PData (PArray a))
760
761