Add bangs
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module : Data.Array.Parallel.Unlifted.Distributed.Basics
4 -- Copyright : (c) 2006 Roman Leshchinskiy
5 -- License : see libraries/ndp/LICENSE
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable (GHC Extensions)
10 --
11 -- Standard combinators for distributed types.
12 --
13
14 {-# LANGUAGE CPP #-}
15
16 #include "fusion-phases.h"
17
18 module Data.Array.Parallel.Unlifted.Distributed.Combinators (
19 mapD, zipD, unzipD, fstD, sndD, zipWithD,
20 foldD, scanD, mapAccumLD,
21
22 -- * Monadic combinators
23 mapDST_, mapDST, zipWithDST_, zipWithDST
24 ) where
25
26 import Data.Array.Parallel.Base (
27 (:*:)(..), uncurryS, unsafe_pairS, unsafe_unpairS, ST, runST)
28 import Data.Array.Parallel.Unlifted.Distributed.Gang (
29 Gang, gangSize)
30 import Data.Array.Parallel.Unlifted.Distributed.Types (
31 DT, Dist, indexD, zipD, unzipD, fstD, sndD,
32 newMD, writeMD, unsafeFreezeMD,
33 checkGangD)
34 import Data.Array.Parallel.Unlifted.Distributed.DistST (
35 DistST, distST_, distST, runDistST, myD)
36
37 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
38
39 -- | Map a function over a distributed value.
40 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
41 {-# NOINLINE mapD #-}
42 mapD g f !d = checkGangD (here "mapD") g d
43 (runDistST g (myD d >>= return . f))
44
45 {-# RULES
46
47 "mapD/mapD" forall gang f g d.
48 mapD gang f (mapD gang g d) = mapD gang (\x -> f (g x)) d
49
50 "zipD/mapD[1]" forall gang f xs ys.
51 zipD (mapD gang f xs) ys
52 = mapD gang (unsafe_pairS . (\(xs, ys) -> (f xs, ys)) . unsafe_unpairS)
53 (zipD xs ys)
54
55 "zipD/mapD[2]" forall gang f xs ys.
56 zipD xs (mapD gang f ys)
57 = mapD gang (unsafe_pairS . (\(xs, ys) -> (xs, f ys)) . unsafe_unpairS)
58 (zipD xs ys)
59
60 #-}
61
62 -- zipD, unzipD, fstD, sndD reexported from Types
63
64 -- | Combine two distributed values with the given function.
65 zipWithD :: (DT a, DT b, DT c)
66 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
67 {-# INLINE zipWithD #-}
68 zipWithD g f dx dy = mapD g (uncurry f . unsafe_unpairS) (zipD dx dy)
69
70 -- | Fold a distributed value.
71 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
72 -- {-# INLINE_DIST foldD #-}
73 {-# NOINLINE foldD #-}
74 foldD g f !d = checkGangD ("here foldD") g d $
75 fold 1 (d `indexD` 0)
76 where
77 !n = gangSize g
78 --
79 fold i x | i == n = x
80 | otherwise = fold (i+1) (f x $ d `indexD` i)
81
82 -- | Prefix sum of a distributed value.
83 scanD :: DT a => Gang -> (a -> a -> a) -> a -> Dist a -> Dist a :*: a
84 {-# INLINE_DIST scanD #-}
85 scanD g f z !d = checkGangD (here "scanD") g d $
86 runST (do
87 md <- newMD g
88 s <- scan md 0 z
89 d' <- unsafeFreezeMD md
90 return (d' :*: s))
91 where
92 !n = gangSize g
93 --
94 scan md i x | i == n = return x
95 | otherwise = do
96 writeMD md i x
97 scan md (i+1) (f x $ d `indexD` i)
98
99 mapAccumLD :: (DT a, DT b) => Gang -> (acc -> a -> acc :*: b)
100 -> acc -> Dist a -> acc :*: Dist b
101 {-# INLINE_DIST mapAccumLD #-}
102 mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $
103 runST (do
104 md <- newMD g
105 acc' <- go md 0 acc
106 d' <- unsafeFreezeMD md
107 return (acc' :*: d'))
108 where
109 !n = gangSize g
110
111 go md i acc | i == n = return acc
112 | otherwise = case f acc (d `indexD` i) of
113 acc' :*: b -> do
114 writeMD md i b
115 go md (i+1) acc'
116
117 -- NOTE: The following combinators must be strict in the Dists because if they
118 -- are not, the Dist might be evaluated (in parallel) when it is requested in
119 -- the current computation which, again, is parallel. This would break our
120 -- model andlead to a deadlock. Hence the bangs.
121
122 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
123 mapDST_ g p !d = checkGangD (here "mapDST_") g d $
124 distST_ g (myD d >>= p)
125
126 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
127 mapDST g p !d = checkGangD (here "mapDST_") g d $
128 distST g (myD d >>= p)
129
130 zipWithDST_ :: (DT a, DT b)
131 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
132 zipWithDST_ g p !dx !dy = mapDST_ g (uncurryS p) (zipD dx dy)
133
134 zipWithDST :: (DT a, DT b, DT c)
135 => Gang
136 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
137 zipWithDST g p !dx !dy = mapDST g (uncurryS p) (zipD dx dy)
138