5ab5457fff29fe0efcfa1664a983f524fd63d0c4
[packages/dph.git] / dph-prim-seq / Data / Array / Parallel / Unlifted / Sequential / USSegd.hs
1 {-# LANGUAGE CPP #-}
2 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
3 #include "fusion-phases.h"
4
5 -- | Scattered Segment Descriptors
6 module Data.Array.Parallel.Unlifted.Sequential.USSegd (
7 -- * Types
8 USSegd(..),
9
10 -- * Consistency check
11 valid,
12
13 -- * Constructors
14 mkUSSegd,
15 empty,
16 singleton,
17 fromUSegd,
18
19 -- * Predicates
20 isContiguous,
21
22 -- * Projections
23 length,
24 takeUSegd, takeLengths, takeIndices, takeElements,
25 takeSources, takeStarts,
26
27 getSeg,
28
29 -- * Operators
30 append,
31 cullOnVSegids,
32
33 -- * Streams
34 streamSegs
35 ) where
36 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
37 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, Unbox)
38 import Data.Array.Parallel.Pretty hiding (empty)
39 import Prelude hiding (length)
40
41 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
42 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as U
43 import qualified Data.Vector as V
44 import qualified Data.Vector.Fusion.Stream as S
45 import qualified Data.Vector.Fusion.Stream.Size as S
46 import qualified Data.Vector.Fusion.Stream.Monadic as M
47 import qualified Data.Vector.Unboxed as VU
48
49
50 -- USSegd ---------------------------------------------------------------------
51 -- | Scatter segment descriptors are a generalisation of regular
52 -- segment descriptors of type (Segd).
53 --
54 -- * SSegd segments may be drawn from multiple physical source arrays.
55 -- * The segments need not cover the entire flat array.
56 -- * Different segments may point to the same elements.
57 --
58 -- * As different segments may point to the same elements, it is possible
59 -- for the total number of elements covered by the segment descriptor
60 -- to overflow a machine word.
61 --
62 data USSegd
63 = USSegd
64 { ussegd_contiguous :: !Bool
65 -- ^ True when the starts are identical to the usegd indices field
66 -- and the sources are all 0's.
67 -- In this case all the data elements are in one contiguous flat
68 -- array, and consumers can avoid looking at the real starts and
69 -- sources fields.
70
71 , ussegd_starts :: Vector Int
72 -- ^ Starting index of each segment in its flat array
73 -- IMPORTANT: this field is lazy so we can avoid creating it when
74 -- the flat array is contiguous.
75
76 , ussegd_sources :: Vector Int
77 -- ^ Which flat array to take each segment from.
78 -- IMPORTANT: this field is lazy so we can avoid creating it when
79 -- the flat array is contiguous.
80
81 , ussegd_usegd :: !USegd
82 -- ^ Segment descriptor relative to a contiguous index space.
83 -- This defines the length of each segment.
84 }
85 deriving (Show)
86
87
88 -- | Pretty print the physical representation of a `UVSegd`
89 instance PprPhysical USSegd where
90 pprp (USSegd _ starts srcids ssegd)
91 = vcat
92 [ text "USSegd"
93 $$ (nest 7 $ vcat
94 [ text "starts: " <+> (text $ show $ U.toList starts)
95 , text "srcids: " <+> (text $ show $ U.toList srcids) ])
96 , pprp ssegd ]
97
98
99 -- Constructors ---------------------------------------------------------------
100 -- | O(1). Construct a new scattered segment descriptor.
101 -- All the provided arrays must have the same lengths.
102 mkUSSegd
103 :: Vector Int -- ^ starting index of each segment in its flat array
104 -> Vector Int -- ^ which array to take each segment from
105 -> USegd -- ^ contiguous segment descriptor
106 -> USSegd
107
108 mkUSSegd = USSegd False
109 {-# INLINE mkUSSegd #-}
110
111
112 -- | O(1). Check the internal consistency of a scattered segment descriptor.
113 valid :: USSegd -> Bool
114 valid (USSegd _ starts srcids usegd)
115 = (U.length starts == USegd.length usegd)
116 && (U.length srcids == USegd.length usegd)
117
118 {-# NOINLINE valid #-}
119 -- NOINLINE because it's only enabled during debugging anyway.
120
121
122 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
123 empty :: USSegd
124 empty = USSegd True U.empty U.empty USegd.empty
125 {-# INLINE_U empty #-}
126
127
128 -- | O(1). Yield a singleton segment descriptor.
129 -- The single segment covers the given number of elements in a flat array
130 -- with sourceid 0.
131 singleton :: Int -> USSegd
132 singleton n
133 = USSegd True (U.singleton 0) (U.singleton 0) (USegd.singleton n)
134 {-# INLINE_U singleton #-}
135
136
137 -- | O(segs). Promote a plain USegd to a USSegd
138 -- All segments are assumed to come from a flat array with sourceid 0.
139 fromUSegd :: USegd -> USSegd
140 fromUSegd usegd
141 = USSegd True
142 (USegd.takeIndices usegd)
143 (U.replicate (USegd.length usegd) 0)
144 usegd
145 {-# INLINE_U fromUSegd #-}
146
147
148 -- Predicates -----------------------------------------------------------------
149 -- INLINE trivial projections as they'll expand to a single record selector.
150 isContiguous :: USSegd -> Bool
151 isContiguous = ussegd_contiguous
152 {-# INLINE isContiguous #-}
153
154
155 -- Projections ----------------------------------------------------------------
156 -- INLINE trivial projections as they'll expand to a single record selector.
157
158 -- | O(1). Yield the overall number of segments.
159 length :: USSegd -> Int
160 length = USegd.length . ussegd_usegd
161 {-# INLINE length #-}
162
163
164 -- | O(1). Yield the `USegd` of a `USSegd`
165 takeUSegd :: USSegd -> USegd
166 takeUSegd = ussegd_usegd
167 {-# INLINE takeUSegd #-}
168
169
170 -- | O(1). Yield the lengths of the segments of a `USSegd`
171 takeLengths :: USSegd -> Vector Int
172 takeLengths = USegd.takeLengths . ussegd_usegd
173 {-# INLINE takeLengths #-}
174
175
176 -- | O(1). Yield the segment indices of a `USSegd`
177 takeIndices :: USSegd -> Vector Int
178 takeIndices = USegd.takeIndices . ussegd_usegd
179 {-# INLINE takeIndices #-}
180
181
182 -- | O(1). Yield the total number of elements covered by a `USSegd`
183 takeElements :: USSegd -> Int
184 takeElements = USegd.takeElements . ussegd_usegd
185 {-# INLINE takeElements #-}
186
187
188 -- | O(1). Yield the starting indices of a `USSegd`
189 takeStarts :: USSegd -> Vector Int
190 takeStarts = ussegd_starts
191 {-# INLINE takeStarts #-}
192
193
194 -- | O(1). Yield the source ids of a `USSegd`
195 takeSources :: USSegd -> Vector Int
196 takeSources = ussegd_sources
197 {-# INLINE takeSources #-}
198
199
200 -- | O(1). Get the length, segment index, starting index, and source id of a segment.
201 getSeg :: USSegd -> Int -> (Int, Int, Int, Int)
202 getSeg (USSegd _ starts sources usegd) ix
203 = let (len, index) = USegd.getSeg usegd ix
204 in ( len
205 , index
206 , starts U.! ix
207 , sources U.! ix)
208 {-# INLINE_U getSeg #-}
209
210
211 -- Operators ------------------------------------------------------------------
212 -- | O(n). Produce a segment descriptor that describes the result of appending
213 -- two arrays.
214 append :: USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
215 -> USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
216 -> USSegd
217 append (USSegd _ starts1 srcs1 usegd1) pdatas1
218 (USSegd _ starts2 srcs2 usegd2) _
219 = USSegd False
220 (starts1 U.++ starts2)
221 (srcs1 U.++ U.map (+ pdatas1) srcs2)
222 (USegd.append usegd1 usegd2)
223
224 {-# INLINE_U append #-}
225
226
227 -- | Cull the segments in a SSegd down to only those reachable from an array
228 -- of vsegids, and also update the vsegids to point to the same segments
229 -- in the result.
230 --
231 -- TODO: bpermuteDft isn't parallelised
232 --
233 cullOnVSegids :: Vector Int -> USSegd -> (Vector Int, USSegd)
234 cullOnVSegids vsegids (USSegd _ starts sources usegd)
235 = {-# SCC "cullOnVSegids" #-}
236 let -- Determine which of the psegs are still reachable from the vsegs.
237 -- This produces an array of flags,
238 -- with reachable psegs corresponding to 1
239 -- and unreachable psegs corresponding to 0
240 --
241 -- eg vsegids: [0 1 1 3 5 5 6 6]
242 -- => psegids_used: [1 1 0 1 0 1 1]
243 --
244 -- Note that psegids '2' and '4' are not in vsegids_packed.
245 psegids_used
246 = U.bpermuteDft (USegd.length usegd)
247 (const False)
248 (U.zip vsegids (U.replicate (U.length vsegids) True))
249
250 -- Produce an array of used psegs.
251 -- eg psegids_used: [1 1 0 1 0 1 1]
252 -- psegids_packed: [0 1 3 5 6]
253 psegids_packed
254 = U.pack (U.enumFromTo 0 (U.length psegids_used)) psegids_used
255
256 -- Produce an array that maps psegids in the source array onto
257 -- psegids in the result array. If a particular pseg isn't present
258 -- in the result this maps onto -1.
259
260 -- Note that if psegids_used has 0 in some position, then psegids_map
261 -- has -1 in the same position, corresponding to an unused pseg.
262
263 -- eg psegids_packed: [0 1 3 5 6]
264 -- [0 1 2 3 4]
265 -- psegids_map: [0 1 -1 2 -1 3 4]
266 psegids_map
267 = U.bpermuteDft (USegd.length usegd)
268 (const (-1))
269 (U.zip psegids_packed (U.enumFromTo 0 (U.length psegids_packed - 1)))
270
271 -- Use the psegids_map to rewrite the packed vsegids to point to the
272 -- corresponding psegs in the result.
273 --
274 -- eg vsegids: [0 1 1 3 5 5 6 6]
275 -- psegids_map: [0 1 -1 2 -1 3 4]
276 --
277 -- vsegids': [0 1 1 2 3 3 4 4]
278 --
279 vsegids' = U.map (psegids_map U.!) vsegids
280
281 -- Rebuild the usegd.
282 starts' = U.pack starts psegids_used
283 sources' = U.pack sources psegids_used
284
285 lengths' = U.pack (USegd.takeLengths usegd) psegids_used
286 usegd' = USegd.fromLengths lengths'
287
288 ussegd' = USSegd False starts' sources' usegd'
289
290 in (vsegids', ussegd')
291
292 {-# NOINLINE cullOnVSegids #-}
293 -- NOINLINE because it's complicated and won't fuse with anything
294 -- This can also be expensive and we want to see the SCC in profiling builds.
295
296
297 -- Stream Functions -----------------------------------------------------------
298 -- | Stream some physical segments from many data arrays.
299 -- TODO: make this more efficient, and fix fusion.
300 -- We should be able to eliminate a lot of the indexing happening in the
301 -- inner loop by being cleverer about the loop state.
302 --
303 -- TODO: If this is contiguous then we can stream the lot without worrying
304 -- about jumping between segments. EXCEPT that this information must be
305 -- statically visible else streamSegs won't fuse, so we can't have an
306 -- ifThenElse checking the manifest flag.
307
308 streamSegs
309 :: Unbox a
310 => USSegd -- ^ Segment descriptor defining segments base
311 -- on source vectors.
312 -> V.Vector (Vector a) -- ^ Source vectors.
313 -> S.Stream a
314
315 {-# INLINE_STREAM streamSegs #-}
316 streamSegs ussegd@(USSegd _ starts sources usegd) pdatas
317 = let
318 -- length of each segment
319 pseglens = USegd.takeLengths usegd
320
321 -- We've finished streaming this pseg
322 {-# INLINE_INNER fn #-}
323 fn (pseg, ix)
324 -- All psegs are done.
325 | pseg >= length ussegd
326 = return $ S.Done
327
328 -- Current pseg is done
329 | ix >= pseglens `VU.unsafeIndex` pseg
330 = return $ S.Skip (pseg + 1, 0)
331
332 -- Stream an element from this pseg
333 | otherwise
334 = let !srcid = sources `VU.unsafeIndex` pseg
335 !pdata = pdatas `V.unsafeIndex` srcid
336 !start = starts `VU.unsafeIndex` pseg
337 !result = pdata `VU.unsafeIndex` (start + ix)
338 in return $ S.Yield result (pseg, ix + 1)
339
340 in M.Stream fn (0, 0) S.Unknown
341
342