68c4c5dc5b13f9a751b290bbea1781e08bcd099e
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPSSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 -- | Parallel Scattered Segment descriptors.
5 --
6 -- See "Data.Array.Parallel.Unlifted" for how this works.
7 --
8 module Data.Array.Parallel.Unlifted.Parallel.UPSSegd
9 ( -- * Types
10 UPSSegd, valid
11
12 -- * Constructors
13 , mkUPSSegd, fromUSSegd, fromUPSegd
14 , empty, singleton
15
16 -- * Predicates
17 , isContiguous
18
19 -- * Projections
20 , length
21 , takeUSSegd
22 , takeDistributed
23 , takeLengths
24 , takeIndices
25 , takeElements
26 , takeStarts
27 , takeSources
28 , getSeg
29
30 -- * Append
31 , appendWith
32
33 -- * Segmented Folds
34 , foldWithP
35 , fold1WithP
36 , sumWithP
37 , foldSegsWithP)
38 where
39 import Data.Array.Parallel.Pretty hiding (empty)
40 import Data.Array.Parallel.Unlifted.Distributed
41 import Data.Array.Parallel.Unlifted.Parallel.UPSegd (UPSegd)
42 import Data.Array.Parallel.Unlifted.Sequential.USSegd (USSegd)
43 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
44 import Data.Array.Parallel.Unlifted.Vectors (Vectors, Unboxes)
45 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSegd as UPSegd
46 import qualified Data.Array.Parallel.Unlifted.Distributed.USSegd as DUSSegd
47 import qualified Data.Array.Parallel.Unlifted.Sequential.USSegd as USSegd
48 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as US
49 import qualified Data.Array.Parallel.Unlifted.Sequential as Seq
50 import Control.Monad.ST
51 import Prelude hiding (length)
52
53 here :: String -> String
54 here s = "Data.Array.Parallel.Unlifted.Parallel.UPSSegd." ++ s
55
56
57 -- | Parallel Scattered Segment sescriptor
58 data UPSSegd
59 = UPSSegd
60 { upssegd_ussegd :: !USSegd
61 -- ^ Segment descriptor that describes the whole array.
62
63 , upssegd_dssegd :: Dist ((USSegd,Int),Int)
64 -- ^ Segment descriptor for each chunk,
65 -- along with segment id of first slice in the chunk,
66 -- and the offset of that slice in its segment.
67 -- See docs of `splitSegdOfElemsD` for an example.
68 }
69 deriving Show
70
71
72 instance PprPhysical UPSSegd where
73 pprp (UPSSegd ussegd dssegd)
74 = text "UPSSegd"
75 $$ (nest 7 $ vcat
76 [ text "ussegd: " <+> pprp ussegd
77 , text "dssegd: " <+> pprp dssegd])
78
79
80 -- | O(1).
81 -- Check the internal consistency of a scattered segment descriptor.
82 ---
83 -- * TODO: this doesn't do any checks yet
84 valid :: UPSSegd -> Bool
85 valid _ = True
86 {-# NOINLINE valid #-}
87 -- NOINLINE because it's only used during debugging anyway.
88
89
90 -- Constructors ---------------------------------------------------------------
91 -- | Construct a new segment descriptor.
92 mkUPSSegd
93 :: Vector Int -- ^ Starting index of each segment in its flat array.
94 -> Vector Int -- ^ Source id of the flat array to tach each segment from.
95 -> UPSegd -- ^ Contiguous (unscattered) segment descriptor.
96 -> UPSSegd
97
98 mkUPSSegd starts sources upsegd
99 = fromUSSegd (USSegd.mkUSSegd starts sources (UPSegd.takeUSegd upsegd))
100 {-# INLINE_UP mkUPSSegd #-}
101
102
103 -- | Promote a global `USSegd` to a parallel `UPSSegd` by distributing
104 -- it across the gang.
105 fromUSSegd :: USSegd -> UPSSegd
106 fromUSSegd ssegd
107 = UPSSegd ssegd (DUSSegd.splitSSegdOnElemsD theGang ssegd)
108 {-# INLINE_UP fromUSSegd #-}
109
110
111 -- | Promote a plain `UPSegd` to a `UPSSegd`, by assuming that all segments
112 -- come from a single flat array with source id 0.
113 ---
114 -- * TODO:
115 -- This sequentially constructs the indices and source fields, and we
116 -- throw out the existing distributed `USegd`. We could probably keep
117 -- some of the existing fields and save reconstructing them.
118 --
119 fromUPSegd :: UPSegd -> UPSSegd
120 fromUPSegd upsegd
121 = fromUSSegd $ USSegd.fromUSegd $ UPSegd.takeUSegd upsegd
122 {-# INLINE_UP fromUPSegd #-}
123
124
125 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
126 empty :: UPSSegd
127 empty = fromUSSegd USSegd.empty
128 {-# INLINE_UP empty #-}
129
130
131 -- | O(1).
132 -- Yield a singleton segment descriptor.
133 -- The single segment covers the given number of elements.
134 singleton :: Int -> UPSSegd
135 singleton n = fromUSSegd $ USSegd.singleton n
136 {-# INLINE_UP singleton #-}
137
138
139 -- Predicates -----------------------------------------------------------------
140 -- INLINE trivial predicates as they'll expand to a simple calls.
141
142 -- | O(1). True when the starts are identical to the usegd indices field and
143 -- the sources are all 0's.
144 --
145 -- In this case all the data elements are in one contiguous flat
146 -- array, and consumers can avoid looking at the real starts and
147 -- sources fields.
148 --
149 isContiguous :: UPSSegd -> Bool
150 isContiguous = USSegd.isContiguous . upssegd_ussegd
151 {-# INLINE isContiguous #-}
152
153
154 -- Projections ----------------------------------------------------------------
155 -- INLINE trivial projections as they'll expand to a single record selector.
156
157 -- | O(1). Yield the overall number of segments.
158 length :: UPSSegd -> Int
159 length = USSegd.length . upssegd_ussegd
160 {-# INLINE length #-}
161
162 -- | O(1). Yield the global `USegd` of a `UPSegd`
163 takeUSSegd :: UPSSegd -> USSegd
164 takeUSSegd = upssegd_ussegd
165 {-# INLINE takeUSSegd #-}
166
167
168 -- | O(1). Yield the distributed `USegd` of a `UPSegd`
169 takeDistributed :: UPSSegd -> Dist ((USSegd, Int), Int)
170 takeDistributed = upssegd_dssegd
171 {-# INLINE takeDistributed #-}
172
173
174 -- | O(1). Yield the lengths of the individual segments.
175 takeLengths :: UPSSegd -> Vector Int
176 takeLengths = USSegd.takeLengths . upssegd_ussegd
177 {-# INLINE takeLengths #-}
178
179
180 -- | O(1). Yield the segment indices.
181 takeIndices :: UPSSegd -> Vector Int
182 takeIndices = USSegd.takeIndices . upssegd_ussegd
183 {-# INLINE takeIndices #-}
184
185
186 -- | O(1). Yield the total number of data elements.
187 --
188 -- @takeElements upssegd = sum (takeLengths upssegd)@
189 --
190 takeElements :: UPSSegd -> Int
191 takeElements = USSegd.takeElements . upssegd_ussegd
192 {-# INLINE takeElements #-}
193
194
195 -- | O(1). Yield the starting indices.
196 takeStarts :: UPSSegd -> Vector Int
197 takeStarts = USSegd.takeStarts . upssegd_ussegd
198 {-# INLINE takeStarts #-}
199
200
201 -- | O(1). Yield the source ids.
202 takeSources :: UPSSegd -> Vector Int
203 takeSources = USSegd.takeSources . upssegd_ussegd
204 {-# INLINE takeSources #-}
205
206
207 -- | O(1). Get the length, segment index, starting index, and source id of a segment.
208 getSeg :: UPSSegd -> Int -> (Int, Int, Int, Int)
209 getSeg upssegd ix
210 = USSegd.getSeg (upssegd_ussegd upssegd) ix
211 {-# INLINE_UP getSeg #-}
212
213
214 -- Append ---------------------------------------------------------------------
215 -- | O(n)
216 -- Produce a segment descriptor that describes the result of appending two
217 -- segmented arrays.
218 --
219 -- Appending two nested arrays is an index space transformation. Because
220 -- a `UPSSegd` can contain segments from multiple flat data arrays, we can
221 -- represent the result of the append without copying elements from the
222 -- underlying flat data arrays.
223 ---
224 -- * TODO: This calls out to the sequential version.
225 --
226 appendWith
227 :: UPSSegd -- ^ Segment descriptor of first nested array.
228 -> Int -- ^ Number of flat data arrays used to represent first nested array.
229 -> UPSSegd -- ^ Segment descriptor of second nested array.
230 -> Int -- ^ Number of flat data arrays used to represent second nested array.
231 -> UPSSegd
232 appendWith upssegd1 pdatas1
233 upssegd2 pdatas2
234 = fromUSSegd
235 $ USSegd.appendWith
236 (upssegd_ussegd upssegd1) pdatas1
237 (upssegd_ussegd upssegd2) pdatas2
238 {-# NOINLINE appendWith #-}
239 -- NOINLINE because we're not using it yet.
240
241
242 -- Fold -----------------------------------------------------------------------
243 -- | Fold segments specified by a `UPSSegd`.
244 foldWithP :: (Unbox a, Unboxes a)
245 => (a -> a -> a) -> a -> UPSSegd -> Vectors a -> Vector a
246 foldWithP f !z = foldSegsWithP f (Seq.foldlSSU f z)
247 {-# INLINE_UP foldWithP #-}
248
249
250 -- | Fold segments specified by a `UPSSegd`, with a non-empty vector.
251 fold1WithP :: (Unbox a, Unboxes a)
252 => (a -> a -> a) -> UPSSegd -> Vectors a -> Vector a
253 fold1WithP f = foldSegsWithP f (Seq.fold1SSU f)
254 {-# INLINE_UP fold1WithP #-}
255
256
257 -- | Sum up segments specified by a `UPSSegd`.
258 sumWithP :: (Num a, Unbox a, Unboxes a)
259 => UPSSegd -> Vectors a -> Vector a
260 sumWithP = foldWithP (+) 0
261 {-# INLINE_UP sumWithP #-}
262
263
264 -- | Fold the segments specified by a `UPSSegd`.
265 --
266 -- Low level function takes a per-element worker and a per-segment worker.
267 -- It folds all the segments with the per-segment worker, then uses the
268 -- per-element worker to fixup the partial results when a segment
269 -- is split across multiple threads.
270 --
271 foldSegsWithP
272 :: (Unbox a, Unboxes a)
273 => (a -> a -> a)
274 -> (USSegd -> Vectors a -> Vector a)
275 -> UPSSegd -> Vectors a -> Vector a
276
277 foldSegsWithP fElem fSeg segd xss
278 = dcarry `seq` drs `seq`
279 runST (do
280 mrs <- joinDM theGang drs
281 fixupFold fElem mrs dcarry
282 US.unsafeFreeze mrs)
283
284 where (dcarry,drs)
285 = unzipD
286 $ mapD theGang partial (takeDistributed segd)
287
288 partial ((ssegd, k), off)
289 = let rs = fSeg ssegd xss
290 {-# INLINE [0] n #-}
291 n | off == 0 = 0
292 | otherwise = 1
293
294 in ((k, US.take n rs), US.drop n rs)
295 {-# INLINE_UP foldSegsWithP #-}
296
297
298 fixupFold
299 :: Unbox a
300 => (a -> a -> a)
301 -> MVector s a
302 -> Dist (Int,Vector a)
303 -> ST s ()
304
305 fixupFold f !mrs !dcarry = go 1
306 where
307 !p = gangSize theGang
308
309 go i | i >= p = return ()
310 | US.null c = go (i+1)
311 | otherwise
312 = do x <- US.read mrs k
313 US.write mrs k (f x (US.index (here "fixupFold") c 0))
314 go (i + 1)
315 where
316 (k,c) = indexD (here "fixupFold") dcarry i
317 {-# NOINLINE fixupFold #-}
318