Fix fusion for sumPA
[packages/dph.git] / dph-prim-seq / Data / Array / Parallel / Unlifted / Sequential / USSegd.hs
1 {-# LANGUAGE CPP #-}
2 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
3 #include "fusion-phases.h"
4
5 -- | Scattered Segment Descriptors
6 module Data.Array.Parallel.Unlifted.Sequential.USSegd (
7 -- * Types
8 USSegd(..),
9
10 -- * Consistency check
11 valid,
12
13 -- * Constructors
14 mkUSSegd,
15 empty,
16 singleton,
17 fromUSegd,
18
19 -- * Predicates
20 isContiguous,
21
22 -- * Projections
23 length,
24 takeUSegd, takeLengths, takeIndices, takeElements,
25 takeSources, takeStarts,
26
27 getSeg,
28
29 -- * Operators
30 append,
31 cullOnVSegids,
32
33 -- * Streams
34 streamSegs
35 ) where
36 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
37 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, Unbox)
38 import Data.Array.Parallel.Pretty hiding (empty)
39 import Prelude hiding (length)
40
41 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
42 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as U
43 import qualified Data.Vector as V
44 import qualified Data.Vector.Fusion.Stream as S
45 import qualified Data.Vector.Fusion.Stream.Size as S
46 import qualified Data.Vector.Fusion.Stream.Monadic as M
47 import qualified Data.Vector.Unboxed as VU
48
49
50 -- USSegd ---------------------------------------------------------------------
51 -- | Scatter segment descriptors are a generalisation of regular
52 -- segment descriptors of type (Segd).
53 --
54 -- * SSegd segments may be drawn from multiple physical source arrays.
55 -- * The segments need not cover the entire flat array.
56 -- * Different segments may point to the same elements.
57 --
58 -- * As different segments may point to the same elements, it is possible
59 -- for the total number of elements covered by the segment descriptor
60 -- to overflow a machine word.
61 --
62 data USSegd
63 = USSegd
64 { ussegd_contiguous :: !Bool
65 -- ^ True when the starts are identical to the usegd indices field
66 -- and the sources are all 0's.
67 -- In this case all the data elements are in one contiguous flat
68 -- array, and consumers can avoid looking at the real starts and
69 -- sources fields.
70
71 , ussegd_starts :: Vector Int
72 -- ^ Starting index of each segment in its flat array
73 -- IMPORTANT: this field is lazy so we can avoid creating it when
74 -- the flat array is contiguous.
75
76 , ussegd_sources :: Vector Int
77 -- ^ Which flat array to take each segment from.
78 -- IMPORTANT: this field is lazy so we can avoid creating it when
79 -- the flat array is contiguous.
80
81 , ussegd_usegd :: !USegd
82 -- ^ Segment descriptor relative to a contiguous index space.
83 -- This defines the length of each segment.
84 }
85 deriving (Show)
86
87
88 -- | Pretty print the physical representation of a `UVSegd`
89 instance PprPhysical USSegd where
90 pprp (USSegd _ starts srcids ssegd)
91 = vcat
92 [ text "USSegd"
93 $$ (nest 7 $ vcat
94 [ text "starts: " <+> (text $ show $ U.toList starts)
95 , text "srcids: " <+> (text $ show $ U.toList srcids) ])
96 , pprp ssegd ]
97
98
99 -- Constructors ---------------------------------------------------------------
100 -- | O(1). Construct a new scattered segment descriptor.
101 -- All the provided arrays must have the same lengths.
102 mkUSSegd
103 :: Vector Int -- ^ starting index of each segment in its flat array
104 -> Vector Int -- ^ which array to take each segment from
105 -> USegd -- ^ contiguous segment descriptor
106 -> USSegd
107
108 mkUSSegd = USSegd False
109 {-# INLINE mkUSSegd #-}
110
111
112 -- | O(1). Check the internal consistency of a scattered segment descriptor.
113 valid :: USSegd -> Bool
114 valid (USSegd _ starts srcids usegd)
115 = (U.length starts == USegd.length usegd)
116 && (U.length srcids == USegd.length usegd)
117
118 {-# NOINLINE valid #-}
119 -- NOINLINE because it's only enabled during debugging anyway.
120
121
122 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
123 empty :: USSegd
124 empty = USSegd True U.empty U.empty USegd.empty
125 {-# INLINE_U empty #-}
126
127
128 -- | O(1). Yield a singleton segment descriptor.
129 -- The single segment covers the given number of elements in a flat array
130 -- with sourceid 0.
131 singleton :: Int -> USSegd
132 singleton n
133 = USSegd True (U.singleton 0) (U.singleton 0) (USegd.singleton n)
134 {-# INLINE_U singleton #-}
135
136
137 -- | O(segs). Promote a plain USegd to a USSegd
138 -- All segments are assumed to come from a flat array with sourceid 0.
139 fromUSegd :: USegd -> USSegd
140 fromUSegd usegd
141 = USSegd True
142 (USegd.takeIndices usegd)
143 (U.replicate (USegd.length usegd) 0)
144 usegd
145 {-# INLINE_U fromUSegd #-}
146
147
148 -- Predicates -----------------------------------------------------------------
149 isContiguous :: USSegd -> Bool
150 isContiguous = ussegd_contiguous
151 {-# INLINE isContiguous #-}
152
153
154 -- Projections ----------------------------------------------------------------
155 -- INLINE trivial projections as they'll expand to a single record selector.
156
157 -- | O(1). Yield the overall number of segments.
158 length :: USSegd -> Int
159 length = USegd.length . ussegd_usegd
160 {-# INLINE length #-}
161
162
163 -- | O(1). Yield the `USegd` of a `USSegd`
164 takeUSegd :: USSegd -> USegd
165 takeUSegd = ussegd_usegd
166 {-# INLINE takeUSegd #-}
167
168
169 -- | O(1). Yield the lengths of the segments of a `USSegd`
170 takeLengths :: USSegd -> Vector Int
171 takeLengths = USegd.takeLengths . ussegd_usegd
172 {-# INLINE takeLengths #-}
173
174
175 -- | O(1). Yield the segment indices of a `USSegd`
176 takeIndices :: USSegd -> Vector Int
177 takeIndices = USegd.takeIndices . ussegd_usegd
178 {-# INLINE takeIndices #-}
179
180
181 -- | O(1). Yield the total number of elements covered by a `USSegd`
182 takeElements :: USSegd -> Int
183 takeElements = USegd.takeElements . ussegd_usegd
184 {-# INLINE takeElements #-}
185
186
187 -- | O(1). Yield the starting indices of a `USSegd`
188 takeStarts :: USSegd -> Vector Int
189 takeStarts = ussegd_starts
190 {-# INLINE takeStarts #-}
191
192
193 -- | O(1). Yield the source ids of a `USSegd`
194 takeSources :: USSegd -> Vector Int
195 takeSources = ussegd_sources
196 {-# INLINE takeSources #-}
197
198
199 -- | O(1). Get the length, segment index, starting index, and source id of a segment.
200 getSeg :: USSegd -> Int -> (Int, Int, Int, Int)
201 getSeg (USSegd _ starts sources usegd) ix
202 = let (len, index) = USegd.getSeg usegd ix
203 in ( len
204 , index
205 , starts U.! ix
206 , sources U.! ix)
207 {-# INLINE_U getSeg #-}
208
209
210 -- Operators ------------------------------------------------------------------
211 -- | O(n). Produce a segment descriptor that describes the result of appending
212 -- two arrays.
213 append :: USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
214 -> USSegd -> Int -- ^ ussegd of array, and number of physical data arrays
215 -> USSegd
216 append (USSegd _ starts1 srcs1 usegd1) pdatas1
217 (USSegd _ starts2 srcs2 usegd2) _
218 = USSegd False
219 (starts1 U.++ starts2)
220 (srcs1 U.++ U.map (+ pdatas1) srcs2)
221 (USegd.append usegd1 usegd2)
222
223 {-# INLINE_U append #-}
224
225
226 -- | Cull the segments in a SSegd down to only those reachable from an array
227 -- of vsegids, and also update the vsegids to point to the same segments
228 -- in the result.
229 --
230 -- TODO: bpermuteDft isn't parallelised
231 --
232 cullOnVSegids :: Vector Int -> USSegd -> (Vector Int, USSegd)
233 cullOnVSegids vsegids (USSegd _ starts sources usegd)
234 = let -- Determine which of the psegs are still reachable from the vsegs.
235 -- This produces an array of flags,
236 -- with reachable psegs corresponding to 1
237 -- and unreachable psegs corresponding to 0
238 --
239 -- eg vsegids: [0 1 1 3 5 5 6 6]
240 -- => psegids_used: [1 1 0 1 0 1 1]
241 --
242 -- Note that psegids '2' and '4' are not in vsegids_packed.
243 psegids_used
244 = U.bpermuteDft (USegd.length usegd)
245 (const False)
246 (U.zip vsegids (U.replicate (U.length vsegids) True))
247
248 -- Produce an array of used psegs.
249 -- eg psegids_used: [1 1 0 1 0 1 1]
250 -- psegids_packed: [0 1 3 5 6]
251 psegids_packed
252 = U.pack (U.enumFromTo 0 (U.length psegids_used)) psegids_used
253
254 -- Produce an array that maps psegids in the source array onto
255 -- psegids in the result array. If a particular pseg isn't present
256 -- in the result this maps onto -1.
257
258 -- Note that if psegids_used has 0 in some position, then psegids_map
259 -- has -1 in the same position, corresponding to an unused pseg.
260
261 -- eg psegids_packed: [0 1 3 5 6]
262 -- [0 1 2 3 4]
263 -- psegids_map: [0 1 -1 2 -1 3 4]
264 psegids_map
265 = U.bpermuteDft (USegd.length usegd)
266 (const (-1))
267 (U.zip psegids_packed (U.enumFromTo 0 (U.length psegids_packed - 1)))
268
269 -- Use the psegids_map to rewrite the packed vsegids to point to the
270 -- corresponding psegs in the result.
271 --
272 -- eg vsegids: [0 1 1 3 5 5 6 6]
273 -- psegids_map: [0 1 -1 2 -1 3 4]
274 --
275 -- vsegids': [0 1 1 2 3 3 4 4]
276 --
277 vsegids' = U.map (psegids_map U.!) vsegids
278
279 -- Rebuild the usegd.
280 starts' = U.pack starts psegids_used
281 sources' = U.pack sources psegids_used
282
283 lengths' = U.pack (USegd.takeLengths usegd) psegids_used
284 usegd' = USegd.fromLengths lengths'
285
286 ussegd' = USSegd False starts' sources' usegd'
287
288 in (vsegids', ussegd')
289
290 {-# NOINLINE cullOnVSegids #-}
291 -- NOINLINE because it's complicated and won't fuse with anything
292
293
294
295
296 -- | Stream some physical segments from many data arrays.
297 -- TODO: make this more efficient, and fix fusion.
298 -- We should be able to eliminate a lot of the indexing happening in the
299 -- inner loop by being cleverer about the loop state.
300 --
301 -- TODO: If this is contiguous then we can stream the lot without worrying
302 -- about jumping between segments. EXCEPT that this information must be
303 -- statically visible else streamSegs won't fuse, so we can't have an
304 -- ifThenElse checking the manifest flag.
305
306 streamSegs
307 :: Unbox a
308 => USSegd -- ^ Segment descriptor defining segments based on source vectors.
309 -> V.Vector (Vector a) -- ^ Source vectors.
310 -> S.Stream a
311
312 streamSegs ussegd@(USSegd _ starts sources usegd) pdatas
313 = let
314 -- length of each segment
315 pseglens = USegd.takeLengths usegd
316
317 -- We've finished streaming this pseg
318 {-# INLINE_INNER fn #-}
319 fn (pseg, ix)
320 -- All psegs are done.
321 | pseg >= length ussegd
322 = return $ S.Done
323
324 -- Current pseg is done
325 | ix >= pseglens `VU.unsafeIndex` pseg
326 = return $ S.Skip (pseg + 1, 0)
327
328 -- Stream an element from this pseg
329 | otherwise
330 = let !srcid = sources `VU.unsafeIndex` pseg
331 !pdata = pdatas `V.unsafeIndex` srcid
332 !start = starts `VU.unsafeIndex` pseg
333 !result = pdata `VU.unsafeIndex` (start + ix)
334 in return $ S.Yield result (pseg, ix + 1)
335
336 in M.Stream fn (0, 0) S.Unknown
337
338 {-# INLINE_STREAM streamSegs #-}
339