52f315f0e15cf7ed182e2f0bd2f97c8cb7c90987
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Arrays.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE EmptyDataDecls, ScopedTypeVariables #-}
3 {-# LANGUAGE CPP #-}
4 #include "fusion-phases.h"
5
6 -- | Operations on distributed arrays.
7 module Data.Array.Parallel.Unlifted.Distributed.Arrays
8 ( -- * Distribution phantom parameter
9 Distribution, balanced, unbalanced
10
11 -- * Array Lengths
12 , lengthD, splitLenD, splitLenIdxD
13
14 -- * Splitting and joining
15 , splitAsD, splitD, joinLengthD, joinD, splitJoinD, joinDM
16
17 -- * Permutations
18 , permuteD, bpermuteD
19
20 -- * Update
21 , atomicUpdateD
22
23 -- * Carry
24 , carryD)
25 where
26 import Data.Array.Parallel.Base (ST, runST)
27 import Data.Array.Parallel.Unlifted.Distributed.Gang
28 import Data.Array.Parallel.Unlifted.Distributed.DistST
29 import Data.Array.Parallel.Unlifted.Distributed.Types
30 import Data.Array.Parallel.Unlifted.Distributed.Combinators
31 import Data.Array.Parallel.Unlifted.Distributed.Scalars
32 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
33 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
34 import GHC.Base ( quotInt, remInt )
35 import Control.Monad
36
37 here :: String -> String
38 here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." Prelude.++ s
39
40
41 -- Distribution ---------------------------------------------------------------
42 -- | This is a phantom parameter used to record whether a distributed value
43 -- is balanced evenly among the threads. It's used to signal this property
44 -- between RULES, but the actual value is never used.
45 data Distribution
46
47 balanced :: Distribution
48 balanced = error $ here "balanced: touched"
49 {-# NOINLINE balanced #-}
50
51
52 unbalanced :: Distribution
53 unbalanced = error $ here "unbalanced: touched"
54 {-# NOINLINE unbalanced #-}
55
56
57 -- Splitting and Joining array lengths ----------------------------------------
58 -- | O(threads).
59 -- Distribute an array length over a 'Gang'.
60 -- Each thread holds the number of elements it's reponsible for.
61 -- If the array length doesn't split evenly among the threads then the first
62 -- threads get a few more elements.
63 --
64 -- @splitLenD theGangN4 511
65 -- = [128,128,128,127]@
66 --
67 splitLenD :: Gang -> Int -> Dist Int
68 splitLenD g n = generateD_cheap g len
69 where
70 !p = gangSize g
71 !l = n `quotInt` p
72 !m = n `remInt` p
73
74 {-# INLINE [0] len #-}
75 len i | i < m = l+1
76 | otherwise = l
77 {-# NOINLINE splitLenD #-}
78 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
79
80
81 -- | O(threads).
82 -- Distribute an array length over a 'Gang'.
83 -- Each thread holds the number of elements it's responsible for,
84 -- and the index of the start of its chunk.
85 --
86 -- @splitLenIdxD theGangN4 511
87 -- = [(128,0),(128,128),(128,256),(127,384)]@
88 --
89 splitLenIdxD :: Gang -> Int -> Dist (Int, Int)
90 splitLenIdxD g n = generateD_cheap g len_idx
91 where
92 !p = gangSize g
93 !l = n `quotInt` p
94 !m = n `remInt` p
95
96 {-# INLINE [0] len_idx #-}
97 len_idx i | i < m = (l+1, i*(l+1))
98 | otherwise = (l, i*l + m)
99 {-# NOINLINE splitLenIdxD #-}
100 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
101
102
103 -- | O(threads).
104 -- Get the overall length of a distributed array.
105 -- This is implemented by reading the chunk length from each thread,
106 -- and summing them up.
107 joinLengthD :: Unbox a => Gang -> Dist (Vector a) -> Int
108 joinLengthD g = sumD g . lengthD
109 {-# NOINLINE joinLengthD #-}
110 -- NOINLINE because it's cheap and doesn't need to fuse with anything.
111 -- No operations are performed on the elements, so we don't need
112 -- to specialise for the element type.
113
114
115 -- Splitting and Joining arrays -----------------------------------------------
116 -- | Distribute an array over a 'Gang' such that each threads gets the given
117 -- number of elements.
118 --
119 -- @splitAsD theGangN4 (splitLenD theGangN4 10) [1 2 3 4 5 6 7 8 9 0]
120 -- = [[1 2 3] [4 5 6] [7 8] [9 0]]@
121 --
122 splitAsD :: Unbox a => Gang -> Dist Int -> Vector a -> Dist (Vector a)
123 splitAsD g dlen !arr
124 = zipWithD (seqGang g) (Seq.slice "splitAsD" arr) is dlen
125 where
126 is = fst $ scanD g (+) 0 dlen
127 {-# INLINE_DIST splitAsD #-}
128
129
130 -- | Distribute an array over a 'Gang'.
131 --
132 -- NOTE: This is defined in terms of splitD_impl to avoid introducing loops
133 -- through RULES. Without it, splitJoinD would be a loop breaker.
134 --
135 splitD :: Unbox a => Gang -> Distribution -> Vector a -> Dist (Vector a)
136 splitD g _ arr = splitD_impl g arr
137 {-# INLINE_DIST splitD #-}
138
139
140 splitD_impl :: Unbox a => Gang -> Vector a -> Dist (Vector a)
141 splitD_impl g !arr
142 = generateD_cheap g (\i -> Seq.slice "splitD_impl" arr (idx i) (len i))
143 where
144 n = Seq.length arr
145 !p = gangSize g
146 !l = n `quotInt` p
147 !m = n `remInt` p
148
149 {-# INLINE [0] idx #-}
150 idx i | i < m = (l+1)*i
151 | otherwise = l*i + m
152
153 {-# INLINE [0] len #-}
154 len i | i < m = l+1
155 | otherwise = l
156 {-# INLINE_DIST splitD_impl #-}
157
158
159 -- | Join a distributed array.
160 -- Join sums up the array lengths of each chunk, allocates a new result array,
161 -- and copies each chunk into the result.
162 --
163 -- NOTE: This is defined in terms of joinD_impl to avoid introducing loops
164 -- through RULES. Without it, splitJoinD would be a loop breaker.
165 --
166 joinD :: Unbox a => Gang -> Distribution -> Dist (Vector a) -> Vector a
167 joinD g _ darr = joinD_impl g darr
168 {-# INLINE CONLIKE [1] joinD #-}
169
170
171 joinD_impl :: forall a. Unbox a => Gang -> Dist (Vector a) -> Vector a
172 joinD_impl g !darr
173 = checkGangD (here "joinD") g darr
174 $ Seq.new n (\ma -> zipWithDST_ g (copy ma) di darr)
175 where
176 (!di,!n) = scanD g (+) 0 $ lengthD darr
177
178 copy :: forall s. MVector s a -> Int -> Vector a -> DistST s ()
179 copy ma i arr = stToDistST (Seq.copy (Seq.mslice i (Seq.length arr) ma) arr)
180 {-# INLINE_DIST joinD_impl #-}
181
182
183 -- | Split a vector over a gang, run a distributed computation, then
184 -- join the pieces together again.
185 splitJoinD
186 :: (Unbox a, Unbox b)
187 => Gang
188 -> (Dist (Vector a) -> Dist (Vector b))
189 -> Vector a
190 -> Vector b
191 splitJoinD g f !xs
192 = joinD_impl g (f (splitD_impl g xs))
193 {-# INLINE_DIST splitJoinD #-}
194
195
196
197 -- | Join a distributed array, yielding a mutable global array
198 joinDM :: Unbox a => Gang -> Dist (Vector a) -> ST s (MVector s a)
199 joinDM g darr
200 = checkGangD (here "joinDM") g darr
201 $ do marr <- Seq.newM n
202 zipWithDST_ g (copy marr) di darr
203 return marr
204 where
205 (!di,!n) = scanD g (+) 0 $ lengthD darr
206
207 copy ma i arr = stToDistST (Seq.copy (Seq.mslice i (Seq.length arr) ma) arr)
208 {-# INLINE joinDM #-}
209
210
211 {-# RULES
212
213 "splitD[unbalanced]/joinD" forall g b da.
214 splitD g unbalanced (joinD g b da) = da
215
216 "splitD[balanced]/joinD" forall g da.
217 splitD g balanced (joinD g balanced da) = da
218
219 "splitD/splitJoinD" forall g b f xs.
220 splitD g b (splitJoinD g f xs) = f (splitD g b xs)
221
222 "splitJoinD/joinD" forall g b f da.
223 splitJoinD g f (joinD g b da) = joinD g b (f da)
224
225 "splitJoinD/splitJoinD" forall g f1 f2 xs.
226 splitJoinD g f1 (splitJoinD g f2 xs) = splitJoinD g (f1 . f2) xs
227
228 #-}
229
230 {-# RULES
231
232 "Seq.zip/joinD[1]" forall g xs ys.
233 Seq.zip (joinD g balanced xs) ys
234 = joinD g balanced (zipWithD g Seq.zip xs (splitD g balanced ys))
235
236 "Seq.zip/joinD[2]" forall g xs ys.
237 Seq.zip xs (joinD g balanced ys)
238 = joinD g balanced (zipWithD g Seq.zip (splitD g balanced xs) ys)
239
240 "Seq.zip/splitJoinD" forall gang f g xs ys.
241 Seq.zip (splitJoinD gang (imapD gang f) xs) (splitJoinD gang (imapD gang g) ys)
242 = splitJoinD gang (imapD gang (\i zs -> let (as,bs) = Seq.unzip zs
243 in Seq.zip (f i as) (g i bs)))
244 (Seq.zip xs ys)
245
246 #-}
247
248
249 -- Permutation ----------------------------------------------------------------
250 -- | Permute for distributed arrays.
251 permuteD
252 :: forall a. Unbox a
253 => Gang -> Dist (Vector a) -> Dist (Vector Int) -> Vector a
254 permuteD g darr dis
255 = Seq.new n (\ma -> zipWithDST_ g (permute ma) darr dis)
256 where
257 n = joinLengthD g darr
258
259 permute :: forall s. MVector s a -> Vector a -> Vector Int -> DistST s ()
260 permute ma arr is = stToDistST (Seq.mpermute ma arr is)
261 {-# INLINE_DIST permuteD #-}
262
263
264 -- NOTE: The bang is necessary because the array must be fully evaluated
265 -- before we pass it to the parallel computation.
266 bpermuteD :: Unbox a => Gang -> Vector a -> Dist (Vector Int) -> Dist (Vector a)
267 bpermuteD g !as ds = mapD g (Seq.bpermute as) ds
268 {-# INLINE bpermuteD #-}
269
270
271 -- Update ---------------------------------------------------------------------
272 -- NB: This does not (and cannot) try to prevent two threads from writing to
273 -- the same position. We probably want to consider this an (unchecked) user
274 -- error.
275 atomicUpdateD :: forall a. Unbox a
276 => Gang -> Dist (Vector a) -> Dist (Vector (Int,a)) -> Vector a
277 atomicUpdateD g darr upd
278 = runST
279 $ do marr <- joinDM g darr
280 mapDST_ g (update marr) upd
281 Seq.unsafeFreeze marr
282 where
283 update :: forall s. MVector s a -> Vector (Int,a) -> DistST s ()
284 update marr arr = stToDistST (Seq.mupdate marr arr)
285 {-# INLINE atomicUpdateD #-}
286
287
288 -- Carry ----------------------------------------------------------------------
289 -- | Selectively combine the last elements of some chunks with the
290 -- first elements of others.
291 --
292 -- NOTE: This runs sequentially and should only be used for testing purposes.
293 --
294 -- @
295 -- pprp $ splitD theGang unbalanced $ fromList [80, 10, 20, 40, 50, 10 :: Int]
296 -- DVector lengths: [2,2,1,1]
297 -- chunks: [[80,10],[20,40],[50],[10]]
298 --
299 -- pprp $ fst
300 -- $ carryD theGang (+) 0
301 -- (mkDPrim $ fromList [True, False, True, False])
302 -- (splitD theGang unbalanced $ fromList [80, 10, 20, 40, 50, 10 :: Int])
303 --
304 -- DVector lengths: [1,2,0,1]
305 -- chunks: [[80],[30,40],[],[60]]
306 -- @
307 --
308 carryD :: forall a
309 . (Unbox a, DT a)
310 => Gang
311 -> (a -> a -> a) -> a
312 -> Dist Bool
313 -> Dist (Vector a)
314 -> (Dist (Vector a), a)
315
316 carryD gang f zero shouldCarry vec
317 = runST
318 $ do md <- newMD gang
319 acc <- carryD' f zero shouldCarry vec md
320 d <- unsafeFreezeMD md
321 return (d, acc)
322
323
324 carryD' :: forall a s
325 . (Unbox a, DT a)
326 => (a -> a -> a) -> a
327 -> Dist Bool
328 -> Dist (Vector a)
329 -> MDist (Vector a) s
330 -> ST s a
331
332 carryD' f zero shouldCarry vec md_
333 = go md_ zero 0
334 where go (md :: MDist (Vector a) s) prev ix
335 | ix >= sizeD vec = return prev
336 | otherwise
337 = do let chunk :: Vector a
338 !chunk = indexD (here "carryD'") vec ix
339 let !chunkLen = Seq.length chunk
340
341 -- Whether to carry the last value of this chunk into the next chunk
342 let !carry = indexD (here "carryD") shouldCarry ix
343
344 -- The new length for this chunk
345 let !chunkLen'
346 | chunkLen == 0 = 0
347 | carry = chunkLen - 1
348 | otherwise = chunkLen
349
350 -- The new value of the accumulator
351 let acc = f prev (Seq.index (here "carryD'") chunk 0)
352
353 -- Allocate a mutable vector to hold the new chunk and copy
354 -- source elements into it.
355 mchunk' <- Seq.newM chunkLen'
356 Seq.copy mchunk' (Seq.slice (here "carryD'") chunk 0 chunkLen')
357
358 when (chunkLen' /= 0)
359 $ Seq.write mchunk' 0 acc
360
361 -- Store the new chunk in the gang
362 chunk' <- Seq.unsafeFreeze mchunk'
363 writeMD md ix chunk'
364
365 -- What value to carry into the next chunk
366 let next
367 | chunkLen' == 0 = acc
368 | carry = Seq.index (here "next") chunk (chunkLen - 1)
369 | otherwise = zero
370
371 go md next (ix + 1)
372