cc88c112b748e692ca07aaa99fec8d3fd7dca574
[packages/dph.git] / dph-base / Data / Array / Parallel / Stream.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module : Data.Array.Parallel.Stream
4 -- Copyright : (c) 2010 Roman Leshchinskiy
5 -- License : see libraries/ndp/LICENSE
6 --
7 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
8 -- Stability : internal
9 -- Portability : non-portable (existentials)
10 --
11 -- Stream functions not implemented in vector
12 --
13 -- TODO: The use of INLINE pragmas in some of these function isn't consistent.
14 -- for indexedS and combine2ByTagS, there is an INLINE_INNER on the 'next'
15 -- function, but replicateEachS uses a plain INLINE and fold1SS uses
16 -- a hard INLINE [0]. Can we make a rule that all top-level stream functions
17 -- in this module have INLINE_STREAM, and all 'next' functions have
18 -- INLINE_INNER? If not we should document the reasons for the special cases.
19 --
20 -- TODO: The behavour of indicesSS looks suspiciously inconsistent.
21 --
22
23 #include "fusion-phases.h"
24
25 module Data.Array.Parallel.Stream (
26
27 -- * Flat stream operators
28 indexedS, replicateEachS, replicateEachRS,
29 interleaveS, combine2ByTagS,
30 enumFromToEachS, enumFromStepLenEachS,
31
32 -- * Segmented stream operators
33 foldSS, fold1SS, combineSS, appendSS,
34 foldValuesR,
35 indicesSS
36 ) where
37
38 import Data.Array.Parallel.Base ( Tag )
39
40 import qualified Data.Vector.Fusion.Stream as S
41 import Data.Vector.Fusion.Stream.Monadic ( Stream(..), Step(..) )
42 import Data.Vector.Fusion.Stream.Size ( Size(..) )
43
44 -- | Tag each element of an stream with its index in that stream.
45 --
46 -- @
47 -- indexed [42,93,13]
48 -- = [(0,42), (1,93), (2,13)]
49 -- @
50 indexedS :: S.Stream a -> S.Stream (Int,a)
51 {-# INLINE_STREAM indexedS #-}
52 indexedS (Stream next s n) = Stream next' (0,s) n
53 where
54 {-# INLINE_INNER next' #-}
55 next' (i,s) = do
56 r <- next s
57 case r of
58 Yield x s' -> return $ Yield (i,x) (i+1,s')
59 Skip s' -> return $ Skip (i,s')
60 Done -> return Done
61
62
63 -- | Given a stream of pairs containing a count an an element,
64 -- replicate element the number of times given by the count.
65 --
66 -- The first parameter sets the size hint of the resulting stream.
67 --
68 -- @
69 -- replicateEach 10 [(2,10), (5,20), (3,30)]
70 -- = [10,10,20,20,20,20,20,30,30,30]
71 -- @
72 replicateEachS :: Int -> S.Stream (Int,a) -> S.Stream a
73 {-# INLINE_STREAM replicateEachS #-}
74 replicateEachS n (Stream next s _) =
75 Stream next' (0,Nothing,s) (Exact n)
76 where
77 {-# INLINE next' #-}
78 next' (0, _, s) =
79 do
80 r <- next s
81 case r of
82 Done -> return Done
83 Skip s' -> return $ Skip (0, Nothing, s')
84 Yield (k,x) s' -> return $ Skip (k, Just x,s')
85 next' (k,Nothing,s) = return Done -- FIXME: unreachable
86 next' (k,Just x,s) = return $ Yield x (k-1,Just x,s)
87
88
89 -- | Repeat each element in the stream the given number of times.
90 --
91 -- @
92 -- replicateEach 2 [10,20,30]
93 -- = [10,10,20,20,30,30]
94 -- @
95 --
96 replicateEachRS :: Int -> S.Stream a -> S.Stream a
97 {-# INLINE_STREAM replicateEachRS #-}
98 replicateEachRS !n (Stream next s sz)
99 = Stream next' (0,Nothing,s) (sz `multSize` n)
100 where
101 next' (0,_,s) =
102 do
103 r <- next s
104 case r of
105 Done -> return Done
106 Skip s' -> return $ Skip (0,Nothing,s')
107 Yield x s' -> return $ Skip (n,Just x,s')
108 next' (i,Nothing,s) = return Done -- FIXME: unreachable
109 next' (i,Just x,s) = return $ Yield x (i-1,Just x,s)
110
111
112 -- | Multiply a size hint by a scalar.
113 multSize :: Size -> Int -> Size
114 multSize (Exact n) k = Exact (n*k)
115 multSize (Max n) k = Max (n*k)
116 multSize Unknown _ = Unknown
117
118
119 -- | Interleave the elements of two streams. We alternate between the first
120 -- and second streams, stopping when we can't find a matching element.
121 --
122 -- @
123 -- interleave [2,3,4] [10,20,30] = [2,10,3,20,4,30]
124 -- interleave [2,3] [10,20,30] = [2,10,3,20]
125 -- interleave [2,3,4] [10,20] = [2,10,3,20,4]
126 -- @
127 --
128 interleaveS :: S.Stream a -> S.Stream a -> S.Stream a
129 {-# INLINE_STREAM interleaveS #-}
130 interleaveS (Stream next1 s1 n1) (Stream next2 s2 n2)
131 = Stream next (False,s1,s2) (n1+n2)
132 where
133 {-# INLINE next #-}
134 next (False,s1,s2) =
135 do
136 r <- next1 s1
137 case r of
138 Yield x s1' -> return $ Yield x (True ,s1',s2)
139 Skip s1' -> return $ Skip (False,s1',s2)
140 Done -> return Done
141
142 next (True,s1,s2) =
143 do
144 r <- next2 s2
145 case r of
146 Yield x s2' -> return $ Yield x (False,s1,s2')
147 Skip s2' -> return $ Skip (True ,s1,s2')
148 -- FIXME: error
149 Done -> return Done
150
151
152 -- | Combine two streams, using a tag stream to tell us which of the data
153 -- streams to take the next element from.
154 --
155 -- If there are insufficient elements in the data strams for the provided
156 -- tag stream then `error`.
157 --
158 -- @
159 -- combine2ByTag [0,1,1,0,0,1] [1,2,3] [4,5,6]
160 -- = [1,4,5,2,3,6]
161 -- @
162 --
163 combine2ByTagS :: S.Stream Tag -> S.Stream a -> S.Stream a -> S.Stream a
164 {-# INLINE_STREAM combine2ByTagS #-}
165 combine2ByTagS (Stream next_tag s m) (Stream next0 s0 _)
166 (Stream next1 s1 _)
167 = Stream next (Nothing,s,s0,s1) m
168 where
169 {-# INLINE_INNER next #-}
170 next (Nothing,s,s0,s1)
171 = do
172 r <- next_tag s
173 case r of
174 Done -> return Done
175 Skip s' -> return $ Skip (Nothing,s',s0,s1)
176 Yield t s' -> return $ Skip (Just t, s',s0,s1)
177
178 next (Just 0,s,s0,s1)
179 = do
180 r <- next0 s0
181 case r of
182 Done -> error "combine2ByTagS: stream 1 too short"
183 Skip s0' -> return $ Skip (Just 0, s,s0',s1)
184 Yield x s0' -> return $ Yield x (Nothing,s,s0',s1)
185
186 next (Just t,s,s0,s1)
187 = do
188 r <- next1 s1
189 case r of
190 Done -> error "combine2ByTagS: stream 2 too short"
191 Skip s1' -> return $ Skip (Just t, s,s0,s1')
192 Yield x s1' -> return $ Yield x (Nothing,s,s0,s1')
193
194
195 -- | Create a stream of integer ranges. The pairs in the input stream
196 -- give the first and last value of each range.
197 --
198 -- The first parameter gives the size hint for the resulting stream.
199 --
200 -- @
201 -- enumFromToEach 11 [(2,5), (10,16), (20,22)]
202 -- = [2,3,4,5,10,11,12,13,14,15,16,20,21,22]
203 -- @
204 --
205 enumFromToEachS :: Int -> S.Stream (Int,Int) -> S.Stream Int
206 {-# INLINE_STREAM enumFromToEachS #-}
207 enumFromToEachS n (Stream next s _)
208 = Stream next' (Nothing,s) (Exact n)
209 where
210 {-# INLINE_INNER next' #-}
211 next' (Nothing,s)
212 = do
213 r <- next s
214 case r of
215 Yield (k,m) s' -> return $ Skip (Just (k,m),s')
216 Skip s' -> return $ Skip (Nothing, s')
217 Done -> return Done
218
219 next' (Just (k,m),s)
220 | k > m = return $ Skip (Nothing, s)
221 | otherwise = return $ Yield k (Just (k+1,m),s)
222
223
224 -- | Create a stream of integer ranges. The triples in the input stream
225 -- give the first value, increment, length of each range.
226 --
227 -- The first parameter gives the size hint for the resulting stream.
228 --
229 -- @
230 -- enumFromStepLenEach [(1,1,5), (10,2,4), (20,3,5)]
231 -- = [1,2,3,4,5,10,12,14,16,20,23,26,29,32]
232 -- @
233 --
234 enumFromStepLenEachS :: Int -> S.Stream (Int,Int,Int) -> S.Stream Int
235 {-# INLINE_STREAM enumFromStepLenEachS #-}
236 enumFromStepLenEachS len (Stream next s _)
237 = Stream next' (Nothing,s) (Exact len)
238 where
239 {-# INLINE_INNER next' #-}
240 next' (Nothing,s)
241 = do
242 r <- next s
243 case r of
244 Yield (from,step,len) s' -> return $ Skip (Just (from,step,len),s')
245 Skip s' -> return $ Skip (Nothing,s')
246 Done -> return Done
247
248 next' (Just (from,step,0),s) = return $ Skip (Nothing,s)
249 next' (Just (from,step,n),s)
250 = return $ Yield from (Just (from+step,step,n-1),s)
251
252
253 -- | Segmented Stream fold. Take segments from the given stream and fold each
254 -- using the supplied function and initial element.
255 --
256 -- @
257 -- foldSS (+) 0 [2, 3, 2] [10, 20, 30, 40, 50, 60, 70]
258 -- = [30,120,130]
259 -- @
260 --
261 foldSS :: (a -> b -> a) -- ^ function to perform the fold
262 -> a -- ^ initial element of each fold
263 -> S.Stream Int -- ^ stream of segment lengths
264 -> S.Stream b -- ^ stream of input data
265 -> S.Stream a -- ^ stream of fold results
266
267 {-# INLINE_STREAM foldSS #-}
268 foldSS f z (Stream nexts ss sz) (Stream nextv vs _) =
269 Stream next (Nothing,z,ss,vs) sz
270 where
271 {-# INLINE next #-}
272 next (Nothing,x,ss,vs) =
273 do
274 r <- nexts ss
275 case r of
276 Done -> return Done
277 Skip ss' -> return $ Skip (Nothing,x, ss', vs)
278 Yield n ss' -> return $ Skip (Just n, z, ss', vs)
279
280 next (Just 0,x,ss,vs) =
281 return $ Yield x (Nothing,z,ss,vs)
282 next (Just n,x,ss,vs) =
283 do
284 r <- nextv vs
285 case r of
286 Done -> return Done
287 -- FIXME
288 -- error
289 -- "Stream.Segmented.foldSS: invalid segment descriptor"
290 Skip vs' -> return $ Skip (Just n,x,ss,vs')
291 Yield y vs' -> let r = f x y
292 in r `seq` return (Skip (Just (n-1), r, ss, vs'))
293
294
295 -- | Like `foldSS`, but use the first member of each chunk as the initial
296 -- element for the fold.
297 fold1SS :: (a -> a -> a) -> S.Stream Int -> S.Stream a -> S.Stream a
298 {-# INLINE_STREAM fold1SS #-}
299 fold1SS f (Stream nexts ss sz) (Stream nextv vs _) =
300 Stream next (Nothing,Nothing,ss,vs) sz
301 where
302 {-# INLINE [0] next #-}
303 next (Nothing,Nothing,ss,vs) =
304 do
305 r <- nexts ss
306 case r of
307 Done -> return Done
308 Skip ss' -> return $ Skip (Nothing,Nothing,ss',vs)
309 Yield n ss' -> return $ Skip (Just n ,Nothing,ss',vs)
310
311 next (Just !n,Nothing,ss,vs) =
312 do
313 r <- nextv vs
314 case r of
315 Done -> return Done -- FIXME: error
316 Skip vs' -> return $ Skip (Just n, Nothing,ss,vs')
317 Yield x vs' -> return $ Skip (Just (n-1),Just x, ss,vs')
318
319 next (Just 0,Just x,ss,vs) =
320 return $ Yield x (Nothing,Nothing,ss,vs)
321
322 next (Just n,Just x,ss,vs) =
323 do
324 r <- nextv vs
325 case r of
326 Done -> return Done -- FIXME: error
327 Skip vs' -> return $ Skip (Just n ,Just x ,ss,vs')
328 Yield y vs' -> let r = f x y
329 in r `seq` return (Skip (Just (n-1),Just r,ss,vs'))
330
331
332 -- | Segmented Stream combine. Like `combine2ByTagS`, except that the tags select
333 -- entire segments of each data stream, instead of selecting one element at a time.
334 --
335 -- @
336 -- combineSS [True, True, False, True, False, False]
337 -- [2,1,3] [10,20,30,40,50,60]
338 -- [1,2,3] [11,22,33,44,55,66]
339 -- = [10,20,30,11,40,50,60,22,33,44,55,66]
340 -- @
341 --
342 -- This says take two elements from the first stream, then another one element
343 -- from the first stream, then one element from the second stream, then three
344 -- elements from the first stream...
345 --
346 combineSS
347 :: S.Stream Bool -- ^ tag values
348 -> S.Stream Int -- ^ segment lengths for first data stream
349 -> S.Stream a -- ^ first data stream
350 -> S.Stream Int -- ^ segment lengths for second data stream
351 -> S.Stream a -- ^ second data stream
352 -> S.Stream a
353
354 {-# INLINE_STREAM combineSS #-}
355 combineSS (Stream nextf sf _)
356 (Stream nexts1 ss1 _) (Stream nextv1 vs1 nv1)
357 (Stream nexts2 ss2 _) (Stream nextv2 vs2 nv2)
358 = Stream next (Nothing,True,sf,ss1,vs1,ss2,vs2)
359 (nv1+nv2)
360 where
361 {-# INLINE next #-}
362 next (Nothing,f,sf,ss1,vs1,ss2,vs2) =
363 do
364 r <- nextf sf
365 case r of
366 Done -> return Done
367 Skip sf' -> return $ Skip (Nothing,f,sf',ss1,vs1,ss2,vs2)
368 Yield c sf'
369 | c ->
370 do
371 r <- nexts1 ss1
372 case r of
373 Done -> return Done
374 Skip ss1' -> return $ Skip (Nothing,f,sf,ss1',vs1,ss2,vs2)
375 Yield n ss1' -> return $ Skip (Just n,c,sf',ss1',vs1,ss2,vs2)
376
377 | otherwise ->
378 do
379 r <- nexts2 ss2
380 case r of
381 Done -> return Done
382 Skip ss2' -> return $ Skip (Nothing,f,sf,ss1,vs1,ss2',vs2)
383 Yield n ss2' -> return $ Skip (Just n,c,sf',ss1,vs1,ss2',vs2)
384
385 next (Just 0,_,sf,ss1,vs1,ss2,vs2) =
386 return $ Skip (Nothing,True,sf,ss1,vs1,ss2,vs2)
387
388 next (Just n,True,sf,ss1,vs1,ss2,vs2) =
389 do
390 r <- nextv1 vs1
391 case r of
392 Done -> return Done
393 Skip vs1' -> return $ Skip (Just n,True,sf,ss1,vs1',ss2,vs2)
394 Yield x vs1' -> return $ Yield x (Just (n-1),True,sf,ss1,vs1',ss2,vs2)
395
396 next (Just n,False,sf,ss1,vs1,ss2,vs2) =
397 do
398 r <- nextv2 vs2
399 case r of
400 Done -> return Done
401 Skip vs2' -> return $ Skip (Just n,False,sf,ss1,vs1,ss2,vs2')
402 Yield x vs2' -> return $ Yield x (Just (n-1),False,sf,ss1,vs1,ss2,vs2')
403
404
405 -- | Segmented Strem append. Append corresponding segments from each stream.
406 --
407 -- @
408 -- appendSS [2, 1, 3] [10, 20, 30, 40, 50, 60]
409 -- [1, 3, 2] [11, 22, 33, 44, 55, 66]
410 -- = [10,20,11,30,22,33,44,40,50,60,55,66]
411 -- @
412 --
413 appendSS
414 :: S.Stream Int -- ^ segment lengths for first data stream
415 -> S.Stream a -- ^ first data stream
416 -> S.Stream Int -- ^ segment lengths for second data stream
417 -> S.Stream a -- ^ second data stream
418 -> S.Stream a
419
420 {-# INLINE_STREAM appendSS #-}
421 appendSS (Stream nexts1 ss1 ns1) (Stream nextv1 sv1 nv1)
422 (Stream nexts2 ss2 ns2) (Stream nextv2 sv2 nv2)
423 = Stream next (True,Nothing,ss1,sv1,ss2,sv2) (nv1 + nv2)
424 where
425 {-# INLINE next #-}
426 next (True,Nothing,ss1,sv1,ss2,sv2) =
427 do
428 r <- nexts1 ss1
429 case r of
430 Done -> return $ Done
431 Skip ss1' -> return $ Skip (True,Nothing,ss1',sv1,ss2,sv2)
432 Yield n ss1' -> return $ Skip (True,Just n ,ss1',sv1,ss2,sv2)
433
434 next (True,Just 0,ss1,sv1,ss2,sv2)
435 = return $ Skip (False,Nothing,ss1,sv1,ss2,sv2)
436
437 next (True,Just n,ss1,sv1,ss2,sv2) =
438 do
439 r <- nextv1 sv1
440 case r of
441 Done -> return Done -- FIXME: error
442 Skip sv1' -> return $ Skip (True,Just n,ss1,sv1',ss2,sv2)
443 Yield x sv1' -> return $ Yield x (True,Just (n-1),ss1,sv1',ss2,sv2)
444
445 next (False,Nothing,ss1,sv1,ss2,sv2) =
446 do
447 r <- nexts2 ss2
448 case r of
449 Done -> return Done -- FIXME: error
450 Skip ss2' -> return $ Skip (False,Nothing,ss1,sv1,ss2',sv2)
451 Yield n ss2' -> return $ Skip (False,Just n,ss1,sv1,ss2',sv2)
452
453 next (False,Just 0,ss1,sv1,ss2,sv2)
454 = return $ Skip (True,Nothing,ss1,sv1,ss2,sv2)
455
456 next (False,Just n,ss1,sv1,ss2,sv2) =
457 do
458 r <- nextv2 sv2
459 case r of
460 Done -> return Done -- FIXME: error
461 Skip sv2' -> return $ Skip (False,Just n,ss1,sv1,ss2,sv2')
462 Yield x sv2' -> return $ Yield x (False,Just (n-1),ss1,sv1,ss2,sv2')
463
464
465 -- | Segmented Stream fold, with a fixed segment length.
466 --
467 -- Like `foldSS` but use a fixed length for each segment.
468 --
469 foldValuesR
470 :: (a -> b -> a) -- ^ function to perform the fold
471 -> a -- ^ initial element for fold
472 -> Int -- ^ length of each segment
473 -> S.Stream b -- ^ data stream
474 -> S.Stream a
475
476 {-# INLINE_STREAM foldValuesR #-}
477 foldValuesR f z segSize (Stream nextv vs nv) =
478 Stream next (segSize,z,vs) (nv `divSize` segSize)
479 where
480 {-# INLINE next #-}
481 next (0,x,vs) = return $ Yield x (segSize,z,vs)
482
483 next (n,x,vs) =
484 do
485 r <- nextv vs
486 case r of
487 Done -> return Done
488 Skip vs' -> return $ Skip (n,x,vs')
489 Yield y vs' -> let r = f x y
490 in r `seq` return (Skip ((n-1),r,vs'))
491
492
493 -- | Divide a size hint by a scalar.
494 divSize :: Size -> Int -> Size
495 divSize (Exact n) k = Exact (n `div` k)
496 divSize (Max n) k = Max (n `div` k)
497 divSize Unknown _ = Unknown
498
499
500 -- | Segmented Stream indices.
501 --
502 -- @
503 -- indicesSS 15 4 [3, 5, 7]
504 -- = [4,5,6,0,1,2,3,4,0,1,2,3,4,5,6]
505 -- @
506 --
507 -- TODO: Is that correct? Why does the first segment in the result start from 4,
508 -- unlike the others?
509 --
510 indicesSS
511 :: Int
512 -> Int
513 -> S.Stream Int
514 -> S.Stream Int
515
516 {-# INLINE_STREAM indicesSS #-}
517 indicesSS n i (Stream next s _) =
518 Stream next' (i,Nothing,s) (Exact n)
519 where
520 {-# INLINE next' #-}
521 next' (i,Nothing,s) =
522 do
523 r <- next s
524 case r of
525 Done -> return Done
526 Skip s' -> return $ Skip (i,Nothing,s')
527 Yield k s' -> return $ Skip (i,Just k,s')
528
529 next' (i,Just k,s)
530 | k > 0 = return $ Yield i (i+1,Just (k-1),s)
531 | otherwise = return $ Skip (0 ,Nothing ,s)
532