dph-prim-par: cleanup API for distributed Segds
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 -- | Parallel segment descriptors.
5 --
6 -- See "Data.Array.Parallel.Unlifted" for how this works.
7 --
8 module Data.Array.Parallel.Unlifted.Parallel.UPSegd
9 ( -- * Types
10 UPSegd(..)
11 , valid
12
13 -- * Constructors
14 , mkUPSegd, fromUSegd
15 , empty, singleton, fromLengths
16
17 -- * Projections
18 , length
19 , takeUSegd
20 , takeDistributed
21 , takeLengths
22 , takeIndices
23 , takeElements
24
25 -- * Indices
26 , indicesP
27
28 -- * Replicate
29 , replicateWithP
30
31 -- * Segmented Folds
32 , foldWithP
33 , fold1WithP
34 , sumWithP
35 , foldSegsWithP)
36 where
37 import Data.Array.Parallel.Unlifted.Distributed
38 import Data.Array.Parallel.Unlifted.Distributed.What
39 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
40 import qualified Data.Array.Parallel.Unlifted.Distributed.Data.USegd as USegd
41 import qualified Data.Array.Parallel.Unlifted.Sequential as Seq
42 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as US
43 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
44 import Data.Array.Parallel.Pretty hiding (empty)
45 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
46 import Control.Monad.ST
47 import Prelude hiding (length)
48
49 here :: String -> String
50 here s = "Data.Array.Parallel.Unlifted.Parallel.UPSegd." ++ s
51
52
53 -- | A parallel segment descriptor holds a global (undistributed) segment
54 -- desciptor, as well as a distributed version. The distributed version
55 -- describes how to split work on the segmented array over the gang.
56 data UPSegd
57 = UPSegd
58 { upsegd_usegd :: !USegd
59 -- ^ Segment descriptor that describes the whole array.
60
61 , upsegd_dsegd :: Dist ((USegd,Int),Int)
62 -- ^ Segment descriptor for each chunk,
63 -- along with segment id of first slice in the chunk,
64 -- and the offset of that slice in its segment.
65 -- See docs of `splitSegdOfElemsD` for an example.
66 }
67
68
69 -- Pretty ---------------------------------------------------------------------
70 instance PprPhysical UPSegd where
71 pprp (UPSegd usegd dsegd)
72 = text "UPSegd"
73 $$ (nest 7 $ vcat
74 [ text "usegd: " <+> pprp usegd
75 , text "dsegd: " <+> pprp dsegd])
76
77
78 -- Valid ----------------------------------------------------------------------
79 -- | O(1).
80 -- Check the internal consistency of a parallel segment descriptor.
81 ---
82 -- * TODO: this doesn't do any checks yet
83 valid :: UPSegd -> Bool
84 valid _ = True
85 {-# NOINLINE valid #-}
86 -- NOINLINE because it's only used during debugging anyway.
87
88
89 -- Constructors ---------------------------------------------------------------
90 -- | O(1). Construct a new parallel segment descriptor.
91 mkUPSegd
92 :: Vector Int -- ^ Length of each segment.
93 -> Vector Int -- ^ Starting index of each segment.
94 -> Int -- ^ Total number of elements in the flat array.
95 -> UPSegd
96
97 mkUPSegd lens idxs n
98 = fromUSegd (USegd.mkUSegd lens idxs n)
99 {-# INLINE_UP mkUPSegd #-}
100
101
102 -- | Convert a global `USegd` to a parallel `UPSegd` by distributing
103 -- it across the gang.
104 fromUSegd :: USegd -> UPSegd
105 fromUSegd segd = UPSegd segd (USegd.splitSegdOnElemsD theGang segd)
106 {-# INLINE_UP fromUSegd #-}
107
108
109 -- | O(1). Construct an empty segment descriptor, with no elements or segments.
110 empty :: UPSegd
111 empty = fromUSegd USegd.empty
112 {-# INLINE_UP empty #-}
113
114
115 -- | O(1). Construct a singleton segment descriptor.
116 -- The single segment covers the given number of elements.
117 singleton :: Int -> UPSegd
118 singleton n = fromUSegd $ USegd.singleton n
119 {-# INLINE_UP singleton #-}
120
121
122 -- | O(n). Convert an array of segment lengths into a parallel segment descriptor.
123 --
124 -- The array contains the length of each segment, and we compute the
125 -- indices from that. Runtime is O(n) in the number of segments.
126 --
127 fromLengths :: Vector Int -> UPSegd
128 fromLengths = fromUSegd . USegd.fromLengths
129 {-# INLINE_UP fromLengths #-}
130
131
132 -- Projections ----------------------------------------------------------------
133 -- INLINE trivial projections as they'll expand to a single record selector.
134
135 -- | O(1). Yield the overall number of segments.
136 length :: UPSegd -> Int
137 length = USegd.length . upsegd_usegd
138 {-# INLINE length #-}
139
140
141 -- | O(1). Yield the global `USegd` of a `UPSegd`.
142 takeUSegd :: UPSegd -> USegd
143 takeUSegd = upsegd_usegd
144 {-# INLINE takeUSegd #-}
145
146
147 -- | O(1). Yield the distributed `USegd` of a `UPSegd`.
148 --
149 -- We get a plain `USegd` for each chunk, the segment id of the first
150 -- slice in the chunk, and the starting offset of that slice in its segment.
151 --
152 takeDistributed :: UPSegd -> Dist ((USegd,Int),Int)
153 takeDistributed = upsegd_dsegd
154 {-# INLINE takeDistributed #-}
155
156
157 -- | O(1). Yield the lengths of the individual segments.
158 takeLengths :: UPSegd -> Vector Int
159 takeLengths = USegd.takeLengths . upsegd_usegd
160 {-# INLINE takeLengths #-}
161
162
163 -- | O(1). Yield the segment indices.
164 takeIndices :: UPSegd -> Vector Int
165 takeIndices = USegd.takeIndices . upsegd_usegd
166 {-# INLINE takeIndices #-}
167
168
169 -- | O(1). Yield the total number of array elements.
170 --
171 -- @takeElements upsegd = sum (takeLengths upsegd)@
172 --
173 takeElements :: UPSegd -> Int
174 takeElements = USegd.takeElements . upsegd_usegd
175 {-# INLINE takeElements #-}
176
177
178 -- Indices --------------------------------------------------------------------
179 -- | O(n). Yield a vector containing indicies that give the position of each
180 -- member of the flat array in its corresponding segment.
181 --
182 -- @indicesP (fromLengths [5, 2, 3]) = [0,1,2,3,4,0,1,0,1,2]@
183 --
184 indicesP :: UPSegd -> Vector Int
185 indicesP
186 = joinD theGang balanced
187 . mapD (What "UPSegd.indicesP/indices") theGang indices
188 . takeDistributed
189 where
190 indices ((segd,_k),off) = Seq.indicesSU' off segd
191 {-# NOINLINE indicesP #-}
192 -- NOINLINE because we're not using it yet.
193
194
195 -- Replicate ------------------------------------------------------------------
196 -- | Copying segmented replication. Each element of the vector is physically
197 -- copied according to the length of each segment in the segment descriptor.
198 --
199 -- @replicateWith (fromLengths [3, 1, 2]) [5, 6, 7] = [5, 5, 5, 6, 7, 7]@
200 --
201 replicateWithP :: Unbox a => UPSegd -> Vector a -> Vector a
202 replicateWithP segd !xs
203 = joinD theGang balanced
204 . mapD (What "UPSegd.replicateWithP/replicateSU") theGang rep
205 $ takeDistributed segd
206 where
207 rep ((dsegd,di),_)
208 = Seq.replicateSU dsegd
209 $ US.slice (here "replicateWithP")
210 xs di (USegd.length dsegd)
211 {-# INLINE_UP replicateWithP #-}
212
213
214 -- Fold -----------------------------------------------------------------------
215 -- | Fold segments specified by a `UPSegd`.
216 foldWithP :: Unbox a
217 => (a -> a -> a) -> a -> UPSegd -> Vector a -> Vector a
218 foldWithP f !z = foldSegsWithP f (Seq.foldlSU f z)
219 {-# INLINE_UP foldWithP #-}
220
221
222 -- | Fold segments specified by a `UPSegd`, with a non-empty vector.
223 fold1WithP :: Unbox a
224 => (a -> a -> a) -> UPSegd -> Vector a -> Vector a
225 fold1WithP f = foldSegsWithP f (Seq.fold1SU f)
226 {-# INLINE_UP fold1WithP #-}
227
228
229 -- | Sum up segments specified by a `UPSegd`.
230 sumWithP :: (Num e, Unbox e) => UPSegd -> Vector e -> Vector e
231 sumWithP = foldWithP (+) 0
232 {-# INLINE_UP sumWithP #-}
233
234
235 -- | Fold the segments specified by a `UPSegd`.
236 --
237 -- This low level function takes a per-element worker and a per-segment worker.
238 -- It folds all the segments with the per-segment worker, then uses the
239 -- per-element worker to fixup the partial results when a segment
240 -- is split across multiple threads.
241 --
242 foldSegsWithP
243 :: Unbox a
244 => (a -> a -> a)
245 -> (USegd -> Vector a -> Vector a)
246 -> UPSegd -> Vector a -> Vector a
247
248 {-# INLINE_UP foldSegsWithP #-}
249 foldSegsWithP fElem fSeg segd xs
250 = dcarry `seq` drs `seq`
251 runST (do
252 mrs <- joinDM theGang drs
253 fixupFold fElem mrs dcarry
254 US.unsafeFreeze mrs)
255
256 where (dcarry,drs)
257 = unzipD
258 $ mapD (What "UPSegd.foldSegsWithP/partial") theGang partial
259 $ zipD (takeDistributed segd)
260 (splitD theGang balanced xs)
261
262 partial (((segd', k), off), as)
263 = let rs = fSeg segd' as
264 {-# INLINE [0] n #-}
265 n | off == 0 = 0
266 | otherwise = 1
267
268 in ((k, US.take n rs), US.drop n rs)
269
270
271 fixupFold
272 :: Unbox a
273 => (a -> a -> a)
274 -> MVector s a
275 -> Dist (Int,Vector a)
276 -> ST s ()
277 {-# NOINLINE fixupFold #-}
278 fixupFold f !mrs !dcarry = go 1
279 where
280 !p = gangSize theGang
281
282 go i | i >= p = return ()
283 | US.null c = go (i+1)
284 | otherwise = do
285 x <- US.read mrs k
286 US.write mrs k (f x (US.index (here "fixupFold") c 0))
287 go (i + 1)
288 where
289 (k,c) = indexD (here "fixupFold") dcarry i