aa29de64396d1550076375aed01aa6d5bf5f4dec
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 -- | Parallel segment descriptors.
5 module Data.Array.Parallel.Unlifted.Parallel.UPSegd (
6 -- * Types
7 UPSegd, valid,
8
9 -- * Constructors
10 mkUPSegd, fromUSegd,
11 empty, singleton, fromLengths,
12
13 -- * Projections
14 length,
15 takeUSegd,
16 takeDistributed,
17 takeLengths,
18 takeIndices,
19 takeElements,
20
21 -- * Indices
22 indicesP,
23
24 -- * Replicate
25 replicateWithP,
26
27 -- * Segmented Folds
28 foldWithP,
29 fold1WithP,
30 sumWithP,
31 foldSegsWithP
32 ) where
33 import Data.Array.Parallel.Unlifted.Distributed
34 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
35 import qualified Data.Array.Parallel.Unlifted.Distributed.USegd as USegd
36 import qualified Data.Array.Parallel.Unlifted.Sequential.Basics as Seq
37 import qualified Data.Array.Parallel.Unlifted.Sequential.Combinators as Seq
38 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
39 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
40 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
41
42 import Control.Monad.ST
43 import Prelude hiding (length)
44
45
46 -- | A parallel segment descriptor holds a global (undistributed) segment desciptor,
47 -- as well as a distributed version. The distributed version describes how to split
48 -- work on the segmented array over the gang.
49 data UPSegd
50 = UPSegd
51 { upsegd_usegd :: !USegd
52 -- ^ Segment descriptor that describes the whole array.
53
54 , upsegd_dsegd :: Dist ((USegd,Int),Int)
55 -- ^ Segment descriptor for each chunk,
56 -- along with segment id of first slice in the chunk,
57 -- and the offset of that slice in its segment.
58 -- See docs of `splitSegdOfElemsD` for an example.
59 }
60
61
62 -- Valid ----------------------------------------------------------------------
63 -- | O(1).
64 -- Check the internal consistency of a parallel segment descriptor.
65 --
66 -- * TODO: this doesn't do any checks yet
67 valid :: UPSegd -> Bool
68 valid _ = True
69 {-# NOINLINE valid #-}
70 -- NOINLINE because it's only used during debugging anyway.
71
72
73 -- Constructors ---------------------------------------------------------------
74 -- | O(1). Construct a new parallel segment descriptor.
75 mkUPSegd
76 :: Vector Int -- ^ Length of each segment.
77 -> Vector Int -- ^ Starting index of each segment.
78 -> Int -- ^ Total number of elements in the flat array.
79 -> UPSegd
80
81 mkUPSegd lens idxs n
82 = fromUSegd (USegd.mkUSegd lens idxs n)
83 {-# INLINE_UP mkUPSegd #-}
84
85
86 -- | Convert a global `USegd` to a parallel `UPSegd` by distributing
87 -- it across the gang.
88 fromUSegd :: USegd -> UPSegd
89 fromUSegd segd = UPSegd segd (USegd.splitSegdOnElemsD theGang segd)
90 {-# INLINE_UP fromUSegd #-}
91
92
93 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
94 empty :: UPSegd
95 empty = fromUSegd USegd.empty
96 {-# INLINE_UP empty #-}
97
98
99 -- | O(1). Yield a singleton segment descriptor.
100 -- The single segment covers the given number of elements.
101 singleton :: Int -> UPSegd
102 singleton n = fromUSegd $ USegd.singleton n
103 {-# INLINE_UP singleton #-}
104
105
106 -- | O(n). Convert an array of segment lengths into a parallel segment descriptor.
107 --
108 -- The array contains the length of each segment, and we compute the
109 -- indices from that. Runtime is O(n) in the number of segments.
110 --
111 fromLengths :: Vector Int -> UPSegd
112 fromLengths = fromUSegd . USegd.fromLengths
113 {-# INLINE_UP fromLengths #-}
114
115
116 -- Projections ----------------------------------------------------------------
117 -- INLINE trivial projections as they'll expand to a single record selector.
118
119 -- | O(1). Yield the overall number of segments.
120 length :: UPSegd -> Int
121 length = USegd.length . upsegd_usegd
122 {-# INLINE length #-}
123
124
125 -- | O(1). Yield the global `USegd` of a `UPSegd`.
126 takeUSegd :: UPSegd -> USegd
127 takeUSegd = upsegd_usegd
128 {-# INLINE takeUSegd #-}
129
130
131 -- | O(1). Yield the distributed `USegd` of a `UPSegd`.
132 --
133 -- We get a plain `USegd` for each chunk, the segment id of the first
134 -- slice in the chunk, and the starting offset of that slice in its segment.
135 --
136 takeDistributed :: UPSegd -> Dist ((USegd,Int),Int)
137 takeDistributed = upsegd_dsegd
138 {-# INLINE takeDistributed #-}
139
140
141 -- | O(1). Yield the lengths of the individual segments.
142 takeLengths :: UPSegd -> Vector Int
143 takeLengths = USegd.takeLengths . upsegd_usegd
144 {-# INLINE takeLengths #-}
145
146
147 -- | O(1). Yield the segment indices.
148 takeIndices :: UPSegd -> Vector Int
149 takeIndices = USegd.takeIndices . upsegd_usegd
150 {-# INLINE takeIndices #-}
151
152
153 -- | O(1). Yield the total number of array elements.
154 --
155 -- @takeElements upsegd = sum (takeLengths upsegd)@
156 --
157 takeElements :: UPSegd -> Int
158 takeElements = USegd.takeElements . upsegd_usegd
159 {-# INLINE takeElements #-}
160
161
162 -- Indices --------------------------------------------------------------------
163 -- | O(n). Yield a vector containing indicies that give the position of each
164 -- member of the flat array in its corresponding segment.
165 --
166 -- @indicesP (fromLengths [5, 2, 3]) = [0,1,2,3,4,0,1,0,1,2]@
167 --
168 indicesP :: UPSegd -> Vector Int
169 indicesP
170 = joinD theGang balanced
171 . mapD theGang indices
172 . takeDistributed
173 where
174 indices ((segd,k),off) = Seq.indicesSU' off segd
175 {-# NOINLINE indicesP #-}
176 -- NOINLINE because we're not using it yet.
177
178
179 -- Replicate ------------------------------------------------------------------
180 -- | Copying segmented replication. Each element of the vector is physically
181 -- copied according to the length of each segment in the segment descriptor.
182 --
183 -- @replicateWith (fromLengths [3, 1, 2]) [5, 6, 7] = [5, 5, 5, 6, 7, 7]@
184 --
185 replicateWithP :: Unbox a => UPSegd -> Vector a -> Vector a
186 replicateWithP segd !xs
187 = joinD theGang balanced
188 . mapD theGang rep
189 $ takeDistributed segd
190 where
191 rep ((dsegd,di),_)
192 = Seq.replicateSU dsegd (Seq.slice xs di (USegd.length dsegd))
193 {-# INLINE_UP replicateWithP #-}
194
195
196 -- Fold -----------------------------------------------------------------------
197 -- | Fold segments specified by a `UPSegd`.
198 foldWithP :: Unbox a
199 => (a -> a -> a) -> a -> UPSegd -> Vector a -> Vector a
200 foldWithP f !z = foldSegsWithP f (Seq.foldlSU f z)
201 {-# INLINE_UP foldWithP #-}
202
203
204 -- | Fold segments specified by a `UPSegd`, with a non-empty vector.
205 fold1WithP :: Unbox a
206 => (a -> a -> a) -> UPSegd -> Vector a -> Vector a
207 fold1WithP f = foldSegsWithP f (Seq.fold1SU f)
208 {-# INLINE_UP fold1WithP #-}
209
210
211 -- | Sum up segments specified by a `UPSegd`.
212 sumWithP :: (Num e, Unbox e) => UPSegd -> Vector e -> Vector e
213 sumWithP = foldWithP (+) 0
214 {-# INLINE_UP sumWithP #-}
215
216
217 -- | Fold the segments specified by a `UPSegd`.
218 --
219 -- This low level function takes a per-element worker and a per-segment worker.
220 -- It folds all the segments with the per-segment worker, then uses the
221 -- per-element worker to fixup the partial results when a segment
222 -- is split across multiple threads.
223 --
224 foldSegsWithP
225 :: Unbox a
226 => (a -> a -> a)
227 -> (USegd -> Vector a -> Vector a)
228 -> UPSegd -> Vector a -> Vector a
229
230 {-# INLINE_UP foldSegsWithP #-}
231 foldSegsWithP fElem fSeg segd xs
232 = dcarry `seq` drs `seq`
233 runST (do
234 mrs <- joinDM theGang drs
235 fixupFold fElem mrs dcarry
236 Seq.unsafeFreeze mrs)
237
238 where (dcarry,drs)
239 = unzipD
240 $ mapD theGang partial
241 $ zipD (takeDistributed segd)
242 (splitD theGang balanced xs)
243
244 partial (((segd, k), off), as)
245 = let rs = fSeg segd as
246 {-# INLINE [0] n #-}
247 n | off == 0 = 0
248 | otherwise = 1
249
250 in ((k, Seq.take n rs), Seq.drop n rs)
251
252
253 fixupFold
254 :: Unbox a
255 => (a -> a -> a)
256 -> MVector s a
257 -> Dist (Int,Vector a)
258 -> ST s ()
259 {-# NOINLINE fixupFold #-}
260 fixupFold f !mrs !dcarry = go 1
261 where
262 !p = gangSize theGang
263
264 go i | i >= p = return ()
265 | Seq.null c = go (i+1)
266 | otherwise = do
267 x <- Seq.read mrs k
268 Seq.write mrs k (f x (c Seq.! 0))
269 go (i + 1)
270 where
271 (k,c) = indexD dcarry i