dph-prim-par: break up distributed types module and enable warnings
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 {-# LANGUAGE CPP #-}
3 #include "fusion-phases.h"
4
5 -- | Standard combinators for distributed types.
6 module Data.Array.Parallel.Unlifted.Distributed.Combinators (
7 generateD, generateD_cheap,
8 imapD, mapD, zipD, unzipD, fstD, sndD, zipWithD, izipWithD,
9 foldD, scanD, mapAccumLD,
10
11 -- * Monadic combinators
12 mapDST_, mapDST, zipWithDST_, zipWithDST
13 ) where
14 import Data.Array.Parallel.Base ( ST, runST)
15 import Data.Array.Parallel.Unlifted.Distributed.Gang
16 import Data.Array.Parallel.Unlifted.Distributed.Types
17 import Data.Array.Parallel.Unlifted.Distributed.DistST
18 import Debug.Trace
19
20
21 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
22
23
24 -- | Create a distributed value, given a function that makes the value in each thread.
25 generateD :: DT a => Gang -> (Int -> a) -> Dist a
26 {-# NOINLINE generateD #-}
27 generateD g f
28 = runDistST g (myIndex >>= return . f)
29
30
31 -- | Create a distributed value, but run it sequentially (I think?)
32 generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a
33 {-# NOINLINE generateD_cheap #-}
34 generateD_cheap g f
35 = runDistST_seq g (myIndex >>= return . f)
36
37
38 -- Mapping --------------------------------------------------------------------
39 -- | Map a function across all elements of a distributed value.
40 -- The worker function also gets the current thread index.
41 -- As opposed to `imapD'` this version also deepSeqs each element before
42 -- passing it to the function.
43 imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
44 {-# INLINE [0] imapD #-}
45 imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d
46
47
48 -- | Map a function across all elements of a distributed value.
49 -- The worker function also gets the current thread index.
50 imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
51 {-# NOINLINE imapD' #-}
52 imapD' g f !d = checkGangD (here "imapD") g d
53 (runDistST g (do
54 i <- myIndex
55 x <- myD d
56 return (f i x)))
57
58
59 -- | Map a function over a distributed value.
60 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
61 {-# INLINE mapD #-}
62 mapD g = imapD g . const
63
64 {-# RULES
65
66 "imapD/generateD" forall gang f g.
67 imapD gang f (generateD gang g) = generateD gang (\i -> f i (g i))
68
69 "imapD/generateD_cheap" forall gang f g.
70 imapD gang f (generateD_cheap gang g) = generateD gang (\i -> f i (g i))
71
72 "imapD/imapD" forall gang f g d.
73 imapD gang f (imapD gang g d) = imapD gang (\i x -> f i (g i x)) d
74
75 #-}
76
77
78 -- Zipping --------------------------------------------------------------------
79 -- | Combine two distributed values with the given function.
80 zipWithD :: (DT a, DT b, DT c)
81 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
82 {-# INLINE zipWithD #-}
83 zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy)
84
85
86 -- | Combine two distributed values with the given function.
87 -- The worker function also gets the index of the current thread.
88 izipWithD :: (DT a, DT b, DT c)
89 => Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c
90 {-# INLINE izipWithD #-}
91 izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy)
92
93 {-# RULES
94 "zipD/imapD[1]" forall gang f xs ys.
95 zipD (imapD gang f xs) ys
96 = imapD gang (\i (x,y) -> (f i x,y)) (zipD xs ys)
97
98 "zipD/imapD[2]" forall gang f xs ys.
99 zipD xs (imapD gang f ys)
100 = imapD gang (\i (x,y) -> (x, f i y)) (zipD xs ys)
101
102 "zipD/generateD[1]" forall gang f xs.
103 zipD (generateD gang f) xs
104 = imapD gang (\i x -> (f i, x)) xs
105
106 "zipD/generateD[2]" forall gang f xs.
107 zipD xs (generateD gang f)
108 = imapD gang (\i x -> (x, f i)) xs
109
110 #-}
111
112
113 -- Folding --------------------------------------------------------------------
114 -- | Fold a distributed value.
115 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
116 {-# NOINLINE foldD #-}
117 foldD g f !d = checkGangD ("here foldD") g d $
118 fold 1 (d `indexD` 0)
119 where
120 !n = gangSize g
121 --
122 fold i x | i == n = x
123 | otherwise = fold (i+1) (f x $ d `indexD` i)
124
125
126 -- | Prefix sum of a distributed value.
127 scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
128 {-# NOINLINE scanD #-}
129 scanD g f z !d = checkGangD (here "scanD") g d $
130 runST (do
131 md <- newMD g
132 s <- scan md 0 z
133 d' <- unsafeFreezeMD md
134 return (d',s))
135 where
136 !n = gangSize g
137 scan :: forall s. MDist a s -> Int -> a -> ST s a
138 scan md i !x | i == n = return x
139 | otherwise = do
140 writeMD md i x
141 scan md (i+1) (f x $ d `indexD` i)
142
143 -- | Combination of map and fold.
144 mapAccumLD :: forall a b acc. (DT a, DT b)
145 => Gang -> (acc -> a -> (acc,b))
146 -> acc -> Dist a -> (acc,Dist b)
147 {-# INLINE_DIST mapAccumLD #-}
148 mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $
149 runST (do
150 md <- newMD g
151 acc' <- go md 0 acc
152 d' <- unsafeFreezeMD md
153 return (acc',d'))
154 where
155 !n = gangSize g
156 go :: MDist b s -> Int -> acc -> ST s acc
157 go md i acc | i == n = return acc
158 | otherwise = case f acc (d `indexD` i) of
159 (acc',b) -> do
160 writeMD md i b
161 go md (i+1) acc'
162
163
164 -- Versions that work on DistST -----------------------------------------------
165 -- NOTE: The following combinators must be strict in the Dists because if they
166 -- are not, the Dist might be evaluated (in parallel) when it is requested in
167 -- the current computation which, again, is parallel. This would break our
168 -- model andlead to a deadlock. Hence the bangs.
169
170 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
171 {-# INLINE mapDST_ #-}
172 mapDST_ g p d = mapDST_' g (\x -> x `deepSeqD` p x) d
173
174
175 mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
176 mapDST_' g p !d = checkGangD (here "mapDST_") g d $
177 distST_ g (myD d >>= p)
178
179
180 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
181 {-# INLINE mapDST #-}
182 mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
183
184 mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
185 mapDST' g p !d = checkGangD (here "mapDST_") g d $
186 distST g (myD d >>= p)
187
188
189 zipWithDST_ :: (DT a, DT b)
190 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
191 {-# INLINE zipWithDST_ #-}
192 zipWithDST_ g p !dx !dy = mapDST_ g (uncurry p) (zipD dx dy)
193
194
195 zipWithDST :: (DT a, DT b, DT c)
196 => Gang
197 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
198 {-# INLINE zipWithDST #-}
199 zipWithDST g p !dx !dy = mapDST g (uncurry p) (zipD dx dy)
200