Cleanup dph-prim-par
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE CPP #-}
3 #include "fusion-phases.h"
4
5 -- | Standard combinators for distributed types.
6 module Data.Array.Parallel.Unlifted.Distributed.Combinators (
7 generateD, generateD_cheap,
8 imapD, mapD, zipD, unzipD, fstD, sndD, zipWithD, izipWithD,
9 foldD, scanD, mapAccumLD,
10
11 -- * Monadic combinators
12 mapDST_, mapDST, zipWithDST_, zipWithDST
13 ) where
14 import Data.Array.Parallel.Base ( ST, runST)
15 import Data.Array.Parallel.Unlifted.Distributed.Gang (
16 Gang, gangSize)
17 import Data.Array.Parallel.Unlifted.Distributed.Types (
18 DT, Dist, MDist, indexD, zipD, unzipD, fstD, sndD, deepSeqD,
19 newMD, writeMD, unsafeFreezeMD,
20 checkGangD, measureD, debugD)
21 import Data.Array.Parallel.Unlifted.Distributed.DistST
22 import Debug.Trace
23
24
25 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
26
27
28 -- | Create a distributed value, given a function that makes the value in each thread.
29 generateD :: DT a => Gang -> (Int -> a) -> Dist a
30 {-# NOINLINE generateD #-}
31 generateD g f
32 = runDistST g (myIndex >>= return . f)
33
34
35 -- | Create a distributed value, but run it sequentially (I think?)
36 generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a
37 {-# NOINLINE generateD_cheap #-}
38 generateD_cheap g f
39 = runDistST_seq g (myIndex >>= return . f)
40
41
42 -- Mapping --------------------------------------------------------------------
43 -- | Map a function across all elements of a distributed value.
44 -- The worker function also gets the current thread index.
45 -- As opposed to `imapD'` this version also deepSeqs each element before
46 -- passing it to the function.
47 imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
48 {-# INLINE [0] imapD #-}
49 imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d
50
51
52 -- | Map a function across all elements of a distributed value.
53 -- The worker function also gets the current thread index.
54 imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
55 {-# NOINLINE imapD' #-}
56 imapD' g f !d = checkGangD (here "imapD") g d
57 (runDistST g (do
58 i <- myIndex
59 x <- myD d
60 return (f i x)))
61
62
63 -- | Map a function over a distributed value.
64 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
65 {-# INLINE mapD #-}
66 mapD g = imapD g . const
67
68 {-# RULES
69
70 "imapD/generateD" forall gang f g.
71 imapD gang f (generateD gang g) = generateD gang (\i -> f i (g i))
72
73 "imapD/generateD_cheap" forall gang f g.
74 imapD gang f (generateD_cheap gang g) = generateD gang (\i -> f i (g i))
75
76 "imapD/imapD" forall gang f g d.
77 imapD gang f (imapD gang g d) = imapD gang (\i x -> f i (g i x)) d
78
79 #-}
80
81
82 -- Zipping --------------------------------------------------------------------
83 -- | Combine two distributed values with the given function.
84 zipWithD :: (DT a, DT b, DT c)
85 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
86 {-# INLINE zipWithD #-}
87 zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy)
88
89
90 -- | Combine two distributed values with the given function.
91 -- The worker function also gets the index of the current thread.
92 izipWithD :: (DT a, DT b, DT c)
93 => Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c
94 {-# INLINE izipWithD #-}
95 izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy)
96
97 {-# RULES
98 "zipD/imapD[1]" forall gang f xs ys.
99 zipD (imapD gang f xs) ys
100 = imapD gang (\i (x,y) -> (f i x,y)) (zipD xs ys)
101
102 "zipD/imapD[2]" forall gang f xs ys.
103 zipD xs (imapD gang f ys)
104 = imapD gang (\i (x,y) -> (x, f i y)) (zipD xs ys)
105
106 "zipD/generateD[1]" forall gang f xs.
107 zipD (generateD gang f) xs
108 = imapD gang (\i x -> (f i, x)) xs
109
110 "zipD/generateD[2]" forall gang f xs.
111 zipD xs (generateD gang f)
112 = imapD gang (\i x -> (x, f i)) xs
113
114 #-}
115
116
117 -- Folding --------------------------------------------------------------------
118 -- | Fold a distributed value.
119 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
120 {-# NOINLINE foldD #-}
121 foldD g f !d = checkGangD ("here foldD") g d $
122 fold 1 (d `indexD` 0)
123 where
124 !n = gangSize g
125 --
126 fold i x | i == n = x
127 | otherwise = fold (i+1) (f x $ d `indexD` i)
128
129
130 -- | Prefix sum of a distributed value.
131 scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
132 {-# NOINLINE scanD #-}
133 scanD g f z !d = checkGangD (here "scanD") g d $
134 runST (do
135 md <- newMD g
136 s <- scan md 0 z
137 d' <- unsafeFreezeMD md
138 return (d',s))
139 where
140 !n = gangSize g
141 scan :: forall s. MDist a s -> Int -> a -> ST s a
142 scan md i !x | i == n = return x
143 | otherwise = do
144 writeMD md i x
145 scan md (i+1) (f x $ d `indexD` i)
146
147 -- | Combination of map and fold.
148 mapAccumLD :: forall a b acc. (DT a, DT b)
149 => Gang -> (acc -> a -> (acc,b))
150 -> acc -> Dist a -> (acc,Dist b)
151 {-# INLINE_DIST mapAccumLD #-}
152 mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $
153 runST (do
154 md <- newMD g
155 acc' <- go md 0 acc
156 d' <- unsafeFreezeMD md
157 return (acc',d'))
158 where
159 !n = gangSize g
160 go :: MDist b s -> Int -> acc -> ST s acc
161 go md i acc | i == n = return acc
162 | otherwise = case f acc (d `indexD` i) of
163 (acc',b) -> do
164 writeMD md i b
165 go md (i+1) acc'
166
167
168 -- Versions that work on DistST -----------------------------------------------
169 -- NOTE: The following combinators must be strict in the Dists because if they
170 -- are not, the Dist might be evaluated (in parallel) when it is requested in
171 -- the current computation which, again, is parallel. This would break our
172 -- model andlead to a deadlock. Hence the bangs.
173
174 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
175 {-# INLINE mapDST_ #-}
176 mapDST_ g p d = mapDST_' g (\x -> x `deepSeqD` p x) d
177
178
179 mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
180 mapDST_' g p !d = checkGangD (here "mapDST_") g d $
181 distST_ g (myD d >>= p)
182
183
184 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
185 {-# INLINE mapDST #-}
186 mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
187
188 mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
189 mapDST' g p !d = checkGangD (here "mapDST_") g d $
190 distST g (myD d >>= p)
191
192
193 zipWithDST_ :: (DT a, DT b)
194 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
195 {-# INLINE zipWithDST_ #-}
196 zipWithDST_ g p !dx !dy = mapDST_ g (uncurry p) (zipD dx dy)
197
198
199 zipWithDST :: (DT a, DT b, DT c)
200 => Gang
201 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
202 {-# INLINE zipWithDST #-}
203 zipWithDST g p !dx !dy = mapDST g (uncurry p) (zipD dx dy)
204