78b0f4630c2297386ad9a9226385098d48f823a3
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 {-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
2 {-# LANGUAGE ScopedTypeVariables #-}
3 {-# LANGUAGE CPP #-}
4 #include "fusion-phases.h"
5
6 -- | Standard combinators for distributed types.
7 module Data.Array.Parallel.Unlifted.Distributed.Combinators
8 ( generateD, generateD_cheap
9 , imapD, mapD
10 , zipD, unzipD
11 , fstD, sndD
12 , zipWithD, izipWithD
13 , foldD
14 , scanD
15 , mapAccumLD
16
17 -- * Monadic combinators
18 , mapDST_, mapDST, zipWithDST_, zipWithDST)
19 where
20 import Data.Array.Parallel.Base ( ST, runST)
21 import Data.Array.Parallel.Unlifted.Distributed.Gang
22 import Data.Array.Parallel.Unlifted.Distributed.Types
23 import Data.Array.Parallel.Unlifted.Distributed.DistST
24
25
26 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
27
28 -- | Create a distributed value, given a function to create the instance
29 -- for each thread.
30 generateD :: DT a => Gang -> (Int -> a) -> Dist a
31 generateD g f
32 = runDistST g (myIndex >>= return . f)
33 {-# NOINLINE generateD #-}
34
35
36 -- | Create a distributed value, but do it sequentially.
37 --
38 -- This function is used when we want to operate on a distributed value, but
39 -- there isn't much data involved. For example, if we want to distribute
40 -- a single integer to each thread, then there's no need to fire up the
41 -- gang for this.
42 --
43 generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a
44 generateD_cheap g f
45 = runDistST_seq g (myIndex >>= return . f)
46 {-# NOINLINE generateD_cheap #-}
47
48
49 -- Mapping --------------------------------------------------------------------
50 -- | Map a function across all elements of a distributed value.
51 -- The worker function also gets the current thread index.
52 -- As opposed to `imapD'` this version also deepSeqs each element before
53 -- passing it to the function.
54 imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
55 imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d
56 {-# INLINE [0] imapD #-}
57
58
59 -- | Map a function across all elements of a distributed value.
60 -- The worker function also gets the current thread index.
61 imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
62 imapD' g f !d
63 = checkGangD (here "imapD") g d
64 $ runDistST g
65 (do i <- myIndex
66 x <- myD d
67 return (f i x))
68 {-# NOINLINE imapD' #-}
69
70
71 -- | Map a function to every instance of a distributed value.
72 --
73 -- This applies the function to every thread, but not every value held
74 -- by the thread. If you want that then use something like:
75 --
76 -- @mapD theGang (V.map (+ 1)) :: Dist (Vector Int) -> Dist (Vector Int)@
77 --
78 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
79 mapD g = imapD g . const
80 {-# INLINE mapD #-}
81
82
83 {-# RULES
84
85 "imapD/generateD" forall gang f g.
86 imapD gang f (generateD gang g) = generateD gang (\i -> f i (g i))
87
88 "imapD/generateD_cheap" forall gang f g.
89 imapD gang f (generateD_cheap gang g) = generateD gang (\i -> f i (g i))
90
91 "imapD/imapD" forall gang f g d.
92 imapD gang f (imapD gang g d) = imapD gang (\i x -> f i (g i x)) d
93
94 #-}
95
96
97 -- Zipping --------------------------------------------------------------------
98 -- | Combine two distributed values with the given function.
99 zipWithD :: (DT a, DT b, DT c)
100 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
101 zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy)
102 {-# INLINE zipWithD #-}
103
104
105 -- | Combine two distributed values with the given function.
106 -- The worker function also gets the index of the current thread.
107 izipWithD :: (DT a, DT b, DT c)
108 => Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c
109 izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy)
110 {-# INLINE izipWithD #-}
111
112
113 {-# RULES
114 "zipD/imapD[1]" forall gang f xs ys.
115 zipD (imapD gang f xs) ys
116 = imapD gang (\i (x,y) -> (f i x,y)) (zipD xs ys)
117
118 "zipD/imapD[2]" forall gang f xs ys.
119 zipD xs (imapD gang f ys)
120 = imapD gang (\i (x,y) -> (x, f i y)) (zipD xs ys)
121
122 "zipD/generateD[1]" forall gang f xs.
123 zipD (generateD gang f) xs
124 = imapD gang (\i x -> (f i, x)) xs
125
126 "zipD/generateD[2]" forall gang f xs.
127 zipD xs (generateD gang f)
128 = imapD gang (\i x -> (x, f i)) xs
129
130 #-}
131
132
133 -- Folding --------------------------------------------------------------------
134 -- | Fold all the instances of a distributed value.
135 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
136 foldD g f !d
137 = checkGangD ("here foldD") g d
138 $ fold 1 (indexD (here "foldD") d 0)
139 where
140 !n = gangSize g
141 --
142 fold i x | i == n = x
143 | otherwise = fold (i+1) (f x $ indexD (here "foldD") d i)
144 {-# NOINLINE foldD #-}
145
146
147 -- | Prefix sum of the instances of a distributed value.
148 scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
149 scanD g f z !d
150 = checkGangD (here "scanD") g d
151 $ runST (do
152 md <- newMD g
153 s <- scan md 0 z
154 d' <- unsafeFreezeMD md
155 return (d',s))
156 where
157 !n = gangSize g
158
159 scan :: forall s. MDist a s -> Int -> a -> ST s a
160 scan md i !x
161 | i == n = return x
162 | otherwise
163 = do writeMD md i x
164 scan md (i+1) (f x $ indexD (here "scanD") d i)
165 {-# NOINLINE scanD #-}
166
167
168 -- | Combination of map and fold.
169 mapAccumLD
170 :: forall a b acc. (DT a, DT b)
171 => Gang
172 -> (acc -> a -> (acc, b))
173 -> acc -> Dist a -> (acc, Dist b)
174
175 mapAccumLD g f acc !d
176 = checkGangD (here "mapAccumLD") g d
177 $ runST (do
178 md <- newMD g
179 acc' <- go md 0 acc
180 d' <- unsafeFreezeMD md
181 return (acc',d'))
182 where
183 !n = gangSize g
184
185 go :: MDist b s -> Int -> acc -> ST s acc
186 go md i acc'
187 | i == n = return acc'
188 | otherwise
189 = case f acc' (indexD (here "mapAccumLD") d i) of
190 (acc'',b) -> do
191 writeMD md i b
192 go md (i+1) acc''
193 {-# INLINE_DIST mapAccumLD #-}
194
195
196 -- Versions that work on DistST -----------------------------------------------
197 -- NOTE: The following combinators must be strict in the Dists because if they
198 -- are not, the Dist might be evaluated (in parallel) when it is requested in
199 -- the current computation which, again, is parallel. This would break our
200 -- model andlead to a deadlock. Hence the bangs.
201
202 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
203 mapDST_ g p d
204 = mapDST_' g (\x -> x `deepSeqD` p x) d
205 {-# INLINE mapDST_ #-}
206
207
208 mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
209 mapDST_' g p !d
210 = checkGangD (here "mapDST_") g d
211 $ distST_ g (myD d >>= p)
212
213
214 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
215 mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
216 {-# INLINE mapDST #-}
217
218
219 mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
220 mapDST' g p !d
221 = checkGangD (here "mapDST_") g d
222 $ distST g (myD d >>= p)
223
224
225 zipWithDST_
226 :: (DT a, DT b)
227 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
228 zipWithDST_ g p !dx !dy
229 = mapDST_ g (uncurry p) (zipD dx dy)
230 {-# INLINE zipWithDST_ #-}
231
232
233 zipWithDST
234 :: (DT a, DT b, DT c)
235 => Gang
236 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
237 zipWithDST g p !dx !dy
238 = mapDST g (uncurry p) (zipD dx dy)
239 {-# INLINE zipWithDST #-}
240