6698a749f8ca6c818873c34a97d0385a55e1d9b4
[packages/dph.git] / dph-test / fusion / Stream.hs
1 {-# LANGUAGE FlexibleContexts #-}
2 module Stream where
3
4 import Data.Vector.Generic as G
5 import Data.Vector.Generic.Base as G
6 import Data.Vector.Fusion.Stream.Monadic as S
7 import Data.Vector.Fusion.Stream.Size as S
8 import Data.Vector.Fusion.Util as S
9 import qualified Data.Vector.Generic.New as New
10 import Data.Vector.Generic.New ( New )
11
12
13 -- Swallow --------------------------------------------------------------------
14 -- | Like 'stream' but something else in the context knows how long it is.
15 -- We just keep returning elements and don't check for the end-of-vector
16 -- condition.
17 swallow :: Monad m => Vector v a => v a -> Stream m a
18 swallow v
19 = v `seq` n `seq` (S.unfoldr get 0 `S.sized` Unknown)
20 where
21 n = G.length v
22
23 {-# INLINE get #-}
24 get i
25 | Box a <- basicUnsafeIndexM v i
26 = Just (a, i + 1)
27 {-# INLINE [1] swallow #-}
28
29
30 swallowS stream = stream
31 {-# INLINE [1] swallowS #-}
32
33 {-# RULES "swallow/new/unstream"
34 forall s
35 . swallow (new (New.unstream s)) = swallowS s
36 #-}
37
38
39 -- Repeat ---------------------------------------------------------------------
40 repeatM :: Monad m => m a -> Stream m a
41 repeatM x
42 = Stream step () Unknown
43 where
44 {-# INLINE [0] step #-}
45 step _
46 = do v <- x
47 return $ Yield v ()
48
49 {-# INLINE [1] repeatM #-}
50
51
52 {-# RULES "swallowS/replicate"
53 forall len x
54 . swallowS (S.replicateM len x) = repeatM x
55 #-}
56
57
58 -- Locked Streamers -----------------------------------------------------------
59 lockedStream2
60 :: (Monad m, Vector v a, Vector v b)
61 => v a -> v b
62 -> Stream m (a, b)
63
64 lockedStream2 aa bb
65 = lockedZip2S (G.length aa) (swallow aa) (swallow bb)
66 {-# INLINE [1] lockedStream2 #-}
67
68
69 -- | Stream three things.
70 lockedStream3
71 :: (Monad m, Vector v a, Vector v b, Vector v c)
72 => v a -> v b -> v c
73 -> Stream m (a, b, c)
74
75 lockedStream3 aa bb cc
76 = lockedZip3S (G.length aa) (swallow aa) (swallow bb) (swallow cc)
77 {-# INLINE [1] lockedStream3 #-}
78
79
80 {-# RULES "lockedStream3/new_1"
81 forall as bs cs
82 . lockedStream3 (G.new as) bs cs
83 = S.map (\((b, c), a) -> (a, b, c))
84 $ lockedZip2S (G.length bs) (lockedSwallow2 bs cs) (swallow (G.new as))
85 #-}
86
87 {-# RULES "lockedStream3/new_2"
88 forall as bs cs
89 . lockedStream3 as (G.new bs) cs
90 = S.map (\((a, c), b) -> (a, b, c))
91 $ lockedZip2S (G.length as) (lockedSwallow2 as cs) (swallow (G.new bs))
92 #-}
93
94 {-# RULES "lockedStream3/new_3"
95 forall as bs cs
96 . lockedStream3 as bs (G.new cs)
97 = S.map (\((a, b), c) -> (a, b, c))
98 $ lockedZip2S (G.length as) (lockedSwallow2 as bs) (swallow (G.new cs))
99 #-}
100
101
102
103 -- Locked Swallowers ---------------------------------------------------------
104
105 -- | Swallow two things.
106 -- There is no end-of vector check.
107 -- The context needs to know how many elements to demand.
108 lockedSwallow2
109 :: (Monad m, Vector v a, Vector v b)
110 => v a -> v b
111 -> Stream m (a, b)
112
113 lockedSwallow2 aa bb
114 = aa `seq` bb `seq` n `seq` (S.unfoldr get 0 `S.sized` Unknown)
115 where n = G.length aa
116
117 {-# INLINE [0] get #-}
118 get i
119 | Box a <- basicUnsafeIndexM aa i
120 , Box b <- basicUnsafeIndexM bb i
121 = Just ((a, b), i + 1)
122 {-# INLINE [1] lockedSwallow2 #-}
123
124
125 -- Locked Stream Zippers -----------------------------------------------------
126 lockedZip2S
127 :: Monad m
128 => Int
129 -> Stream m a -> Stream m b
130 -> Stream m (a, b)
131
132 lockedZip2S len
133 (Stream mkStep1 sa1 _)
134 (Stream mkStep2 sa2 _)
135 = Stream step (sa1, sa2, 0) (S.Exact len)
136 where
137 {-# INLINE [0] step #-}
138 step (s1, s2, i)
139 = i `seq`
140 do step1 <- mkStep1 s1
141 step2 <- mkStep2 s2
142 return $ case (step1, step2) of
143 (Yield x1 s1', Yield x2 s2')
144 | i < len -> Yield (x1, x2) (s1', s2', i + 1)
145 _ -> Done
146 {-# INLINE [1] lockedZip2S #-}
147
148
149 lockedZip3S
150 :: Monad m
151 => Int
152 -> Stream m a -> Stream m b -> Stream m c
153 -> Stream m (a, b, c)
154
155 lockedZip3S len
156 (Stream mkStep1 sa1 _)
157 (Stream mkStep2 sa2 _)
158 (Stream mkStep3 sa3 _)
159 = Stream step (sa1, sa2, sa3, 0) (S.Exact len)
160 where
161 {-# INLINE [0] step #-}
162 step (s1, s2, s3, i)
163 = i `seq`
164 do step1 <- mkStep1 s1
165 step2 <- mkStep2 s2
166 step3 <- mkStep3 s3
167 return $ case (step1, step2, step3) of
168 (Yield x1 s1', Yield x2 s2', Yield x3 s3')
169 | i < len -> Yield (x1, x2, x3) (s1', s2', s3', i + 1)
170
171 _ -> Done
172 {-# INLINE [1] lockedZip3S #-}
173