Merge branch 'master' of /Users/benl/devel/ghc/ghc-head/libraries/dph
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE CPP #-}
3 -----------------------------------------------------------------------------
4 -- |
5 -- Module : Data.Array.Parallel.Unlifted.Distributed.Basics
6 -- Copyright : (c) 2006 Roman Leshchinskiy
7 -- License : see libraries/ndp/LICENSE
8 --
9 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
10 -- Stability : experimental
11 -- Portability : non-portable (GHC Extensions)
12 --
13 -- Standard combinators for distributed types.
14 --
15
16 #include "fusion-phases.h"
17
18 module Data.Array.Parallel.Unlifted.Distributed.Combinators (
19 generateD, generateD_cheap,
20 imapD, mapD, zipD, unzipD, fstD, sndD, zipWithD, izipWithD,
21 foldD, scanD, mapAccumLD,
22
23 -- * Monadic combinators
24 mapDST_, mapDST, zipWithDST_, zipWithDST
25 ) where
26
27 import Data.Array.Parallel.Base ( ST, runST)
28 import Data.Array.Parallel.Unlifted.Distributed.Gang (
29 Gang, gangSize)
30 import Data.Array.Parallel.Unlifted.Distributed.Types (
31 DT, Dist, MDist, indexD, zipD, unzipD, fstD, sndD, deepSeqD,
32 newMD, writeMD, unsafeFreezeMD,
33 checkGangD, measureD, debugD)
34 import Data.Array.Parallel.Unlifted.Distributed.DistST
35
36 import Debug.Trace
37
38 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
39
40 -- | Create a distributed value, given a function that makes the value in each thread.
41 generateD :: DT a => Gang -> (Int -> a) -> Dist a
42 {-# NOINLINE generateD #-}
43 generateD g f
44 = trace "generate full")
45 $ runDistST g (myIndex >>= return . f)
46
47
48 -- | Create a distributed value, but run it sequentially (I think?)
49 generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a
50 {-# NOINLINE generateD_cheap #-}
51 generateD_cheap g f
52 = trace "generate cheap"
53 $ runDistST_seq g (myIndex >>= return . f)
54
55
56 -- Mapping --------------------------------------------------------------------
57 -- | Map a function across all elements of a distributed value.
58 -- The worker function also gets the current thread index.
59 -- As opposed to `imapD'` this version also deepSeqs each element before
60 -- passing it to the function.
61 imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
62 {-# INLINE [0] imapD #-}
63 imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d
64
65
66 -- | Map a function across all elements of a distributed value.
67 -- The worker function also gets the current thread index.
68 imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
69 {-# NOINLINE imapD' #-}
70 imapD' g f !d = checkGangD (here "imapD") g d
71 (runDistST g (do
72 i <- myIndex
73 x <- myD d
74 return (f i x)))
75
76
77 -- | Map a function over a distributed value.
78 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
79 {-# INLINE mapD #-}
80 mapD g = imapD g . const
81
82 {-# RULES
83
84 "imapD/generateD" forall gang f g.
85 imapD gang f (generateD gang g) = generateD gang (\i -> f i (g i))
86
87 "imapD/generateD_cheap" forall gang f g.
88 imapD gang f (generateD_cheap gang g) = generateD gang (\i -> f i (g i))
89
90 "imapD/imapD" forall gang f g d.
91 imapD gang f (imapD gang g d) = imapD gang (\i x -> f i (g i x)) d
92
93 #-}
94
95 {- RULES
96
97 "mapD/generateD"
98 mapD gang f (generateD gang g) = generateD gang (\x -> f (g x))
99
100 "mapD/mapD" forall gang f g d.
101 mapD gang f (mapD gang g d) = mapD gang (\x -> f (g x)) d
102
103 "zipD/mapD[1]" forall gang f xs ys.
104 zipD (mapD gang f xs) ys
105 = mapD gang (unsafe_pairS . (\(xs, ys) -> (f xs, ys)) . unsafe_unpairS)
106 (zipD xs ys)
107
108 "zipD/mapD[2]" forall gang f xs ys.
109 zipD xs (mapD gang f ys)
110 = mapD gang (unsafe_pairS . (\(xs, ys) -> (xs, f ys)) . unsafe_unpairS)
111 (zipD xs ys)
112 -}
113
114
115 -- Zipping --------------------------------------------------------------------
116 -- | Combine two distributed values with the given function.
117 zipWithD :: (DT a, DT b, DT c)
118 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
119 {-# INLINE zipWithD #-}
120 zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy)
121
122
123 -- | Combine two distributed values with the given function.
124 -- The worker function also gets the index of the current thread.
125 izipWithD :: (DT a, DT b, DT c)
126 => Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c
127 {-# INLINE izipWithD #-}
128 izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy)
129
130 {-# RULES
131 "zipD/imapD[1]" forall gang f xs ys.
132 zipD (imapD gang f xs) ys
133 = imapD gang (\i (x,y) -> (f i x,y)) (zipD xs ys)
134
135 "zipD/imapD[2]" forall gang f xs ys.
136 zipD xs (imapD gang f ys)
137 = imapD gang (\i (x,y) -> (x, f i y)) (zipD xs ys)
138
139 "zipD/generateD[1]" forall gang f xs.
140 zipD (generateD gang f) xs
141 = imapD gang (\i x -> (f i, x)) xs
142
143 "zipD/generateD[2]" forall gang f xs.
144 zipD xs (generateD gang f)
145 = imapD gang (\i x -> (x, f i)) xs
146
147 #-}
148
149
150 -- Folding --------------------------------------------------------------------
151 -- | Fold a distributed value.
152 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
153 -- {-# INLINE_DIST foldD #-}
154 {-# NOINLINE foldD #-}
155 foldD g f !d = checkGangD ("here foldD") g d $
156 fold 1 (d `indexD` 0)
157 where
158 !n = gangSize g
159 --
160 fold i x | i == n = x
161 | otherwise = fold (i+1) (f x $ d `indexD` i)
162
163
164 -- | Prefix sum of a distributed value.
165 scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
166 {-# NOINLINE scanD #-}
167 scanD g f z !d = checkGangD (here "scanD") g d $
168 runST (do
169 md <- newMD g
170 s <- scan md 0 z
171 d' <- unsafeFreezeMD md
172 return (d',s))
173 where
174 !n = gangSize g
175 scan :: forall s. MDist a s -> Int -> a -> ST s a
176 scan md i !x | i == n = return x
177 | otherwise = do
178 writeMD md i x
179 scan md (i+1) (f x $ d `indexD` i)
180
181 -- | Combination of map and fold.
182 mapAccumLD :: forall a b acc. (DT a, DT b)
183 => Gang -> (acc -> a -> (acc,b))
184 -> acc -> Dist a -> (acc,Dist b)
185 {-# INLINE_DIST mapAccumLD #-}
186 mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $
187 runST (do
188 md <- newMD g
189 acc' <- go md 0 acc
190 d' <- unsafeFreezeMD md
191 return (acc',d'))
192 where
193 !n = gangSize g
194 go :: MDist b s -> Int -> acc -> ST s acc
195 go md i acc | i == n = return acc
196 | otherwise = case f acc (d `indexD` i) of
197 (acc',b) -> do
198 writeMD md i b
199 go md (i+1) acc'
200
201
202 -- Versions that work on DistST -----------------------------------------------
203 -- NOTE: The following combinators must be strict in the Dists because if they
204 -- are not, the Dist might be evaluated (in parallel) when it is requested in
205 -- the current computation which, again, is parallel. This would break our
206 -- model andlead to a deadlock. Hence the bangs.
207
208 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
209 {-# INLINE mapDST_ #-}
210 mapDST_ g p d = mapDST_' g (\x -> x `deepSeqD` p x) d
211
212 mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
213 mapDST_' g p !d = checkGangD (here "mapDST_") g d $
214 distST_ g (myD d >>= p)
215
216 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
217 {-# INLINE mapDST #-}
218 mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
219
220 mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
221 mapDST' g p !d = checkGangD (here "mapDST_") g d $
222 distST g (myD d >>= p)
223
224 zipWithDST_ :: (DT a, DT b)
225 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
226 {-# INLINE zipWithDST_ #-}
227 zipWithDST_ g p !dx !dy = mapDST_ g (uncurry p) (zipD dx dy)
228
229 zipWithDST :: (DT a, DT b, DT c)
230 => Gang
231 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
232 {-# INLINE zipWithDST #-}
233 zipWithDST g p !dx !dy = mapDST g (uncurry p) (zipD dx dy)
234