dph-prim-*: add updateVSegsReachable for when we know the result covers all psegs
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPVSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
5
6 -- | Parallel virtual segment descriptors.
7 module Data.Array.Parallel.Unlifted.Parallel.UPVSegd (
8 -- * Types
9 UPVSegd,
10
11 -- * Consistency check
12 valid,
13
14 -- * Constructors
15 mkUPVSegd,
16 fromUPSegd,
17 fromUPSSegd,
18 empty,
19 singleton,
20
21 -- * Predicates
22 isManifest,
23 isContiguous,
24
25 -- * Projections
26 length,
27 takeVSegids,
28 takeUPSSegd,
29 takeLengths,
30 getSeg,
31
32 -- * Demotion
33 demoteToUPSSegd,
34 unsafeDemoteToUPSegd,
35
36 -- * Operators
37 updateVSegs,
38 updateVSegsReachable,
39
40 appendWith,
41 combine2,
42 ) where
43 import Data.Array.Parallel.Unlifted.Parallel.Permute
44 import Data.Array.Parallel.Unlifted.Parallel.UPSel (UPSel2)
45 import Data.Array.Parallel.Unlifted.Parallel.UPSSegd (UPSSegd)
46 import Data.Array.Parallel.Unlifted.Parallel.UPSegd (UPSegd)
47 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector)
48 import Data.Array.Parallel.Pretty hiding (empty)
49 import Prelude hiding (length)
50
51 import qualified Data.Array.Parallel.Unlifted.Sequential.USSegd as USSegd
52 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
53 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSel as UPSel
54 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSegd as UPSegd
55 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSSegd as UPSSegd
56
57
58 -- UPVSegd ---------------------------------------------------------------------
59 -- | A parallel virtual segment descriptor is an extension of `UPSSegd`
60 -- that explicitly represents sharing of data between multiple segments.
61 --
62 -- TODO: It would probably be better to represent the vsegids as a lens (function)
63 -- instead of a vector of segids. Much of the time the vsegids are just [0..n]
64 --
65 data UPVSegd
66 = UPVSegd
67 { upvsegd_manifest :: !Bool
68 -- ^ When the vsegids field holds a lazy (V.enumFromTo 0 (len - 1))
69 -- then this field is True. This lets us perform some operations like
70 -- demoteToUPSSegd without actually creating it.
71
72 , upvsegd_vsegids :: Vector Int
73 -- ^ Virtual segment identifiers that indicate what physical segment
74 -- to use for each virtual segment.
75 --
76 -- IMPORTANT:
77 --- This field must be lazy (no bang) because when it has the value
78 -- (V.enumFromTo 0 (len - 1)) we want to avoid building the enumeration
79 -- unless it's strictly demanded.
80
81 , upvsegd_upssegd :: !UPSSegd }
82 deriving (Show)
83
84
85 -- | Pretty print the physical representation of a `UVSegd`
86 instance PprPhysical UPVSegd where
87 pprp (UPVSegd _ vsegids upssegd)
88 = vcat
89 [ text "UPVSegd" $$ (nest 7 $ text "vsegids: " <+> (text $ show $ V.toList vsegids))
90 , pprp upssegd ]
91
92
93 -- | O(1). Check the internal consistency of a virutal segmentation descriptor.
94 --
95 -- * TODO: this doesn't do any checks yet.
96 --\b
97 valid :: UPVSegd -> Bool
98 valid UPVSegd{} = True
99 {-# NOINLINE valid #-}
100 -- NOINLINE because it's only used during debugging anyway.
101
102
103 -- Constructors ---------------------------------------------------------------
104 -- NOTE: these are NOINLINE for now just so it's easier to read the core.
105 -- we can INLINE them later.
106
107 -- | O(1). Construct a new virtual segment descriptor.
108 mkUPVSegd
109 :: Vector Int -- ^ Array saying which physical segment to use for each virtual segment.
110 -> UPSSegd -- ^ Scattered segment descriptor defining the physical segments.
111 -> UPVSegd
112
113 mkUPVSegd = UPVSegd False
114 {-# NOINLINE mkUPVSegd #-}
115
116
117 -- | O(segs). Promote a `UPSSegd` to a `UPVSegd`.
118 -- The result contains one virtual segment for every physical segment
119 -- defined by the `UPSSegd`.
120 --
121 -- TODO: make this parallel, use parallel version of enumFromTo.
122 --
123 fromUPSSegd :: UPSSegd -> UPVSegd
124 fromUPSSegd upssegd
125 = UPVSegd True
126 (V.enumFromTo 0 (UPSSegd.length upssegd - 1))
127 upssegd
128 {-# NOINLINE fromUPSSegd #-}
129
130
131 -- | O(segs). Promote a `UPSegd` to a `UPVSegd`.
132 -- All segments are assumed to come from a flat array with sourceid 0.
133 -- The result contains one virtual segment for every physical segment
134 -- the provided `UPSegd`.
135 --
136 fromUPSegd :: UPSegd -> UPVSegd
137 fromUPSegd = fromUPSSegd . UPSSegd.fromUPSegd
138 {-# NOINLINE fromUPSegd #-}
139
140
141 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
142 empty :: UPVSegd
143 empty = UPVSegd True V.empty UPSSegd.empty
144 {-# NOINLINE empty #-}
145
146
147 -- | O(1). Yield a singleton segment descriptor.
148 -- The single segment covers the given number of elements in a flat array
149 -- with sourceid 0.
150 singleton :: Int -> UPVSegd
151 singleton n = UPVSegd True (V.singleton 0) (UPSSegd.singleton n)
152 {-# NOINLINE singleton #-}
153
154
155 -- Predicates -----------------------------------------------------------------
156 -- | O(1). Checks whether all the segments are manifest (unshared / non-virtual).
157 -- If this is the case, then the vsegids field will be [0..len-1].
158 --
159 -- Consumers can check this field, avoid demanding the vsegids field.
160 -- This can avoid the need for it to be generated in the first place, due to
161 -- lazy evaluation.
162 --
163 isManifest :: UPVSegd -> Bool
164 isManifest = upvsegd_manifest
165 {-# INLINE isManifest #-}
166
167
168 -- | O(1). True when the starts are identical to the usegd indices field and
169 -- the sources are all 0's.
170 --
171 -- In this case all the data elements are in one contiguous flat
172 -- array, and consumers can avoid looking at the real starts and
173 -- sources fields.
174 --
175 isContiguous :: UPVSegd -> Bool
176 isContiguous = UPSSegd.isContiguous . upvsegd_upssegd
177 {-# INLINE isContiguous #-}
178
179
180 -- Projections ----------------------------------------------------------------
181 -- INLINE trivial projections as they'll expand to a single record selector.
182
183 -- | O(1). Yield the overall number of segments.
184 length :: UPVSegd -> Int
185 length = V.length . upvsegd_vsegids
186 {-# INLINE length #-}
187
188
189 -- | O(1). Yield the virtual segment ids of `UPVSegd`.
190 takeVSegids :: UPVSegd -> Vector Int
191 takeVSegids = upvsegd_vsegids
192 {-# INLINE takeVSegids #-}
193
194
195 -- | O(1). Yield the `UPSSegd` of `UPVSegd`.
196 takeUPSSegd :: UPVSegd -> UPSSegd
197 takeUPSSegd = upvsegd_upssegd
198 {-# INLINE takeUPSSegd #-}
199
200
201 -- | O(segs). Yield the lengths of the segments described by a `UPVSegd`.
202 --
203 -- TODO: This is slow and sequential.
204 --
205 takeLengths :: UPVSegd -> Vector Int
206 takeLengths (UPVSegd manifest vsegids upssegd)
207 | manifest = UPSSegd.takeLengths upssegd
208 | otherwise = V.map (UPSSegd.takeLengths upssegd V.!) vsegids
209 {-# NOINLINE takeLengths #-}
210 -- NOINLINE because we don't want a case expression due to the test on the
211 -- manifest flag to appear in the core program.
212
213
214 -- | O(1). Get the length, starting index, and source id of a segment.
215 --
216 -- NOTE: We don't return the segment index field from the USSegd as this refers
217 -- to the flat index relative to the SSegd array, rather than
218 -- relative to the UVSegd array. If we tried to promote the USSegd index
219 -- to a UVSegd index it could overflow.
220 --
221 getSeg :: UPVSegd -> Int -> (Int, Int, Int)
222 getSeg (UPVSegd _ vsegids upssegd) ix
223 = let (len, _index, start, source) = UPSSegd.getSeg upssegd (vsegids V.! ix)
224 in (len, start, source)
225 {-# INLINE_UP getSeg #-}
226
227
228 -- Demotion -------------------------------------------------------------------
229 -- | O(segs). Yield a `UPSSegd` that describes each segment of a `UPVSegd`
230 -- individually.
231 --
232 -- * By doing this we lose information about virtual segments corresponding
233 -- to the same physical segments.
234 --
235 -- * This operation is used in concatPR as the first step in eliminating
236 -- segmentation from a nested array.
237 --
238 demoteToUPSSegd :: UPVSegd -> UPSSegd
239 demoteToUPSSegd (UPVSegd True _vsegids upssegd)
240 = upssegd
241
242 demoteToUPSSegd (UPVSegd False vsegids upssegd)
243 = {-# SCC "demoteToUPSegd" #-}
244 let starts' = bpermuteUP (UPSSegd.takeStarts upssegd) vsegids
245 sources' = bpermuteUP (UPSSegd.takeSources upssegd) vsegids
246 lengths' = bpermuteUP (UPSSegd.takeLengths upssegd) vsegids
247 upsegd' = UPSegd.fromLengths lengths'
248 in UPSSegd.mkUPSSegd starts' sources' upsegd'
249 {-# NOINLINE demoteToUPSSegd #-}
250 -- NOINLINE because it's complicated and won't fuse with anything.
251 -- In core we want to see when VSegds are being demoted.
252
253
254 -- | O(segs). Given an virtual segment descriptor, produce a `UPSegd` that
255 -- that describes the entire array.
256 --
257 -- WARNING:
258 -- Trying to take the `UPSegd` of a nested array that has been constructed with
259 -- replication can cause index overflow. This is because the virtual size of
260 -- the corresponding flat data can be larger than physical memory.
261 --
262 -- You should only apply this function to a nested array when you're about
263 -- about to construct something with the same size as the corresponding
264 -- flat array. In this case the index overflow doesn't matter too much
265 -- because the program would OOM anyway.
266 --
267 -- TODO: if the upvsegd is manifest and contiguous this can be O(1).
268 --
269 unsafeDemoteToUPSegd :: UPVSegd -> UPSegd
270 unsafeDemoteToUPSegd (UPVSegd _ vsegids upssegd)
271 = {-# SCC "unsafeDemoteToUPSegd" #-}
272 UPSegd.fromLengths
273 $ bpermuteUP (UPSSegd.takeLengths upssegd) vsegids
274 {-# NOINLINE unsafeDemoteToUPSegd #-}
275 -- NOINLINE because it's complicated and won't fuse with anything.
276 -- In core we want to see when VSegds are being demoted.
277
278
279 -- Operators ------------------------------------------------------------------
280 -- | Update the virtual segment ids of a `UPVSegd`, and then cull the physical
281 -- segment descriptor so that all phsyical segments are reachable from
282 -- some virtual segment.
283 --
284 -- This function lets you perform filtering operations on the virtual segments,
285 -- while maintaining the invariant that all physical segments are referenced
286 -- by some virtual segment.
287 --
288 -- * TODO: make this parallel.
289 -- It runs the sequential 'cull' then reconstructs the UPSSegd.
290 --
291 updateVSegs :: (Vector Int -> Vector Int) -> UPVSegd -> UPVSegd
292 updateVSegs fUpdate (UPVSegd _ vsegids upssegd)
293 = let (vsegids', ussegd')
294 = USSegd.cullOnVSegids (fUpdate vsegids)
295 $ UPSSegd.takeUSSegd upssegd
296
297 in UPVSegd False vsegids' (UPSSegd.fromUSSegd ussegd')
298 {-# INLINE_UP updateVSegs #-}
299 -- INLINE_UP because we want to inline the parameter function fUpdate.
300
301
302 -- | Update the virtual segment ids of `UPVSegd`, where the result covers
303 -- all physical segments.
304 --
305 -- * The resulting vsegids must cover all physical segments.
306 -- If they do not then there will be physical segments that are not
307 -- reachable from some virtual segment, and performing operations like
308 -- segmented fold will waste work.
309 --
310 -- * Using this version saves performing the 'cull' operation which
311 -- discards unreachable physical segments. This is O(result segments),
312 -- but can be expensive in absolute terms.
313 --
314 updateVSegsReachable :: (Vector Int -> Vector Int) -> UPVSegd -> UPVSegd
315 updateVSegsReachable fUpdate (UPVSegd _ vsegids upssegd)
316 = UPVSegd False (fUpdate vsegids) upssegd
317 {-# INLINE_UP updateVSegsReachable #-}
318 -- INLINE_UP because we want to inline the parameter function fUpdate.
319
320
321 -- Append ---------------------------------------------------------------------
322 -- NOTE: these are NOINLINE for now just so it's easier to read the core.
323 -- we can INLINE them later.
324
325 -- | Produce a segment descriptor that describes the result of appending two arrays.
326 --
327 -- * TODO: make this parallel.
328 --
329 appendWith
330 :: UPVSegd -> Int -- ^ uvsegd of array, and number of physical data arrays
331 -> UPVSegd -> Int -- ^ uvsegd of array, and number of physical data arrays
332 -> UPVSegd
333
334 appendWith
335 (UPVSegd _ vsegids1 upssegd1) pdatas1
336 (UPVSegd _ vsegids2 upssegd2) pdatas2
337
338 = let -- vsegids releative to appended psegs
339 vsegids1' = vsegids1
340 vsegids2' = V.map (+ UPSSegd.length upssegd1) vsegids2
341
342 -- append the vsegids
343 vsegids' = vsegids1' V.++ vsegids2'
344
345 -- All data from the source arrays goes into the result
346 upssegd' = UPSSegd.appendWith
347 upssegd1 pdatas1
348 upssegd2 pdatas2
349
350 in UPVSegd False vsegids' upssegd'
351 {-# NOINLINE appendWith #-}
352
353
354 -- Combine --------------------------------------------------------------------
355 -- NOTE: these are NOINLINE for now just so it's easier to read the core.
356 -- we can INLINE them later.
357
358 -- | Combine two virtual segment descriptors.
359 --
360 -- * TODO: make this parallel.
361 --
362 combine2
363 :: UPSel2
364 -> UPVSegd -> Int -- ^ uvsegd of array, and number of physical data arrays
365 -> UPVSegd -> Int -- ^ uvsegd of array, and number of physical data arrays
366 -> UPVSegd
367
368 combine2
369 upsel2
370 (UPVSegd _ vsegids1 upssegd1) pdatas1
371 (UPVSegd _ vsegids2 upssegd2) pdatas2
372
373 = let -- vsegids relative to combined psegs
374 vsegids1' = vsegids1
375 vsegids2' = V.map (+ (V.length vsegids1)) vsegids2
376
377 -- combine the vsegids
378 vsegids' = V.combine2ByTag (UPSel.tagsUPSel2 upsel2)
379 vsegids1' vsegids2'
380
381 -- All data from the source arrays goes into the result
382 upssegd' = UPSSegd.appendWith
383 upssegd1 pdatas1
384 upssegd2 pdatas2
385
386 in UPVSegd False vsegids' upssegd'
387 {-# NOINLINE combine2 #-}
388