Extract dph-prim-par
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Combinators.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module : Data.Array.Parallel.Unlifted.Distributed.Basics
4 -- Copyright : (c) 2006 Roman Leshchinskiy
5 -- License : see libraries/ndp/LICENSE
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : experimental
9 -- Portability : non-portable (GHC Extensions)
10 --
11 -- Standard combinators for distributed types.
12 --
13
14 module Data.Array.Parallel.Unlifted.Distributed.Combinators (
15 mapD, zipD, unzipD, fstD, sndD, zipWithD,
16 foldD, scanD,
17
18 -- * Monadic combinators
19 mapDST_, mapDST, zipWithDST_, zipWithDST
20 ) where
21
22 import Data.Array.Parallel.Base (
23 (:*:)(..), uncurryS, unsafe_pairS, unsafe_unpairS, ST, runST)
24 import Data.Array.Parallel.Unlifted.Distributed.Gang (
25 Gang, gangSize)
26 import Data.Array.Parallel.Unlifted.Distributed.Types (
27 DT, Dist, indexD, zipD, unzipD, fstD, sndD,
28 newMD, writeMD, unsafeFreezeMD,
29 checkGangD)
30 import Data.Array.Parallel.Unlifted.Distributed.DistST (
31 DistST, distST_, distST, runDistST, myD)
32
33 here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
34
35 -- | Map a function over a distributed value.
36 mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
37 {-# INLINE [1] mapD #-}
38 mapD g f !d = checkGangD (here "mapD") g d
39 (runDistST g (myD d >>= return . f))
40
41 {-# RULES
42
43 "mapD/mapD" forall gang f g d.
44 mapD gang f (mapD gang g d) = mapD gang (\x -> f (g x)) d
45
46 "zipD/mapD[1]" forall gang f xs ys.
47 zipD (mapD gang f xs) ys
48 = mapD gang (unsafe_pairS . (\(xs, ys) -> (f xs, ys)) . unsafe_unpairS)
49 (zipD xs ys)
50
51 "zipD/mapD[2]" forall gang f xs ys.
52 zipD xs (mapD gang f ys)
53 = mapD gang (unsafe_pairS . (\(xs, ys) -> (xs, f ys)) . unsafe_unpairS)
54 (zipD xs ys)
55
56 #-}
57
58 -- zipD, unzipD, fstD, sndD reexported from Types
59
60 -- | Combine two distributed values with the given function.
61 zipWithD :: (DT a, DT b, DT c)
62 => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
63 {-# INLINE zipWithD #-}
64 zipWithD g f dx dy = mapD g (uncurry f . unsafe_unpairS) (zipD dx dy)
65
66 -- | Fold a distributed value.
67 foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
68 foldD g f d = checkGangD ("here foldD") g d $
69 fold 1 (d `indexD` 0)
70 where
71 n = gangSize g
72 --
73 fold i x | i == n = x
74 | otherwise = fold (i+1) (f x $ d `indexD` i)
75
76 -- | Prefix sum of a distributed value.
77 scanD :: DT a => Gang -> (a -> a -> a) -> a -> Dist a -> Dist a :*: a
78 scanD g f z d = checkGangD (here "scanD") g d $
79 runST (do
80 md <- newMD g
81 s <- scan md 0 z
82 d' <- unsafeFreezeMD md
83 return (d' :*: s))
84 where
85 n = gangSize g
86 --
87 scan md i x | i == n = return x
88 | otherwise = do
89 writeMD md i x
90 scan md (i+1) (f x $ d `indexD` i)
91
92 -- NOTE: The following combinators must be strict in the Dists because if they
93 -- are not, the Dist might be evaluated (in parallel) when it is requested in
94 -- the current computation which, again, is parallel. This would break our
95 -- model andlead to a deadlock. Hence the bangs.
96
97 mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
98 mapDST_ g p !d = checkGangD (here "mapDST_") g d $
99 distST_ g (myD d >>= p)
100
101 mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
102 mapDST g p !d = checkGangD (here "mapDST_") g d $
103 distST g (myD d >>= p)
104
105 zipWithDST_ :: (DT a, DT b)
106 => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
107 zipWithDST_ g p !dx !dy = mapDST_ g (uncurryS p) (zipD dx dy)
108
109 zipWithDST :: (DT a, DT b, DT c)
110 => Gang
111 -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
112 zipWithDST g p !dx !dy = mapDST g (uncurryS p) (zipD dx dy)
113