395cf1f1d1bc6f217f1230e7e92a44a861922cf7
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / USegd.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE CPP #-}
3 #include "fusion-phases.h"
4
5 -- | Operations on Distributed Segment Descriptors
6 module Data.Array.Parallel.Unlifted.Distributed.USegd
7 ( splitSegdOnSegsD
8 , splitSegdOnElemsD
9 , splitSD
10 , joinSegdD
11 , glueSegdD)
12 where
13 import Data.Array.Parallel.Unlifted.Distributed.Arrays
14 import Data.Array.Parallel.Unlifted.Distributed.Combinators
15 import Data.Array.Parallel.Unlifted.Distributed.Types
16 import Data.Array.Parallel.Unlifted.Distributed.Gang
17 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
18 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, Unbox)
19 import Data.Array.Parallel.Base
20 import Data.Bits (shiftR)
21 import Control.Monad (when)
22 import qualified Data.Array.Parallel.Unlifted.Distributed.Types.USegd as DUSegd
23 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
24 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
25
26 here :: String -> String
27 here s = "Data.Array.Parallel.Unlifted.Distributed.USegd." ++ s
28
29 -------------------------------------------------------------------------------
30 -- | Split a segment descriptor across the gang, segment wise.
31 -- Whole segments are placed on each thread, and we try to balance out
32 -- the segments so each thread has the same number of array elements.
33 --
34 -- We don't split segments across threads, as this would limit our ability
35 -- to perform intra-thread fusion of lifted operations. The down side
36 -- of this is that if we have few segments with an un-even size distribution
37 -- then large segments can cause the gang to become unbalanced.
38 --
39 -- In the following example the segment with size 100 dominates and
40 -- unbalances the gang. There is no reason to put any segments on the
41 -- the last thread because we need to wait for the first to finish anyway.
42 --
43 -- @ > pprp $ splitSegdOnSegsD theGang
44 -- $ lengthsToUSegd $ fromList [100, 10, 20, 40, 50 :: Int]
45 --
46 -- DUSegd lengths: DVector lengths: [ 1, 3, 1, 0]
47 -- chunks: [[100],[10,20,40],[50],[]]
48 --
49 -- indices: DVector lengths: [1,3,1,0]
50 -- chunks: [[0], [0,10,30], [0], []]
51 --
52 -- elements: DInt [100,70,50,0]
53 -- @
54 --
55 -- NOTE: This splitSegdOnSegsD function isn't currently used.
56 --
57 splitSegdOnSegsD :: Gang -> USegd -> Dist USegd
58 splitSegdOnSegsD g !segd
59 = mapD g USegd.fromLengths
60 $ splitAsD g d lens
61 where
62 !d = snd
63 . mapAccumLD g chunks 0
64 . splitLenD g
65 $ USegd.takeElements segd
66
67 n = USegd.length segd
68 lens = USegd.takeLengths segd
69
70 chunks !i !k
71 = let !j = go i k
72 in (j,j-i)
73
74 go !i !k | i >= n = i
75 | m == 0 = go (i+1) k
76 | k <= 0 = i
77 | otherwise = go (i+1) (k-m)
78 where
79 m = Seq.index (here "splitSegdOnSegsD") lens i
80 {-# NOINLINE splitSegdOnSegsD #-}
81
82
83 -------------------------------------------------------------------------------
84 -- | Split a segment descriptor across the gang, element wise.
85 -- We try to put the same number of elements on each thread, which means
86 -- that segments are sometimes split across threads.
87 --
88 -- Each thread gets a slice of segment descriptor, the segid of the first
89 -- slice, and the offset of the first slice in its segment.
90 --
91 -- Example:
92 -- In this picture each X represents 5 elements, and we have 5 segements in total.
93 --
94 -- @ segs: ----------------------- --- ------- --------------- -------------------
95 -- elems: |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
96 -- | thread1 | thread2 | thread3 | thread4 |
97 -- segid: 0 0 3 4
98 -- offset: 0 45 0 5
99 --
100 -- pprp $ splitSegdOnElemsD theGang4
101 -- $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50 :: Int]
102 --
103 -- segd: DUSegd lengths: DVector lengths: [1,3,2,1]
104 -- chunks: [[45],[15,10,20],[40,5],[45]]
105 -- indices: DVector lengths: [1,3,2,1]
106 -- chunks: [[0], [0,15,25], [0,40],[0]]
107 -- elements: DInt [45,45,45,45]
108 --
109 -- segids: DInt [0,0,3,4] (segment id of first slice on thread)
110 -- offsets: DInt [0,45,0,5] (offset of that slice in its segment)
111 -- @
112 --
113 splitSegdOnElemsD :: Gang -> USegd -> Dist ((USegd,Int),Int)
114 splitSegdOnElemsD g !segd
115 = {-# SCC "splitSegdOnElemsD" #-}
116 imapD g mk (splitLenIdxD g (USegd.takeElements segd))
117 where
118 -- Number of threads in gang.
119 !nThreads = gangSize g
120
121 -- Determine what elements go on a thread
122 mk :: Int -- Thread index.
123 -> (Int, Int) -- Number of elements on this thread,
124 -- and starting offset into the flat array.
125 -> ((USegd, Int), Int) -- Segd for this thread, segid of first slice,
126 -- and offset of first slice.
127
128 mk i (nElems, ixStart)
129 = case getChunk segd ixStart nElems (i == nThreads - 1) of
130 (# lens, l, o #) -> ((USegd.fromLengths lens, l), o)
131
132 {-# NOINLINE splitSegdOnElemsD #-}
133 -- NOINLINE because this function has a large body of code and we don't want
134 -- to blow up the client modules by inlining it everywhere.
135
136
137 -------------------------------------------------------------------------------
138 -- | Determine what elements go on a thread.
139 -- The 'chunk' refers to the a chunk of the flat array, and is defined
140 -- by a set of segment slices.
141 --
142 -- Example:
143 -- In this picture each X represents 5 elements, and we have 5 segements in total.
144 --
145 -- @
146 -- segs: ----------------------- --- ------- --------------- -------------------
147 -- elems: |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
148 -- | thread1 | thread2 | thread3 | thread4 |
149 -- segid: 0 0 3 4
150 -- offset: 0 45 0 5
151 -- k: 0 1 3 5
152 -- k': 1 3 5 5
153 -- left: 0 15 0 45
154 -- right: 45 20 5 0
155 -- left_len: 0 1 0 1
156 -- left_off: 0 45 0 5
157 -- n': 1 3 2 1
158 -- @
159 getChunk
160 :: USegd -- ^ Segment descriptor of entire array.
161 -> Int -- ^ Starting offset into the flat array for the first
162 -- slice on this thread.
163 -> Int -- ^ Number of elements in this thread.
164 -> Bool -- ^ Whether this is the last thread in the gang.
165 -> (# Vector Int -- Lengths of segment slices,
166 , Int -- segid of first slice,
167 , Int #) -- offset of first slice.
168
169 getChunk !segd !nStart !nElems is_last
170 = (# lens'', k-left_len, left_off #)
171 where
172 -- Lengths of all segments.
173 -- eg: [60, 10, 20, 40, 50]
174 !lens = USegd.takeLengths segd
175
176 -- Indices indices of all segments.
177 -- eg: [0, 60, 70, 90, 130]
178 !idxs = USegd.takeIndices segd
179
180 -- Total number of segments defined by segment descriptor.
181 -- eg: 5
182 !n = Seq.length lens
183
184 -- Segid of the first seg that starts after the left of this chunk.
185 !k = search nStart idxs
186
187 -- Segid of the first seg that starts after the right of this chunk.
188 !k' | is_last = n
189 | otherwise = search (nStart + nElems) idxs
190
191 -- The length of the left-most slice of this chunk.
192 !left | k == n = nElems
193 | otherwise = min ((Seq.index (here "getChunk") idxs k) - nStart) nElems
194
195 -- The length of the right-most slice of this chunk.
196 !right | k' == k = 0
197 | otherwise = nStart + nElems - (Seq.index (here "getChunk") idxs (k'-1))
198
199 -- Whether the first element in this chunk is an internal element of
200 -- of a segment. Alternatively, indicates that the first element of
201 -- the chunk is not the first element of a segment.
202 !left_len | left == 0 = 0
203 | otherwise = 1
204
205 -- If the first element of the chunk starts within a segment,
206 -- then gives the index within that segment, otherwise 0.
207 !left_off | left == 0 = 0
208 | otherwise = nStart - (Seq.index (here "getChunk") idxs (k-1))
209
210 -- How many segments this chunk straddles.
211 !n' = left_len + (k'-k)
212
213 -- Create the lengths for this chunk by first copying out the lengths
214 -- from the original segment descriptor. If the slices on the left
215 -- and right cover partial segments, then we update the corresponding
216 -- lengths.
217 !lens'
218 = runST (do
219 -- Create a new array big enough to hold all the lengths for this chunk.
220 !mlens' <- Seq.newM n'
221
222 -- If the first element is inside a segment,
223 -- then update the length to be the length of the slice.
224 when (left /= 0)
225 $ Seq.write mlens' 0 left
226
227 -- Copy out array lengths for this chunk.
228 Seq.copy (Seq.mdrop left_len mlens')
229 (Seq.slice "getChunk" lens k (k'-k))
230
231 -- If the last element is inside a segment,
232 -- then update the length to be the length of the slice.
233 when (right /= 0)
234 $ Seq.write mlens' (n' - 1) right
235
236 Seq.unsafeFreeze mlens')
237
238 !lens'' = lens'
239 {- = trace
240 (render $ vcat
241 [ text "CHUNK"
242 , pprp segd
243 , text "nStart: " <+> int nStart
244 , text "nElems: " <+> int nElems
245 , text "k: " <+> int k
246 , text "k': " <+> int k'
247 , text "left: " <+> int left
248 , text "right: " <+> int right
249 , text "left_len:" <+> int left_len
250 , text "left_off:" <+> int left_off
251 , text "n': " <+> int n'
252 , text ""]) lens'
253 -}
254
255 {-# INLINE getChunk #-}
256 -- INLINE even though it should be inlined into splitSSegdOnElemsD anyway
257 -- because that function contains the only use.
258
259
260 -------------------------------------------------------------------------------
261 -- O(log n). Given a monotonically increasing vector of `Int`s,
262 -- find the first element that is larger than the given value.
263 --
264 -- eg search 75 [0, 60, 70, 90, 130] = 90
265 -- search 43 [0, 60, 70, 90, 130] = 60
266 --
267 search :: Int -> Vector Int -> Int
268 search !x ys = go 0 (Seq.length ys)
269 where
270 go i n | n <= 0 = i
271
272 | Seq.index (here "search") ys mid < x
273 = go (mid + 1) (n - half - 1)
274
275 | otherwise = go i half
276 where
277 half = n `shiftR` 1
278 mid = i + half
279
280
281 -------------------------------------------------------------------------------
282 -- | time O(segs)
283 -- Join a distributed segment descriptor into a global one.
284 -- This simply joins the distributed lengths and indices fields, but does
285 -- not reconstruct the original segment descriptor as it was before splitting.
286 --
287 -- @ > pprp $ joinSegdD theGang4
288 -- $ fstD $ fstD $ splitSegdOnElemsD theGang
289 -- $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50]
290 --
291 -- USegd lengths: [45,15,10,20,40,5,45]
292 -- indices: [0,45,60,70,90,130,135]
293 -- elements: 180
294 -- @
295 --
296 -- TODO: sequential runtime is O(segs) due to application of lengthsToUSegd
297 --
298 joinSegdD :: Gang -> Dist USegd -> USegd
299 joinSegdD gang
300 = USegd.fromLengths
301 . joinD gang unbalanced
302 . mapD gang USegd.takeLengths
303 {-# INLINE_DIST joinSegdD #-}
304
305
306 -------------------------------------------------------------------------------
307 -- | Glue a distributed segment descriptor back into the original global one.
308 -- Prop: glueSegdD gang $ splitSegdOnElems gang usegd = usegd
309 --
310 -- NOTE: This is runs sequentially and should only be used for testing purposes.
311 --
312 glueSegdD :: Gang -> Dist ((USegd, Int), Int) -> Dist USegd
313 glueSegdD gang bundle
314 = let !usegd = fstD $ fstD $ bundle
315 !lengths = DUSegd.takeLengthsD usegd
316
317 !firstSegOffsets = sndD bundle
318
319 -- | Whether the last segment in this chunk extends into the next chunk.
320 segSplits :: Dist Bool
321 !segSplits
322 = generateD_cheap gang $ \ix
323 -> if ix >= sizeD lengths - 1
324 then False
325 else indexD (here "glueSegdD") firstSegOffsets (ix + 1) /= 0
326
327 !lengths' = fst $ carryD gang (+) 0 segSplits lengths
328 !dusegd' = mapD gang USegd.fromLengths lengths'
329
330 in dusegd'
331 {-# INLINE_DIST glueSegdD #-}
332
333
334 -------------------------------------------------------------------------------
335 splitSD :: Unbox a => Gang -> Dist USegd -> Vector a -> Dist (Vector a)
336 splitSD g dsegd xs
337 = splitAsD g (DUSegd.takeElementsD dsegd) xs
338 {-# INLINE_DIST splitSD #-}
339
340 {-# RULES
341
342 "splitSD/splitJoinD" forall g d f xs.
343 splitSD g d (splitJoinD g f xs) = f (splitSD g d xs)
344
345 "splitSD/Seq.zip" forall g d xs ys.
346 splitSD g d (Seq.zip xs ys) = zipWithD g Seq.zip (splitSD g d xs)
347 (splitSD g d ys)
348
349 #-}