dph-prim-par: Add Justifications to distributed array functions
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Arrays.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE EmptyDataDecls, ScopedTypeVariables #-}
3 {-# LANGUAGE CPP #-}
4 #include "fusion-phases.h"
5
6 -- | Operations on distributed arrays.
7 module Data.Array.Parallel.Unlifted.Distributed.Arrays
8 ( -- * Distribution phantom parameter
9 Distribution, balanced, unbalanced
10
11 -- * Array Lengths
12 , lengthD, splitLenD, splitLenIdxD
13
14 -- * Splitting and joining
15 , splitAsD, splitD, joinLengthD, joinD, splitJoinD, joinDM
16
17 -- * Permutations
18 , permuteD, bpermuteD
19
20 -- * Update
21 , atomicUpdateD
22
23 -- * Carry
24 , carryD)
25 where
26 import Data.Array.Parallel.Base (ST, runST)
27 import Data.Array.Parallel.Unlifted.Distributed.Gang
28 import Data.Array.Parallel.Unlifted.Distributed.DistST
29 import Data.Array.Parallel.Unlifted.Distributed.Types
30 import Data.Array.Parallel.Unlifted.Distributed.Combinators
31 import Data.Array.Parallel.Unlifted.Distributed.Scalars
32 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
33 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
34 import GHC.Base ( quotInt, remInt )
35 import Control.Monad
36
37 here :: String -> String
38 here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." Prelude.++ s
39
40
41 -- Distribution ---------------------------------------------------------------
42 -- | This is a phantom parameter used to record whether a distributed value
43 -- is balanced evenly among the threads. It's used to signal this property
44 -- between RULES, but the actual value is never used.
45 data Distribution
46
47 balanced :: Distribution
48 balanced = error $ here "balanced: touched"
49 {-# NOINLINE balanced #-}
50
51
52 unbalanced :: Distribution
53 unbalanced = error $ here "unbalanced: touched"
54 {-# NOINLINE unbalanced #-}
55
56
57 -- Splitting and Joining array lengths ----------------------------------------
58 -- | O(threads).
59 -- Distribute an array length over a 'Gang'.
60 -- Each thread holds the number of elements it's reponsible for.
61 -- If the array length doesn't split evenly among the threads then the first
62 -- threads get a few more elements.
63 --
64 -- @splitLenD theGangN4 511
65 -- = [128,128,128,127]@
66 --
67 splitLenD :: Gang -> Int -> Dist Int
68 splitLenD gang n
69 = generateD_cheap WhatLength gang len
70 where
71 !p = gangSize gang
72 !l = n `quotInt` p
73 !m = n `remInt` p
74
75 {-# INLINE [0] len #-}
76 len i | i < m = l+1
77 | otherwise = l
78 {-# NOINLINE splitLenD #-}
79 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
80
81
82 -- | O(threads).
83 -- Distribute an array length over a 'Gang'.
84 -- Each thread holds the number of elements it's responsible for,
85 -- and the index of the start of its chunk.
86 --
87 -- @splitLenIdxD theGangN4 511
88 -- = [(128,0),(128,128),(128,256),(127,384)]@
89 --
90 splitLenIdxD :: Gang -> Int -> Dist (Int, Int)
91 splitLenIdxD gang n
92 = generateD_cheap WhatLengthIdx gang len_idx
93 where
94 !p = gangSize gang
95 !l = n `quotInt` p
96 !m = n `remInt` p
97
98 {-# INLINE [0] len_idx #-}
99 len_idx i | i < m = (l+1, i*(l+1))
100 | otherwise = (l, i*l + m)
101 {-# NOINLINE splitLenIdxD #-}
102 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
103
104
105 -- | O(threads).
106 -- Get the overall length of a distributed array.
107 -- This is implemented by reading the chunk length from each thread,
108 -- and summing them up.
109 joinLengthD :: Unbox a => Gang -> Dist (Vector a) -> Int
110 joinLengthD g = sumD g . lengthD
111 {-# NOINLINE joinLengthD #-}
112 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
113 -- No operations are performed on the elements, so we don't need
114 -- to specialise for the element type.
115
116
117 -- Splitting and Joining arrays -----------------------------------------------
118 -- | Distribute an array over a 'Gang' such that each threads gets the given
119 -- number of elements.
120 --
121 -- @splitAsD theGangN4 (splitLenD theGangN4 10) [1 2 3 4 5 6 7 8 9 0]
122 -- = [[1 2 3] [4 5 6] [7 8] [9 0]]@
123 --
124 splitAsD
125 :: Unbox a
126 => Gang -> Dist Int -> Vector a -> Dist (Vector a)
127
128 splitAsD gang dlen !arr
129 = zipWithD WhatSlice (seqGang gang) (Seq.slice "splitAsD" arr) is dlen
130 where
131 is = fst $ scanD gang (+) 0 dlen
132 {-# INLINE_DIST splitAsD #-}
133
134
135 -- | Distribute an array over a 'Gang'.
136 --
137 -- NOTE: This is defined in terms of splitD_impl to avoid introducing loops
138 -- through RULES. Without it, splitJoinD would be a loop breaker.
139 --
140 splitD :: Unbox a => Gang -> Distribution -> Vector a -> Dist (Vector a)
141 splitD g _ arr
142 = splitD_impl g arr
143 {-# INLINE_DIST splitD #-}
144
145
146 splitD_impl :: Unbox a => Gang -> Vector a -> Dist (Vector a)
147 splitD_impl g !arr
148 = generateD_cheap WhatSlice g
149 (\i -> Seq.slice "splitD_impl" arr (idx i) (len i))
150
151 where n = Seq.length arr
152 !p = gangSize g
153 !l = n `quotInt` p
154 !m = n `remInt` p
155
156 {-# INLINE [0] idx #-}
157 idx i | i < m = (l+1)*i
158 | otherwise = l*i + m
159
160 {-# INLINE [0] len #-}
161 len i | i < m = l+1
162 | otherwise = l
163 {-# INLINE_DIST splitD_impl #-}
164
165
166 -- | Join a distributed array.
167 -- Join sums up the array lengths of each chunk, allocates a new result array,
168 -- and copies each chunk into the result.
169 --
170 -- NOTE: This is defined in terms of joinD_impl to avoid introducing loops
171 -- through RULES. Without it, splitJoinD would be a loop breaker.
172 --
173 joinD :: Unbox a => Gang -> Distribution -> Dist (Vector a) -> Vector a
174 joinD g _ darr = joinD_impl g darr
175 {-# INLINE CONLIKE [1] joinD #-}
176
177
178 joinD_impl :: forall a. Unbox a => Gang -> Dist (Vector a) -> Vector a
179 joinD_impl g !darr
180 = checkGangD (here "joinD") g darr
181 $ Seq.new n (\ma -> zipWithDST_ g (copy ma) di darr)
182 where
183 (!di,!n) = scanD g (+) 0 $ lengthD darr
184
185 copy :: forall s. MVector s a -> Int -> Vector a -> DistST s ()
186 copy ma i arr = stToDistST (Seq.copy (Seq.mslice i (Seq.length arr) ma) arr)
187 {-# INLINE_DIST joinD_impl #-}
188
189
190 -- | Split a vector over a gang, run a distributed computation, then
191 -- join the pieces together again.
192 splitJoinD
193 :: (Unbox a, Unbox b)
194 => Gang
195 -> (Dist (Vector a) -> Dist (Vector b))
196 -> Vector a
197 -> Vector b
198 splitJoinD g f !xs
199 = joinD_impl g (f (splitD_impl g xs))
200 {-# INLINE_DIST splitJoinD #-}
201
202
203
204 -- | Join a distributed array, yielding a mutable global array
205 joinDM :: Unbox a => Gang -> Dist (Vector a) -> ST s (MVector s a)
206 joinDM g darr
207 = checkGangD (here "joinDM") g darr
208 $ do marr <- Seq.newM n
209 zipWithDST_ g (copy marr) di darr
210 return marr
211 where
212 (!di,!n) = scanD g (+) 0 $ lengthD darr
213
214 copy ma i arr = stToDistST (Seq.copy (Seq.mslice i (Seq.length arr) ma) arr)
215 {-# INLINE joinDM #-}
216
217
218 {-# RULES
219
220 "splitD[unbalanced]/joinD" forall g b da.
221 splitD g unbalanced (joinD g b da) = da
222
223 "splitD[balanced]/joinD" forall g da.
224 splitD g balanced (joinD g balanced da) = da
225
226 "splitD/splitJoinD" forall g b f xs.
227 splitD g b (splitJoinD g f xs) = f (splitD g b xs)
228
229 "splitJoinD/joinD" forall g b f da.
230 splitJoinD g f (joinD g b da) = joinD g b (f da)
231
232 "splitJoinD/splitJoinD" forall g f1 f2 xs.
233 splitJoinD g f1 (splitJoinD g f2 xs) = splitJoinD g (f1 . f2) xs
234
235 #-}
236
237 {-# RULES
238
239 "Seq.zip/joinD[1]" forall g xs ys.
240 Seq.zip (joinD g balanced xs) ys
241 = joinD g balanced (zipWithD WhatZip g Seq.zip xs (splitD g balanced ys))
242
243 "Seq.zip/joinD[2]" forall g xs ys.
244 Seq.zip xs (joinD g balanced ys)
245 = joinD g balanced (zipWithD WhatZip g Seq.zip (splitD g balanced xs) ys)
246
247 "Seq.zip/splitJoinD"
248 forall what1 what2 gang f g xs ys
249 . Seq.zip (splitJoinD gang (imapD what1 gang f) xs)
250 (splitJoinD gang (imapD what2 gang g) ys)
251 = splitJoinD gang
252 (imapD (WhatFusedZipMap what1 what2)
253 gang (\i zs -> let (as,bs) = Seq.unzip zs
254 in Seq.zip (f i as) (g i bs)))
255 (Seq.zip xs ys)
256
257 #-}
258
259
260 -- Permutation ----------------------------------------------------------------
261 -- | Permute for distributed arrays.
262 permuteD
263 :: forall a. Unbox a
264 => Gang -> Dist (Vector a) -> Dist (Vector Int) -> Vector a
265 permuteD g darr dis
266 = Seq.new n (\ma -> zipWithDST_ g (permute ma) darr dis)
267 where
268 n = joinLengthD g darr
269
270 permute :: forall s. MVector s a -> Vector a -> Vector Int -> DistST s ()
271 permute ma arr is = stToDistST (Seq.mpermute ma arr is)
272 {-# INLINE_DIST permuteD #-}
273
274
275 -- NOTE: The bang is necessary because the array must be fully evaluated
276 -- before we pass it to the parallel computation.
277 bpermuteD :: Unbox a
278 => Gang
279 -> Vector a
280 -> Dist (Vector Int)
281 -> Dist (Vector a)
282
283 bpermuteD gang !as ds
284 = mapD WhatBpermute gang (Seq.bpermute as) ds
285 {-# INLINE bpermuteD #-}
286
287
288 -- Update ---------------------------------------------------------------------
289 -- NB: This does not (and cannot) try to prevent two threads from writing to
290 -- the same position. We probably want to consider this an (unchecked) user
291 -- error.
292 atomicUpdateD :: forall a. Unbox a
293 => Gang -> Dist (Vector a) -> Dist (Vector (Int,a)) -> Vector a
294 atomicUpdateD g darr upd
295 = runST
296 $ do marr <- joinDM g darr
297 mapDST_ g (update marr) upd
298 Seq.unsafeFreeze marr
299 where
300 update :: forall s. MVector s a -> Vector (Int,a) -> DistST s ()
301 update marr arr = stToDistST (Seq.mupdate marr arr)
302 {-# INLINE atomicUpdateD #-}
303
304
305 -- Carry ----------------------------------------------------------------------
306 -- | Selectively combine the last elements of some chunks with the
307 -- first elements of others.
308 --
309 -- NOTE: This runs sequentially and should only be used for testing purposes.
310 --
311 -- @
312 -- pprp $ splitD theGang unbalanced $ fromList [80, 10, 20, 40, 50, 10 :: Int]
313 -- DVector lengths: [2,2,1,1]
314 -- chunks: [[80,10],[20,40],[50],[10]]
315 --
316 -- pprp $ fst
317 -- $ carryD theGang (+) 0
318 -- (mkDPrim $ fromList [True, False, True, False])
319 -- (splitD theGang unbalanced $ fromList [80, 10, 20, 40, 50, 10 :: Int])
320 --
321 -- DVector lengths: [1,2,0,1]
322 -- chunks: [[80],[30,40],[],[60]]
323 -- @
324 --
325 carryD :: forall a
326 . (Unbox a, DT a)
327 => Gang
328 -> (a -> a -> a) -> a
329 -> Dist Bool
330 -> Dist (Vector a)
331 -> (Dist (Vector a), a)
332
333 carryD gang f zero shouldCarry vec
334 = runST
335 $ do md <- newMD gang
336 acc <- carryD' f zero shouldCarry vec md
337 d <- unsafeFreezeMD md
338 return (d, acc)
339
340
341 carryD' :: forall a s
342 . (Unbox a, DT a)
343 => (a -> a -> a) -> a
344 -> Dist Bool
345 -> Dist (Vector a)
346 -> MDist (Vector a) s
347 -> ST s a
348
349 carryD' f zero shouldCarry vec md_
350 = go md_ zero 0
351 where go (md :: MDist (Vector a) s) prev ix
352 | ix >= sizeD vec = return prev
353 | otherwise
354 = do let chunk :: Vector a
355 !chunk = indexD (here "carryD'") vec ix
356 let !chunkLen = Seq.length chunk
357
358 -- Whether to carry the last value of this chunk into the next chunk
359 let !carry = indexD (here "carryD") shouldCarry ix
360
361 -- The new length for this chunk
362 let !chunkLen'
363 | chunkLen == 0 = 0
364 | carry = chunkLen - 1
365 | otherwise = chunkLen
366
367 -- The new value of the accumulator
368 let acc = f prev (Seq.index (here "carryD'") chunk 0)
369
370 -- Allocate a mutable vector to hold the new chunk and copy
371 -- source elements into it.
372 mchunk' <- Seq.newM chunkLen'
373 Seq.copy mchunk' (Seq.slice (here "carryD'") chunk 0 chunkLen')
374
375 when (chunkLen' /= 0)
376 $ Seq.write mchunk' 0 acc
377
378 -- Store the new chunk in the gang
379 chunk' <- Seq.unsafeFreeze mchunk'
380 writeMD md ix chunk'
381
382 -- What value to carry into the next chunk
383 let next
384 | chunkLen' == 0 = acc
385 | carry = Seq.index (here "next") chunk (chunkLen - 1)
386 | otherwise = zero
387
388 go md next (ix + 1)
389