dph-prim-par: Add Justifications to distributed array functions
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE CPP #-}
4 #include "fusion-phases.h"
5
6 -- | Standard combinators for distributed types.
7 module Data.Array.Parallel.Unlifted.Distributed.Combinators
8 ( What (..)
9 , generateD, generateD_cheap
10 , imapD, mapD
11 , zipD, unzipD
12 , fstD, sndD
13 , zipWithD, izipWithD
14 , foldD
15 , scanD
16 , mapAccumLD
17
18 -- * Monadic combinators
19 , mapDST_, mapDST, zipWithDST_, zipWithDST)
20 where
21 import Data.Array.Parallel.Base ( ST, runST)
22 import Data.Array.Parallel.Unlifted.Distributed.Gang
23 import Data.Array.Parallel.Unlifted.Distributed.Types
24 import Data.Array.Parallel.Unlifted.Distributed.DistST
25 import Data.Array.Parallel.Unlifted.Distributed.What
26 import Debug.Trace
27
28 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
29
30
31 -- | Create a distributed value, given a function to create the instance
32 -- for each thread.
33 generateD
34 :: DT a
35 => What -- ^ What is the worker function doing.
36 -> Gang
37 -> (Int -> a)
38 -> Dist a
39 generateD what g f
40 = traceEvent (show $ CompGenerate False what)
41 $ runDistST g (myIndex >>= return . f)
42 {-# NOINLINE generateD #-}
43
44
45 -- | Create a distributed value, but do it sequentially.
46 --
47 -- This function is used when we want to operate on a distributed value, but
48 -- there isn't much data involved. For example, if we want to distribute
49 -- a single integer to each thread, then there's no need to fire up the
50 -- gang for this.
51 --
52 generateD_cheap
53 :: DT a
54 => What -- ^ What is the worker function doing.
55 -> Gang
56 -> (Int -> a)
57 -> Dist a
58
59 generateD_cheap what g f
60 = traceEvent (show $ CompGenerate True what)
61 $ runDistST_seq g (myIndex >>= return . f)
62 {-# NOINLINE generateD_cheap #-}
63
64
65 -- Mapping --------------------------------------------------------------------
66 --
67 -- Fusing maps
68 -- ~~~~~~~~~~~
69 -- The staging here is important.
70 -- Our rewrite rules only operate on the imapD form, so fusion between the worker
71 -- functions of consecutive maps takes place before phase [0].
72 --
73 -- At phase [0] we then inline imapD which introduces the call to imapD' which
74 -- uses the gang to evaluate its (now fused) worker.
75 --
76
77 -- | Map a function to every instance of a distributed value.
78 --
79 -- This applies the function to every thread, but not every value held
80 -- by the thread. If you want that then use something like:
81 --
82 -- @mapD theGang (V.map (+ 1)) :: Dist (Vector Int) -> Dist (Vector Int)@
83 --
84 mapD :: (DT a, DT b)
85 => What -- ^ What is the worker function doing.
86 -> Gang
87 -> (a -> b)
88 -> Dist a
89 -> Dist b
90
91 mapD wFn gang
92 = imapD wFn gang . const
93 {-# INLINE mapD #-}
94 -- INLINE because this is just a convenience wrapper for imapD.
95 -- None of our rewrite rules are particular to mapD.
96
97
98 -- | Map a function across all elements of a distributed value.
99 -- The worker function also gets the current thread index.
100 -- As opposed to `imapD'` this version also deepSeqs each element before
101 -- passing it to the function.
102 imapD :: (DT a, DT b)
103 => What -- ^ What is the worker function doing.
104 -> Gang
105 -> (Int -> a -> b)
106 -> Dist a -> Dist b
107 imapD wFn gang f d
108 = imapD' wFn gang (\i x -> x `deepSeqD` f i x) d
109 {-# INLINE [0] imapD #-}
110 -- INLINE [0] because we want to wait until phase [0] before introducing
111 -- the call to imapD'. Our rewrite rules operate directly on the imapD
112 -- formp, so once imapD is inlined no more fusion can take place.
113
114
115 -- | Map a function across all elements of a distributed value.
116 -- The worker function also gets the current thread index.
117 imapD' :: (DT a, DT b)
118 => What -> Gang -> (Int -> a -> b) -> Dist a -> Dist b
119 imapD' what gang f !d
120 = traceEvent (show (CompMap $ what))
121 $ runDistST gang
122 (do i <- myIndex
123 x <- myD d
124 return (f i x))
125 {-# NOINLINE imapD' #-}
126 -- NOINLINE
127
128
129 {-# RULES
130
131 "imapD/generateD"
132 forall wMap wGen gang f g
133 . imapD wMap gang f (generateD wGen gang g)
134 = generateD (WhatFusedMapGen wMap wGen) gang (\i -> f i (g i))
135
136 "imapD/generateD_cheap"
137 forall wMap wGen gang f g
138 . imapD wMap gang f (generateD_cheap wGen gang g)
139 = generateD (WhatFusedMapGen wMap wGen) gang (\i -> f i (g i))
140
141 "imapD/imapD"
142 forall wMap1 wMap2 gang f g d
143 . imapD wMap1 gang f (imapD wMap2 gang g d)
144 = imapD (WhatFusedMapMap wMap1 wMap2) gang (\i x -> f i (g i x)) d
145
146 #-}
147
148
149 -- Zipping --------------------------------------------------------------------
150 -- | Combine two distributed values with the given function.
151 zipWithD :: (DT a, DT b, DT c)
152 => What -- ^ What is the worker function doing.
153 -> Gang
154 -> (a -> b -> c)
155 -> Dist a -> Dist b -> Dist c
156
157 zipWithD what g f dx dy
158 = mapD what g (uncurry f) (zipD dx dy)
159 {-# INLINE zipWithD #-}
160
161
162 -- | Combine two distributed values with the given function.
163 -- The worker function also gets the index of the current thread.
164 izipWithD :: (DT a, DT b, DT c)
165 => What -- ^ What is the worker function doing.
166 -> Gang
167 -> (Int -> a -> b -> c)
168 -> Dist a -> Dist b -> Dist c
169
170 izipWithD what g f dx dy
171 = imapD what g (\i -> uncurry (f i)) (zipD dx dy)
172 {-# INLINE izipWithD #-}
173
174
175 {-# RULES
176 "zipD/imapD[1]"
177 forall gang f xs ys what
178 . zipD (imapD what gang f xs) ys
179 = imapD what gang (\i (x,y) -> (f i x, y)) (zipD xs ys)
180
181 "zipD/imapD[2]"
182 forall gang f xs ys what
183 . zipD xs (imapD what gang f ys)
184 = imapD what gang (\i (x,y) -> (x, f i y)) (zipD xs ys)
185
186 "zipD/generateD[1]"
187 forall gang f xs what
188 . zipD (generateD what gang f) xs
189 = imapD what gang (\i x -> (f i, x)) xs
190
191 "zipD/generateD[2]"
192 forall gang f xs what
193 . zipD xs (generateD what gang f)
194 = imapD what gang (\i x -> (x, f i)) xs
195
196 #-}
197
198
199 -- Folding --------------------------------------------------------------------
200 -- | Fold all the instances of a distributed value.
201 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
202 foldD g f !d
203 = checkGangD ("here foldD") g d
204 $ fold 1 (indexD (here "foldD") d 0)
205 where
206 !n = gangSize g
207 --
208 fold i x | i == n = x
209 | otherwise = fold (i+1) (f x $ indexD (here "foldD") d i)
210 {-# NOINLINE foldD #-}
211
212
213 -- | Prefix sum of the instances of a distributed value.
214 scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
215 scanD g f z !d
216 = checkGangD (here "scanD") g d
217 $ runST (do
218 md <- newMD g
219 s <- scan md 0 z
220 d' <- unsafeFreezeMD md
221 return (d',s))
222 where
223 !n = gangSize g
224
225 scan :: forall s. MDist a s -> Int -> a -> ST s a
226 scan md i !x
227 | i == n = return x
228 | otherwise
229 = do writeMD md i x
230 scan md (i+1) (f x $ indexD (here "scanD") d i)
231 {-# NOINLINE scanD #-}
232
233
234
235 -- MapAccumL ------------------------------------------------------------------
236 -- | Combination of map and fold.
237 mapAccumLD
238 :: forall a b acc. (DT a, DT b)
239 => Gang
240 -> (acc -> a -> (acc, b))
241 -> acc -> Dist a -> (acc, Dist b)
242
243 mapAccumLD g f acc !d
244 = checkGangD (here "mapAccumLD") g d
245 $ runST (do
246 md <- newMD g
247 acc' <- go md 0 acc
248 d' <- unsafeFreezeMD md
249 return (acc',d'))
250 where
251 !n = gangSize g
252
253 go :: MDist b s -> Int -> acc -> ST s acc
254 go md i acc'
255 | i == n = return acc'
256 | otherwise
257 = case f acc' (indexD (here "mapAccumLD") d i) of
258 (acc'',b) -> do
259 writeMD md i b
260 go md (i+1) acc''
261 {-# INLINE_DIST mapAccumLD #-}
262
263
264 -- Versions that work on DistST -----------------------------------------------
265 -- NOTE: The following combinators must be strict in the Dists because if they
266 -- are not, the Dist might be evaluated (in parallel) when it is requested in
267 -- the current computation which, again, is parallel. This would break our
268 -- model andlead to a deadlock. Hence the bangs.
269
270 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
271 mapDST_ g p d
272 = mapDST_' g (\x -> x `deepSeqD` p x) d
273 {-# INLINE mapDST_ #-}
274
275
276 mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
277 mapDST_' g p !d
278 = checkGangD (here "mapDST_") g d
279 $ distST_ g (myD d >>= p)
280
281
282 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
283 mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
284 {-# INLINE mapDST #-}
285
286
287 mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
288 mapDST' g p !d
289 = checkGangD (here "mapDST_") g d
290 $ distST g (myD d >>= p)
291
292
293 zipWithDST_
294 :: (DT a, DT b)
295 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
296 zipWithDST_ g p !dx !dy
297 = mapDST_ g (uncurry p) (zipD dx dy)
298 {-# INLINE zipWithDST_ #-}
299
300
301 zipWithDST
302 :: (DT a, DT b, DT c)
303 => Gang
304 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
305 zipWithDST g p !dx !dy
306 = mapDST g (uncurry p) (zipD dx dy)
307 {-# INLINE zipWithDST #-}
308