46a5d0bdf6ae744d95170adac8254d89d4ae1424
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 -- | Parallel segment descriptors.
5 --
6 -- See "Data.Array.Parallel.Unlifted" for how this works.
7 --
8 module Data.Array.Parallel.Unlifted.Parallel.UPSegd
9 ( -- * Types
10 UPSegd(..)
11 , valid
12
13 -- * Constructors
14 , mkUPSegd, fromUSegd
15 , empty, singleton, fromLengths
16
17 -- * Projections
18 , length
19 , takeUSegd
20 , takeDistributed
21 , takeLengths
22 , takeIndices
23 , takeElements
24
25 -- * Indices
26 , indicesP
27
28 -- * Replicate
29 , replicateWithP
30
31 -- * Segmented Folds
32 , foldWithP
33 , fold1WithP
34 , sumWithP
35 , foldSegsWithP)
36 where
37 import Data.Array.Parallel.Unlifted.Distributed
38 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
39 import qualified Data.Array.Parallel.Unlifted.Distributed.USegd as USegd
40 import qualified Data.Array.Parallel.Unlifted.Sequential as Seq
41 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as US
42 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
43 import Data.Array.Parallel.Pretty hiding (empty)
44 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
45 import Control.Monad.ST
46 import Prelude hiding (length)
47
48 here :: String -> String
49 here s = "Data.Array.Parallel.Unlifted.Parallel.UPSegd." ++ s
50
51
52 -- | A parallel segment descriptor holds a global (undistributed) segment
53 -- desciptor, as well as a distributed version. The distributed version
54 -- describes how to split work on the segmented array over the gang.
55 data UPSegd
56 = UPSegd
57 { upsegd_usegd :: !USegd
58 -- ^ Segment descriptor that describes the whole array.
59
60 , upsegd_dsegd :: Dist ((USegd,Int),Int)
61 -- ^ Segment descriptor for each chunk,
62 -- along with segment id of first slice in the chunk,
63 -- and the offset of that slice in its segment.
64 -- See docs of `splitSegdOfElemsD` for an example.
65 }
66
67
68 -- Pretty ---------------------------------------------------------------------
69 instance PprPhysical UPSegd where
70 pprp (UPSegd usegd dsegd)
71 = text "UPSegd"
72 $$ (nest 7 $ vcat
73 [ text "usegd: " <+> pprp usegd
74 , text "dsegd: " <+> pprp dsegd])
75
76
77 -- Valid ----------------------------------------------------------------------
78 -- | O(1).
79 -- Check the internal consistency of a parallel segment descriptor.
80 ---
81 -- * TODO: this doesn't do any checks yet
82 valid :: UPSegd -> Bool
83 valid _ = True
84 {-# NOINLINE valid #-}
85 -- NOINLINE because it's only used during debugging anyway.
86
87
88 -- Constructors ---------------------------------------------------------------
89 -- | O(1). Construct a new parallel segment descriptor.
90 mkUPSegd
91 :: Vector Int -- ^ Length of each segment.
92 -> Vector Int -- ^ Starting index of each segment.
93 -> Int -- ^ Total number of elements in the flat array.
94 -> UPSegd
95
96 mkUPSegd lens idxs n
97 = fromUSegd (USegd.mkUSegd lens idxs n)
98 {-# INLINE_UP mkUPSegd #-}
99
100
101 -- | Convert a global `USegd` to a parallel `UPSegd` by distributing
102 -- it across the gang.
103 fromUSegd :: USegd -> UPSegd
104 fromUSegd segd = UPSegd segd (USegd.splitSegdOnElemsD theGang segd)
105 {-# INLINE_UP fromUSegd #-}
106
107
108 -- | O(1). Construct an empty segment descriptor, with no elements or segments.
109 empty :: UPSegd
110 empty = fromUSegd USegd.empty
111 {-# INLINE_UP empty #-}
112
113
114 -- | O(1). Construct a singleton segment descriptor.
115 -- The single segment covers the given number of elements.
116 singleton :: Int -> UPSegd
117 singleton n = fromUSegd $ USegd.singleton n
118 {-# INLINE_UP singleton #-}
119
120
121 -- | O(n). Convert an array of segment lengths into a parallel segment descriptor.
122 --
123 -- The array contains the length of each segment, and we compute the
124 -- indices from that. Runtime is O(n) in the number of segments.
125 --
126 fromLengths :: Vector Int -> UPSegd
127 fromLengths = fromUSegd . USegd.fromLengths
128 {-# INLINE_UP fromLengths #-}
129
130
131 -- Projections ----------------------------------------------------------------
132 -- INLINE trivial projections as they'll expand to a single record selector.
133
134 -- | O(1). Yield the overall number of segments.
135 length :: UPSegd -> Int
136 length = USegd.length . upsegd_usegd
137 {-# INLINE length #-}
138
139
140 -- | O(1). Yield the global `USegd` of a `UPSegd`.
141 takeUSegd :: UPSegd -> USegd
142 takeUSegd = upsegd_usegd
143 {-# INLINE takeUSegd #-}
144
145
146 -- | O(1). Yield the distributed `USegd` of a `UPSegd`.
147 --
148 -- We get a plain `USegd` for each chunk, the segment id of the first
149 -- slice in the chunk, and the starting offset of that slice in its segment.
150 --
151 takeDistributed :: UPSegd -> Dist ((USegd,Int),Int)
152 takeDistributed = upsegd_dsegd
153 {-# INLINE takeDistributed #-}
154
155
156 -- | O(1). Yield the lengths of the individual segments.
157 takeLengths :: UPSegd -> Vector Int
158 takeLengths = USegd.takeLengths . upsegd_usegd
159 {-# INLINE takeLengths #-}
160
161
162 -- | O(1). Yield the segment indices.
163 takeIndices :: UPSegd -> Vector Int
164 takeIndices = USegd.takeIndices . upsegd_usegd
165 {-# INLINE takeIndices #-}
166
167
168 -- | O(1). Yield the total number of array elements.
169 --
170 -- @takeElements upsegd = sum (takeLengths upsegd)@
171 --
172 takeElements :: UPSegd -> Int
173 takeElements = USegd.takeElements . upsegd_usegd
174 {-# INLINE takeElements #-}
175
176
177 -- Indices --------------------------------------------------------------------
178 -- | O(n). Yield a vector containing indicies that give the position of each
179 -- member of the flat array in its corresponding segment.
180 --
181 -- @indicesP (fromLengths [5, 2, 3]) = [0,1,2,3,4,0,1,0,1,2]@
182 --
183 indicesP :: UPSegd -> Vector Int
184 indicesP
185 = joinD theGang balanced
186 . mapD theGang indices
187 . takeDistributed
188 where
189 indices ((segd,_k),off) = Seq.indicesSU' off segd
190 {-# NOINLINE indicesP #-}
191 -- NOINLINE because we're not using it yet.
192
193
194 -- Replicate ------------------------------------------------------------------
195 -- | Copying segmented replication. Each element of the vector is physically
196 -- copied according to the length of each segment in the segment descriptor.
197 --
198 -- @replicateWith (fromLengths [3, 1, 2]) [5, 6, 7] = [5, 5, 5, 6, 7, 7]@
199 --
200 replicateWithP :: Unbox a => UPSegd -> Vector a -> Vector a
201 replicateWithP segd !xs
202 = joinD theGang balanced
203 . mapD theGang rep
204 $ takeDistributed segd
205 where
206 rep ((dsegd,di),_)
207 = Seq.replicateSU dsegd
208 $ US.slice (here "replicateWithP")
209 xs di (USegd.length dsegd)
210 {-# INLINE_UP replicateWithP #-}
211
212
213 -- Fold -----------------------------------------------------------------------
214 -- | Fold segments specified by a `UPSegd`.
215 foldWithP :: Unbox a
216 => (a -> a -> a) -> a -> UPSegd -> Vector a -> Vector a
217 foldWithP f !z = foldSegsWithP f (Seq.foldlSU f z)
218 {-# INLINE_UP foldWithP #-}
219
220
221 -- | Fold segments specified by a `UPSegd`, with a non-empty vector.
222 fold1WithP :: Unbox a
223 => (a -> a -> a) -> UPSegd -> Vector a -> Vector a
224 fold1WithP f = foldSegsWithP f (Seq.fold1SU f)
225 {-# INLINE_UP fold1WithP #-}
226
227
228 -- | Sum up segments specified by a `UPSegd`.
229 sumWithP :: (Num e, Unbox e) => UPSegd -> Vector e -> Vector e
230 sumWithP = foldWithP (+) 0
231 {-# INLINE_UP sumWithP #-}
232
233
234 -- | Fold the segments specified by a `UPSegd`.
235 --
236 -- This low level function takes a per-element worker and a per-segment worker.
237 -- It folds all the segments with the per-segment worker, then uses the
238 -- per-element worker to fixup the partial results when a segment
239 -- is split across multiple threads.
240 --
241 foldSegsWithP
242 :: Unbox a
243 => (a -> a -> a)
244 -> (USegd -> Vector a -> Vector a)
245 -> UPSegd -> Vector a -> Vector a
246
247 {-# INLINE_UP foldSegsWithP #-}
248 foldSegsWithP fElem fSeg segd xs
249 = dcarry `seq` drs `seq`
250 runST (do
251 mrs <- joinDM theGang drs
252 fixupFold fElem mrs dcarry
253 US.unsafeFreeze mrs)
254
255 where (dcarry,drs)
256 = unzipD
257 $ mapD theGang partial
258 $ zipD (takeDistributed segd)
259 (splitD theGang balanced xs)
260
261 partial (((segd', k), off), as)
262 = let rs = fSeg segd' as
263 {-# INLINE [0] n #-}
264 n | off == 0 = 0
265 | otherwise = 1
266
267 in ((k, US.take n rs), US.drop n rs)
268
269
270 fixupFold
271 :: Unbox a
272 => (a -> a -> a)
273 -> MVector s a
274 -> Dist (Int,Vector a)
275 -> ST s ()
276 {-# NOINLINE fixupFold #-}
277 fixupFold f !mrs !dcarry = go 1
278 where
279 !p = gangSize theGang
280
281 go i | i >= p = return ()
282 | US.null c = go (i+1)
283 | otherwise = do
284 x <- US.read mrs k
285 US.write mrs k (f x (US.index (here "fixupFold") c 0))
286 go (i + 1)
287 where
288 (k,c) = indexD (here "fixupFold") dcarry i