821638a3f145027ab54387756f6e26cd600cf3a7
[packages/dph.git] / dph-prim-seq / Data / Array / Parallel / Unlifted / Sequential / USSegd.hs
1 {-# LANGUAGE CPP #-}
2 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
3 #include "fusion-phases.h"
4
5 -- | Scattered Segment Descriptors
6 module Data.Array.Parallel.Unlifted.Sequential.USSegd (
7 -- * Types
8 USSegd(..),
9
10 -- * Consistency check
11 valid,
12
13 -- * Constructors
14 mkUSSegd,
15 empty,
16 singleton,
17 fromUSegd,
18
19 -- * Predicates
20 isContiguous,
21
22 -- * Projections
23 length,
24 takeUSegd, takeLengths, takeIndices, takeElements,
25 takeSources, takeStarts,
26
27 getSeg,
28
29 -- * Operators
30 append,
31 cullOnVSegids,
32
33 -- * Streams
34 streamSegs
35 ) where
36 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
37 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, Unbox)
38 import Data.Array.Parallel.Pretty hiding (empty)
39 import Prelude hiding (length)
40
41 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
42 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as U
43 import qualified Data.Vector as V
44 import qualified Data.Vector.Fusion.Stream as S
45 import qualified Data.Vector.Fusion.Stream.Size as S
46 import qualified Data.Vector.Fusion.Stream.Monadic as M
47
48
49 -- USSegd ---------------------------------------------------------------------
50 -- | Scatter segment descriptors are a generalisation of regular
51 -- segment descriptors of type (Segd).
52 --
53 -- * SSegd segments may be drawn from multiple physical source arrays.
54 -- * The segments need not cover the entire flat array.
55 -- * Different segments may point to the same elements.
56 --
57 -- * As different segments may point to the same elements, it is possible
58 -- for the total number of elements covered by the segment descriptor
59 -- to overflow a machine word.
60 --
61 data USSegd
62 = USSegd
63 { ussegd_contiguous :: !Bool
64 -- ^ True when the starts are identical to the usegd indices field
65 -- and the sources are all 0's.
66 -- In this case all the data elements are in one contiguous flat
67 -- array, and consumers can avoid looking at the real starts and
68 -- sources fields.
69
70 , ussegd_starts :: Vector Int
71 -- ^ Starting index of each segment in its flat array
72 -- IMPORTANT: this field is lazy so we can avoid creating it when
73 -- the flat array is contiguous.
74
75 , ussegd_sources :: Vector Int
76 -- ^ Which flat array to take each segment from.
77 -- IMPORTANT: this field is lazy so we can avoid creating it when
78 -- the flat array is contiguous.
79
80 , ussegd_usegd :: !USegd
81 -- ^ Segment descriptor relative to a contiguous index space.
82 -- This defines the length of each segment.
83 }
84 deriving (Show)
85
86
87 -- | Pretty print the physical representation of a `UVSegd`
88 instance PprPhysical USSegd where
89 pprp (USSegd _ starts srcids ssegd)
90 = vcat
91 [ text "USSegd"
92 $$ (nest 7 $ vcat
93 [ text "starts: " <+> (text $ show $ U.toList starts)
94 , text "srcids: " <+> (text $ show $ U.toList srcids) ])
95 , pprp ssegd ]
96
97
98 -- Constructors ---------------------------------------------------------------
99 -- | O(1). Construct a new scattered segment descriptor.
100 -- All the provided arrays must have the same lengths.
101 mkUSSegd
102 :: Vector Int -- ^ starting index of each segment in its flat array
103 -> Vector Int -- ^ which array to take each segment from
104 -> USegd -- ^ contiguous segment descriptor
105 -> USSegd
106
107 mkUSSegd = USSegd False
108 {-# INLINE mkUSSegd #-}
109
110
111 -- | O(1). Check the internal consistency of a scattered segment descriptor.
112 valid :: USSegd -> Bool
113 valid (USSegd _ starts srcids usegd)
114 = (U.length starts == USegd.length usegd)
115 && (U.length srcids == USegd.length usegd)
116
117 {-# NOINLINE valid #-}
118 -- NOINLINE because it's only enabled during debugging anyway.
119
120
121 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
122 empty :: USSegd
123 empty = USSegd True U.empty U.empty USegd.empty
124 {-# INLINE_U empty #-}
125
126
127 -- | O(1). Yield a singleton segment descriptor.
128 -- The single segment covers the given number of elements in a flat array
129 -- with sourceid 0.
130 singleton :: Int -> USSegd
131 singleton n
132 = USSegd True (U.singleton 0) (U.singleton 0) (USegd.singleton n)
133 {-# INLINE_U singleton #-}
134
135
136 -- | O(segs). Promote a plain USegd to a USSegd
137 -- All segments are assumed to come from a flat array with sourceid 0.
138 fromUSegd :: USegd -> USSegd
139 fromUSegd usegd
140 = USSegd True
141 (USegd.takeIndices usegd)
142 (U.replicate (USegd.length usegd) 0)
143 usegd
144 {-# INLINE_U fromUSegd #-}
145
146
147 -- Predicates -----------------------------------------------------------------
148 isContiguous :: USSegd -> Bool
149 isContiguous = ussegd_contiguous
150 {-# INLINE isContiguous #-}
151
152
153 -- Projections ----------------------------------------------------------------
154 -- INLINE trivial projections as they'll expand to a single record selector.
155
156 -- | O(1). Yield the overall number of segments.
157 length :: USSegd -> Int
158 length = USegd.length . ussegd_usegd
159 {-# INLINE length #-}
160
161
162 -- | O(1). Yield the `USegd` of a `USSegd`
163 takeUSegd :: USSegd -> USegd
164 takeUSegd = ussegd_usegd
165 {-# INLINE takeUSegd #-}
166
167
168 -- | O(1). Yield the lengths of the segments of a `USSegd`
169 takeLengths :: USSegd -> Vector Int
170 takeLengths = USegd.takeLengths . ussegd_usegd
171 {-# INLINE takeLengths #-}
172
173
174 -- | O(1). Yield the segment indices of a `USSegd`
175 takeIndices :: USSegd -> Vector Int
176 takeIndices = USegd.takeIndices . ussegd_usegd
177 {-# INLINE takeIndices #-}
178
179
180 -- | O(1). Yield the total number of elements covered by a `USSegd`
181 takeElements :: USSegd -> Int
182 takeElements = USegd.takeElements . ussegd_usegd
183 {-# INLINE takeElements #-}
184
185
186 -- | O(1). Yield the starting indices of a `USSegd`
187 takeStarts :: USSegd -> Vector Int
188 takeStarts = ussegd_starts
189 {-# INLINE takeStarts #-}
190
191
192 -- | O(1). Yield the source ids of a `USSegd`
193 takeSources :: USSegd -> Vector Int
194 takeSources = ussegd_sources
195 {-# INLINE takeSources #-}
196
197
198 -- | O(1). Get the length, segment index, starting index, and source id of a segment.
199 getSeg :: USSegd -> Int -> (Int, Int, Int, Int)
200 getSeg (USSegd _ starts sources usegd) ix
201 = let (len, index) = USegd.getSeg usegd ix
202 in ( len
203 , index
204 , starts U.! ix
205 , sources U.! ix)
206 {-# INLINE_U getSeg #-}
207
208
209 -- Operators ------------------------------------------------------------------
210 -- | O(n). Produce a segment descriptor that describes the result of appending
211 -- two arrays.
212 append :: USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
213 -> USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
214 -> USSegd
215 append (USSegd _ starts1 srcs1 usegd1) pdatas1
216 (USSegd _ starts2 srcs2 usegd2) _
217 = USSegd False
218 (starts1 U.++ starts2)
219 (srcs1 U.++ U.map (+ pdatas1) srcs2)
220 (USegd.append usegd1 usegd2)
221
222 {-# INLINE_U append #-}
223
224
225 -- | Cull the segments in a SSegd down to only those reachable from an array
226 -- of vsegids, and also update the vsegids to point to the same segments
227 -- in the result.
228 --
229 -- TODO: bpermuteDft isn't parallelised
230 --
231 cullOnVSegids :: Vector Int -> USSegd -> (Vector Int, USSegd)
232 cullOnVSegids vsegids (USSegd _ starts sources usegd)
233 = let -- Determine which of the psegs are still reachable from the vsegs.
234 -- This produces an array of flags,
235 -- with reachable psegs corresponding to 1
236 -- and unreachable psegs corresponding to 0
237 --
238 -- eg vsegids: [0 1 1 3 5 5 6 6]
239 -- => psegids_used: [1 1 0 1 0 1 1]
240 --
241 -- Note that psegids '2' and '4' are not in vsegids_packed.
242 psegids_used
243 = U.bpermuteDft (USegd.length usegd)
244 (const False)
245 (U.zip vsegids (U.replicate (U.length vsegids) True))
246
247 -- Produce an array of used psegs.
248 -- eg psegids_used: [1 1 0 1 0 1 1]
249 -- psegids_packed: [0 1 3 5 6]
250 psegids_packed
251 = U.pack (U.enumFromTo 0 (U.length psegids_used)) psegids_used
252
253 -- Produce an array that maps psegids in the source array onto
254 -- psegids in the result array. If a particular pseg isn't present
255 -- in the result this maps onto -1.
256
257 -- Note that if psegids_used has 0 in some position, then psegids_map
258 -- has -1 in the same position, corresponding to an unused pseg.
259
260 -- eg psegids_packed: [0 1 3 5 6]
261 -- [0 1 2 3 4]
262 -- psegids_map: [0 1 -1 2 -1 3 4]
263 psegids_map
264 = U.bpermuteDft (USegd.length usegd)
265 (const (-1))
266 (U.zip psegids_packed (U.enumFromTo 0 (U.length psegids_packed - 1)))
267
268 -- Use the psegids_map to rewrite the packed vsegids to point to the
269 -- corresponding psegs in the result.
270 --
271 -- eg vsegids: [0 1 1 3 5 5 6 6]
272 -- psegids_map: [0 1 -1 2 -1 3 4]
273 --
274 -- vsegids': [0 1 1 2 3 3 4 4]
275 --
276 vsegids' = U.map (psegids_map U.!) vsegids
277
278 -- Rebuild the usegd.
279 starts' = U.pack starts psegids_used
280 sources' = U.pack sources psegids_used
281
282 lengths' = U.pack (USegd.takeLengths usegd) psegids_used
283 usegd' = USegd.fromLengths lengths'
284
285 ussegd' = USSegd False starts' sources' usegd'
286
287 in (vsegids', ussegd')
288
289 {-# NOINLINE cullOnVSegids #-}
290 -- NOINLINE because it's complicated and won't fuse with anything
291
292
293
294
295 -- | Stream some physical segments from many data arrays.
296 -- TODO: make this more efficient, and fix fusion.
297 -- We should be able to eliminate a lot of the indexing happening in the
298 -- inner loop by being cleverer about the loop state.
299 --
300 -- TODO: If this is contiguous then we can stream the lot without worrying
301 -- about jumping between segments. EXCEPT that this information must be
302 -- statically visible else streamSegs won't fuse, so we can't have an
303 -- ifThenElse checking the manifest flag.
304
305 streamSegs
306 :: Unbox a
307 => USSegd -- ^ Segment descriptor defining segments based on source vectors.
308 -> V.Vector (Vector a) -- ^ Source vectors.
309 -> S.Stream a
310
311 streamSegs ussegd@(USSegd _ starts sources usegd) pdatas
312 = let
313 -- length of each segment
314 pseglens = USegd.takeLengths usegd
315
316 -- We've finished streaming this pseg
317 {-# INLINE fn #-}
318 fn (pseg, ix)
319 -- All psegs are done.
320 | pseg >= length ussegd
321 = return $ S.Done
322
323 -- Current pseg is done
324 | ix >= pseglens U.! pseg
325 = fn (pseg + 1, 0)
326
327 -- Stream an element from this pseg
328 | otherwise
329 = let srcid = sources U.! pseg
330 pdata = pdatas V.! srcid
331 start = starts U.! pseg
332 result = pdata U.! (start + ix)
333 in return $ S.Yield result
334 (pseg, ix + 1)
335
336 in M.Stream fn (0, 0) S.Unknown
337
338 {-# INLINE_STREAM streamSegs #-}
339