dph-prim-par: Add Justifications to distributed array functions
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / USegd.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE CPP #-}
3 #include "fusion-phases.h"
4
5 -- | Operations on Distributed Segment Descriptors
6 module Data.Array.Parallel.Unlifted.Distributed.USegd
7 ( splitSegdOnSegsD
8 , splitSegdOnElemsD
9 , splitSD
10 , joinSegdD
11 , glueSegdD)
12 where
13 import Data.Array.Parallel.Unlifted.Distributed.Arrays
14 import Data.Array.Parallel.Unlifted.Distributed.Combinators
15 import Data.Array.Parallel.Unlifted.Distributed.Types
16 import Data.Array.Parallel.Unlifted.Distributed.Gang
17 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
18 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, Unbox)
19 import Data.Array.Parallel.Base
20 import Data.Bits (shiftR)
21 import Control.Monad (when)
22 import qualified Data.Array.Parallel.Unlifted.Distributed.Types.USegd as DUSegd
23 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
24 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
25
26 here :: String -> String
27 here s = "Data.Array.Parallel.Unlifted.Distributed.USegd." ++ s
28
29 -------------------------------------------------------------------------------
30 -- | Split a segment descriptor across the gang, segment wise.
31 -- Whole segments are placed on each thread, and we try to balance out
32 -- the segments so each thread has the same number of array elements.
33 --
34 -- We don't split segments across threads, as this would limit our ability
35 -- to perform intra-thread fusion of lifted operations. The down side
36 -- of this is that if we have few segments with an un-even size distribution
37 -- then large segments can cause the gang to become unbalanced.
38 --
39 -- In the following example the segment with size 100 dominates and
40 -- unbalances the gang. There is no reason to put any segments on the
41 -- the last thread because we need to wait for the first to finish anyway.
42 --
43 -- @ > pprp $ splitSegdOnSegsD theGang
44 -- $ lengthsToUSegd $ fromList [100, 10, 20, 40, 50 :: Int]
45 --
46 -- DUSegd lengths: DVector lengths: [ 1, 3, 1, 0]
47 -- chunks: [[100],[10,20,40],[50],[]]
48 --
49 -- indices: DVector lengths: [1,3,1,0]
50 -- chunks: [[0], [0,10,30], [0], []]
51 --
52 -- elements: DInt [100,70,50,0]
53 -- @
54 --
55 -- NOTE: This splitSegdOnSegsD function isn't currently used.
56 --
57 splitSegdOnSegsD :: Gang -> USegd -> Dist USegd
58 splitSegdOnSegsD g !segd
59 = mapD (What "USegd.splitSegdOnSegds/fromLengths") g USegd.fromLengths
60 $ splitAsD g d lens
61 where
62 !d = snd
63 . mapAccumLD g chunks 0
64 . splitLenD g
65 $ USegd.takeElements segd
66
67 n = USegd.length segd
68 lens = USegd.takeLengths segd
69
70 chunks !i !k
71 = let !j = go i k
72 in (j,j-i)
73
74 go !i !k | i >= n = i
75 | m == 0 = go (i+1) k
76 | k <= 0 = i
77 | otherwise = go (i+1) (k-m)
78 where
79 m = Seq.index (here "splitSegdOnSegsD") lens i
80 {-# NOINLINE splitSegdOnSegsD #-}
81
82
83 -------------------------------------------------------------------------------
84 -- | Split a segment descriptor across the gang, element wise.
85 -- We try to put the same number of elements on each thread, which means
86 -- that segments are sometimes split across threads.
87 --
88 -- Each thread gets a slice of segment descriptor, the segid of the first
89 -- slice, and the offset of the first slice in its segment.
90 --
91 -- Example:
92 -- In this picture each X represents 5 elements, and we have 5 segements in total.
93 --
94 -- @ segs: ----------------------- --- ------- --------------- -------------------
95 -- elems: |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
96 -- | thread1 | thread2 | thread3 | thread4 |
97 -- segid: 0 0 3 4
98 -- offset: 0 45 0 5
99 --
100 -- pprp $ splitSegdOnElemsD theGang4
101 -- $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50 :: Int]
102 --
103 -- segd: DUSegd lengths: DVector lengths: [1,3,2,1]
104 -- chunks: [[45],[15,10,20],[40,5],[45]]
105 -- indices: DVector lengths: [1,3,2,1]
106 -- chunks: [[0], [0,15,25], [0,40],[0]]
107 -- elements: DInt [45,45,45,45]
108 --
109 -- segids: DInt [0,0,3,4] (segment id of first slice on thread)
110 -- offsets: DInt [0,45,0,5] (offset of that slice in its segment)
111 -- @
112 --
113 splitSegdOnElemsD :: Gang -> USegd -> Dist ((USegd,Int),Int)
114 splitSegdOnElemsD g !segd
115 = {-# SCC "splitSegdOnElemsD" #-}
116 imapD (What "USegd.splitSegdOnElemsD/splitLenIdx") g mk
117 (splitLenIdxD g (USegd.takeElements segd))
118 where
119 -- Number of threads in gang.
120 !nThreads = gangSize g
121
122 -- Determine what elements go on a thread
123 mk :: Int -- Thread index.
124 -> (Int, Int) -- Number of elements on this thread,
125 -- and starting offset into the flat array.
126 -> ((USegd, Int), Int) -- Segd for this thread, segid of first slice,
127 -- and offset of first slice.
128
129 mk i (nElems, ixStart)
130 = case getChunk segd ixStart nElems (i == nThreads - 1) of
131 (# lens, l, o #) -> ((USegd.fromLengths lens, l), o)
132
133 {-# NOINLINE splitSegdOnElemsD #-}
134 -- NOINLINE because this function has a large body of code and we don't want
135 -- to blow up the client modules by inlining it everywhere.
136
137
138 -------------------------------------------------------------------------------
139 -- | Determine what elements go on a thread.
140 -- The 'chunk' refers to the a chunk of the flat array, and is defined
141 -- by a set of segment slices.
142 --
143 -- Example:
144 -- In this picture each X represents 5 elements, and we have 5 segements in total.
145 --
146 -- @
147 -- segs: ----------------------- --- ------- --------------- -------------------
148 -- elems: |X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|X X X X X X X X X|
149 -- | thread1 | thread2 | thread3 | thread4 |
150 -- segid: 0 0 3 4
151 -- offset: 0 45 0 5
152 -- k: 0 1 3 5
153 -- k': 1 3 5 5
154 -- left: 0 15 0 45
155 -- right: 45 20 5 0
156 -- left_len: 0 1 0 1
157 -- left_off: 0 45 0 5
158 -- n': 1 3 2 1
159 -- @
160 getChunk
161 :: USegd -- ^ Segment descriptor of entire array.
162 -> Int -- ^ Starting offset into the flat array for the first
163 -- slice on this thread.
164 -> Int -- ^ Number of elements in this thread.
165 -> Bool -- ^ Whether this is the last thread in the gang.
166 -> (# Vector Int -- Lengths of segment slices,
167 , Int -- segid of first slice,
168 , Int #) -- offset of first slice.
169
170 getChunk !segd !nStart !nElems is_last
171 = (# lens'', k-left_len, left_off #)
172 where
173 -- Lengths of all segments.
174 -- eg: [60, 10, 20, 40, 50]
175 !lens = USegd.takeLengths segd
176
177 -- Indices indices of all segments.
178 -- eg: [0, 60, 70, 90, 130]
179 !idxs = USegd.takeIndices segd
180
181 -- Total number of segments defined by segment descriptor.
182 -- eg: 5
183 !n = Seq.length lens
184
185 -- Segid of the first seg that starts after the left of this chunk.
186 !k = search nStart idxs
187
188 -- Segid of the first seg that starts after the right of this chunk.
189 !k' | is_last = n
190 | otherwise = search (nStart + nElems) idxs
191
192 -- The length of the left-most slice of this chunk.
193 !left | k == n = nElems
194 | otherwise = min ((Seq.index (here "getChunk") idxs k) - nStart) nElems
195
196 -- The length of the right-most slice of this chunk.
197 !right | k' == k = 0
198 | otherwise = nStart + nElems - (Seq.index (here "getChunk") idxs (k'-1))
199
200 -- Whether the first element in this chunk is an internal element of
201 -- of a segment. Alternatively, indicates that the first element of
202 -- the chunk is not the first element of a segment.
203 !left_len | left == 0 = 0
204 | otherwise = 1
205
206 -- If the first element of the chunk starts within a segment,
207 -- then gives the index within that segment, otherwise 0.
208 !left_off | left == 0 = 0
209 | otherwise = nStart - (Seq.index (here "getChunk") idxs (k-1))
210
211 -- How many segments this chunk straddles.
212 !n' = left_len + (k'-k)
213
214 -- Create the lengths for this chunk by first copying out the lengths
215 -- from the original segment descriptor. If the slices on the left
216 -- and right cover partial segments, then we update the corresponding
217 -- lengths.
218 !lens'
219 = runST (do
220 -- Create a new array big enough to hold all the lengths for this chunk.
221 !mlens' <- Seq.newM n'
222
223 -- If the first element is inside a segment,
224 -- then update the length to be the length of the slice.
225 when (left /= 0)
226 $ Seq.write mlens' 0 left
227
228 -- Copy out array lengths for this chunk.
229 Seq.copy (Seq.mdrop left_len mlens')
230 (Seq.slice "getChunk" lens k (k'-k))
231
232 -- If the last element is inside a segment,
233 -- then update the length to be the length of the slice.
234 when (right /= 0)
235 $ Seq.write mlens' (n' - 1) right
236
237 Seq.unsafeFreeze mlens')
238
239 !lens'' = lens'
240 {- = trace
241 (render $ vcat
242 [ text "CHUNK"
243 , pprp segd
244 , text "nStart: " <+> int nStart
245 , text "nElems: " <+> int nElems
246 , text "k: " <+> int k
247 , text "k': " <+> int k'
248 , text "left: " <+> int left
249 , text "right: " <+> int right
250 , text "left_len:" <+> int left_len
251 , text "left_off:" <+> int left_off
252 , text "n': " <+> int n'
253 , text ""]) lens'
254 -}
255
256 {-# INLINE getChunk #-}
257 -- INLINE even though it should be inlined into splitSSegdOnElemsD anyway
258 -- because that function contains the only use.
259
260
261 -------------------------------------------------------------------------------
262 -- O(log n). Given a monotonically increasing vector of `Int`s,
263 -- find the first element that is larger than the given value.
264 --
265 -- eg search 75 [0, 60, 70, 90, 130] = 90
266 -- search 43 [0, 60, 70, 90, 130] = 60
267 --
268 search :: Int -> Vector Int -> Int
269 search !x ys = go 0 (Seq.length ys)
270 where
271 go i n | n <= 0 = i
272
273 | Seq.index (here "search") ys mid < x
274 = go (mid + 1) (n - half - 1)
275
276 | otherwise = go i half
277 where
278 half = n `shiftR` 1
279 mid = i + half
280
281
282 -------------------------------------------------------------------------------
283 -- | time O(segs)
284 -- Join a distributed segment descriptor into a global one.
285 -- This simply joins the distributed lengths and indices fields, but does
286 -- not reconstruct the original segment descriptor as it was before splitting.
287 --
288 -- @ > pprp $ joinSegdD theGang4
289 -- $ fstD $ fstD $ splitSegdOnElemsD theGang
290 -- $ lengthsToUSegd $ fromList [60, 10, 20, 40, 50]
291 --
292 -- USegd lengths: [45,15,10,20,40,5,45]
293 -- indices: [0,45,60,70,90,130,135]
294 -- elements: 180
295 -- @
296 --
297 -- TODO: sequential runtime is O(segs) due to application of lengthsToUSegd
298 --
299 joinSegdD :: Gang -> Dist USegd -> USegd
300 joinSegdD gang
301 = USegd.fromLengths
302 . joinD gang unbalanced
303 . mapD (What "joinSegdD/takeLengths") gang USegd.takeLengths
304 {-# INLINE_DIST joinSegdD #-}
305
306
307 -------------------------------------------------------------------------------
308 -- | Glue a distributed segment descriptor back into the original global one.
309 -- Prop: glueSegdD gang $ splitSegdOnElems gang usegd = usegd
310 --
311 -- NOTE: This is runs sequentially and should only be used for testing purposes.
312 --
313 glueSegdD :: Gang -> Dist ((USegd, Int), Int) -> Dist USegd
314 glueSegdD gang bundle
315 = let !usegd = fstD $ fstD $ bundle
316 !lengths = DUSegd.takeLengthsD usegd
317
318 !firstSegOffsets = sndD bundle
319
320 -- | Whether the last segment in this chunk extends into the next chunk.
321 segSplits :: Dist Bool
322 !segSplits
323 = generateD_cheap (What "glueSegdD/segd_offsegs") gang $ \ix
324 -> if ix >= sizeD lengths - 1
325 then False
326 else indexD (here "glueSegdD") firstSegOffsets (ix + 1) /= 0
327
328 !lengths' = fst $ carryD gang (+) 0 segSplits lengths
329 !dusegd' = mapD (What "glueSegdD/fromLenghts") gang
330 USegd.fromLengths lengths'
331
332 in dusegd'
333 {-# INLINE_DIST glueSegdD #-}
334
335
336 -------------------------------------------------------------------------------
337 splitSD :: Unbox a => Gang -> Dist USegd -> Vector a -> Dist (Vector a)
338 splitSD g dsegd xs
339 = splitAsD g (DUSegd.takeElementsD dsegd) xs
340 {-# INLINE_DIST splitSD #-}
341
342 {-# RULES
343
344 "splitSD/splitJoinD"
345 forall g d f xs
346 . splitSD g d (splitJoinD g f xs)
347 = f (splitSD g d xs)
348
349 "splitSD/Seq.zip"
350 forall g d xs ys
351 . splitSD g d (Seq.zip xs ys)
352 = zipWithD WhatZip g Seq.zip
353 (splitSD g d xs)
354 (splitSD g d ys)
355
356 #-}