dph-lifted-vseg: fix edge case in concatl when the last segment is empty
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / UPSegd.hs
1 {-# LANGUAGE CPP #-}
2 #include "fusion-phases.h"
3
4 -- | Parallel segment descriptors.
5 module Data.Array.Parallel.Unlifted.Parallel.UPSegd (
6 -- * Types
7 UPSegd, valid,
8
9 -- * Constructors
10 mkUPSegd, fromUSegd,
11 empty, singleton, fromLengths,
12
13 -- * Projections
14 length,
15 takeUSegd,
16 takeDistributed,
17 takeLengths,
18 takeIndices,
19 takeElements,
20
21 -- * Indices
22 indicesP,
23
24 -- * Replicate
25 replicateWithP,
26
27 -- * Segmented Folds
28 foldWithP,
29 fold1WithP,
30 sumWithP,
31 foldSegsWithP
32 ) where
33 import Data.Array.Parallel.Unlifted.Distributed
34 import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd)
35 import qualified Data.Array.Parallel.Unlifted.Distributed.USegd as USegd
36 import qualified Data.Array.Parallel.Unlifted.Sequential.Basics as Seq
37 import qualified Data.Array.Parallel.Unlifted.Sequential.Combinators as Seq
38 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
39 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd
40 import Data.Array.Parallel.Pretty hiding (empty)
41 import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector, MVector, Unbox)
42
43 import Control.Monad.ST
44 import Prelude hiding (length)
45
46
47 -- | A parallel segment descriptor holds a global (undistributed) segment
48 -- desciptor, as well as a distributed version. The distributed version
49 -- describes how to split work on the segmented array over the gang.
50 data UPSegd
51 = UPSegd
52 { upsegd_usegd :: !USegd
53 -- ^ Segment descriptor that describes the whole array.
54
55 , upsegd_dsegd :: Dist ((USegd,Int),Int)
56 -- ^ Segment descriptor for each chunk,
57 -- along with segment id of first slice in the chunk,
58 -- and the offset of that slice in its segment.
59 -- See docs of `splitSegdOfElemsD` for an example.
60 }
61
62
63 -- Pretty ---------------------------------------------------------------------
64 instance PprPhysical UPSegd where
65 pprp (UPSegd usegd dsegd)
66 = text "UPSegd"
67 $$ (nest 7 $ vcat
68 [ text "usegd: " <+> pprp usegd
69 , text "dsegd: " <+> pprp dsegd])
70
71
72 -- Valid ----------------------------------------------------------------------
73 -- | O(1).
74 -- Check the internal consistency of a parallel segment descriptor.
75 --
76 -- * TODO: this doesn't do any checks yet
77 valid :: UPSegd -> Bool
78 valid _ = True
79 {-# NOINLINE valid #-}
80 -- NOINLINE because it's only used during debugging anyway.
81
82
83 -- Constructors ---------------------------------------------------------------
84 -- | O(1). Construct a new parallel segment descriptor.
85 mkUPSegd
86 :: Vector Int -- ^ Length of each segment.
87 -> Vector Int -- ^ Starting index of each segment.
88 -> Int -- ^ Total number of elements in the flat array.
89 -> UPSegd
90
91 mkUPSegd lens idxs n
92 = fromUSegd (USegd.mkUSegd lens idxs n)
93 {-# INLINE_UP mkUPSegd #-}
94
95
96 -- | Convert a global `USegd` to a parallel `UPSegd` by distributing
97 -- it across the gang.
98 fromUSegd :: USegd -> UPSegd
99 fromUSegd segd = UPSegd segd (USegd.splitSegdOnElemsD theGang segd)
100 {-# INLINE_UP fromUSegd #-}
101
102
103 -- | O(1). Yield an empty segment descriptor, with no elements or segments.
104 empty :: UPSegd
105 empty = fromUSegd USegd.empty
106 {-# INLINE_UP empty #-}
107
108
109 -- | O(1). Yield a singleton segment descriptor.
110 -- The single segment covers the given number of elements.
111 singleton :: Int -> UPSegd
112 singleton n = fromUSegd $ USegd.singleton n
113 {-# INLINE_UP singleton #-}
114
115
116 -- | O(n). Convert an array of segment lengths into a parallel segment descriptor.
117 --
118 -- The array contains the length of each segment, and we compute the
119 -- indices from that. Runtime is O(n) in the number of segments.
120 --
121 fromLengths :: Vector Int -> UPSegd
122 fromLengths = fromUSegd . USegd.fromLengths
123 {-# INLINE_UP fromLengths #-}
124
125
126 -- Projections ----------------------------------------------------------------
127 -- INLINE trivial projections as they'll expand to a single record selector.
128
129 -- | O(1). Yield the overall number of segments.
130 length :: UPSegd -> Int
131 length = USegd.length . upsegd_usegd
132 {-# INLINE length #-}
133
134
135 -- | O(1). Yield the global `USegd` of a `UPSegd`.
136 takeUSegd :: UPSegd -> USegd
137 takeUSegd = upsegd_usegd
138 {-# INLINE takeUSegd #-}
139
140
141 -- | O(1). Yield the distributed `USegd` of a `UPSegd`.
142 --
143 -- We get a plain `USegd` for each chunk, the segment id of the first
144 -- slice in the chunk, and the starting offset of that slice in its segment.
145 --
146 takeDistributed :: UPSegd -> Dist ((USegd,Int),Int)
147 takeDistributed = upsegd_dsegd
148 {-# INLINE takeDistributed #-}
149
150
151 -- | O(1). Yield the lengths of the individual segments.
152 takeLengths :: UPSegd -> Vector Int
153 takeLengths = USegd.takeLengths . upsegd_usegd
154 {-# INLINE takeLengths #-}
155
156
157 -- | O(1). Yield the segment indices.
158 takeIndices :: UPSegd -> Vector Int
159 takeIndices = USegd.takeIndices . upsegd_usegd
160 {-# INLINE takeIndices #-}
161
162
163 -- | O(1). Yield the total number of array elements.
164 --
165 -- @takeElements upsegd = sum (takeLengths upsegd)@
166 --
167 takeElements :: UPSegd -> Int
168 takeElements = USegd.takeElements . upsegd_usegd
169 {-# INLINE takeElements #-}
170
171
172 -- Indices --------------------------------------------------------------------
173 -- | O(n). Yield a vector containing indicies that give the position of each
174 -- member of the flat array in its corresponding segment.
175 --
176 -- @indicesP (fromLengths [5, 2, 3]) = [0,1,2,3,4,0,1,0,1,2]@
177 --
178 indicesP :: UPSegd -> Vector Int
179 indicesP
180 = joinD theGang balanced
181 . mapD theGang indices
182 . takeDistributed
183 where
184 indices ((segd,k),off) = Seq.indicesSU' off segd
185 {-# NOINLINE indicesP #-}
186 -- NOINLINE because we're not using it yet.
187
188
189 -- Replicate ------------------------------------------------------------------
190 -- | Copying segmented replication. Each element of the vector is physically
191 -- copied according to the length of each segment in the segment descriptor.
192 --
193 -- @replicateWith (fromLengths [3, 1, 2]) [5, 6, 7] = [5, 5, 5, 6, 7, 7]@
194 --
195 replicateWithP :: Unbox a => UPSegd -> Vector a -> Vector a
196 replicateWithP segd !xs
197 = joinD theGang balanced
198 . mapD theGang rep
199 $ takeDistributed segd
200 where
201 rep ((dsegd,di),_)
202 = Seq.replicateSU dsegd (Seq.slice xs di (USegd.length dsegd))
203 {-# INLINE_UP replicateWithP #-}
204
205
206 -- Fold -----------------------------------------------------------------------
207 -- | Fold segments specified by a `UPSegd`.
208 foldWithP :: Unbox a
209 => (a -> a -> a) -> a -> UPSegd -> Vector a -> Vector a
210 foldWithP f !z = foldSegsWithP f (Seq.foldlSU f z)
211 {-# INLINE_UP foldWithP #-}
212
213
214 -- | Fold segments specified by a `UPSegd`, with a non-empty vector.
215 fold1WithP :: Unbox a
216 => (a -> a -> a) -> UPSegd -> Vector a -> Vector a
217 fold1WithP f = foldSegsWithP f (Seq.fold1SU f)
218 {-# INLINE_UP fold1WithP #-}
219
220
221 -- | Sum up segments specified by a `UPSegd`.
222 sumWithP :: (Num e, Unbox e) => UPSegd -> Vector e -> Vector e
223 sumWithP = foldWithP (+) 0
224 {-# INLINE_UP sumWithP #-}
225
226
227 -- | Fold the segments specified by a `UPSegd`.
228 --
229 -- This low level function takes a per-element worker and a per-segment worker.
230 -- It folds all the segments with the per-segment worker, then uses the
231 -- per-element worker to fixup the partial results when a segment
232 -- is split across multiple threads.
233 --
234 foldSegsWithP
235 :: Unbox a
236 => (a -> a -> a)
237 -> (USegd -> Vector a -> Vector a)
238 -> UPSegd -> Vector a -> Vector a
239
240 {-# INLINE_UP foldSegsWithP #-}
241 foldSegsWithP fElem fSeg segd xs
242 = dcarry `seq` drs `seq`
243 runST (do
244 mrs <- joinDM theGang drs
245 fixupFold fElem mrs dcarry
246 Seq.unsafeFreeze mrs)
247
248 where (dcarry,drs)
249 = unzipD
250 $ mapD theGang partial
251 $ zipD (takeDistributed segd)
252 (splitD theGang balanced xs)
253
254 partial (((segd, k), off), as)
255 = let rs = fSeg segd as
256 {-# INLINE [0] n #-}
257 n | off == 0 = 0
258 | otherwise = 1
259
260 in ((k, Seq.take n rs), Seq.drop n rs)
261
262
263 fixupFold
264 :: Unbox a
265 => (a -> a -> a)
266 -> MVector s a
267 -> Dist (Int,Vector a)
268 -> ST s ()
269 {-# NOINLINE fixupFold #-}
270 fixupFold f !mrs !dcarry = go 1
271 where
272 !p = gangSize theGang
273
274 go i | i >= p = return ()
275 | Seq.null c = go (i+1)
276 | otherwise = do
277 x <- Seq.read mrs k
278 Seq.write mrs k (f x (c Seq.! 0))
279 go (i + 1)
280 where
281 (k,c) = indexD dcarry i