dph-prim-seq: Add lockedZipWiths up to 8
[packages/dph.git] / dph-test / fusion / Stream.hs
1 {-# LANGUAGE FlexibleContexts #-}
2 module Stream where
3
4 import Data.Vector.Generic as G
5 import Data.Vector.Generic.Base as G
6 import Data.Vector.Fusion.Stream.Monadic as S
7 import Data.Vector.Fusion.Stream.Size as S
8 import Data.Vector.Fusion.Util as S
9 import qualified Data.Vector.Generic.New as New
10 import Data.Vector.Generic.New ( New )
11
12
13 -- Swallow --------------------------------------------------------------------
14 -- | Like 'stream' but something else in the context knows how long it is.
15 -- We just keep returning elements and don't check for the end-of-vector
16 -- condition.
17 swallow :: Monad m => Vector v a => v a -> Stream m a
18 swallow v
19 = v `seq` n `seq` (S.unfoldr get 0 `S.sized` Unknown)
20 where
21 n = G.length v
22
23 {-# INLINE get #-}
24 get i
25 | Box a <- basicUnsafeIndexM v i
26 = Just (a, i + 1)
27 {-# INLINE [1] swallow #-}
28
29
30 swallowS stream = stream
31 {-# INLINE [1] swallowS #-}
32
33 {-# RULES "swallow/new/unstream"
34 forall s
35 . swallow (new (New.unstream s)) = swallowS s
36 #-}
37
38
39 -- Repeat ---------------------------------------------------------------------
40 repeatM :: Monad m => m a -> Stream m a
41 repeatM x
42 = Stream step () Unknown
43 where
44 {-# INLINE [0] step #-}
45 step _
46 = do v <- x
47 return $ Yield v ()
48
49 {-# INLINE [1] repeatM #-}
50
51
52 {-# RULES "swallowS/replicate"
53 forall len x
54 . swallowS (S.replicateM len x) = repeatM x
55 #-}
56
57
58 -- Locked Streamers -----------------------------------------------------------
59 stream2 :: (Monad m, Vector v a, Vector v b)
60 => v a -> v b
61 -> Stream m (a, b)
62
63 stream2 aa bb
64 = map2S (G.length aa) (swallow aa) (swallow bb)
65 {-# INLINE [1] stream2 #-}
66
67
68 -- | Stream three things.
69 stream3 :: (Monad m, Vector v a, Vector v b, Vector v c)
70 => v a -> v b -> v c
71 -> Stream m (a, b, c)
72
73 stream3 aa bb cc
74 = map3S (G.length aa) (swallow aa) (swallow bb) (swallow cc)
75 {-# INLINE [1] stream3 #-}
76
77
78 {-# RULES "stream3/new_1"
79 forall as bs cs
80 . stream3 (G.new as) bs cs
81 = S.map (\((b, c), a) -> (a, b, c))
82 $ map2S (G.length bs) (swallow2 bs cs) (swallow (G.new as))
83 #-}
84
85 {-# RULES "stream3/new_2"
86 forall as bs cs
87 . stream3 as (G.new bs) cs
88 = S.map (\((a, c), b) -> (a, b, c))
89 $ map2S (G.length as) (swallow2 as cs) (swallow (G.new bs))
90 #-}
91
92 {-# RULES "stream3/new_3"
93 forall as bs cs
94 . stream3 as bs (G.new cs)
95 = S.map (\((a, b), c) -> (a, b, c))
96 $ map2S (G.length as) (swallow2 as bs) (swallow (G.new cs))
97 #-}
98
99
100
101 -- Locked Swallowers ---------------------------------------------------------
102 -- | Swallow two things.
103 -- There is no end-of vector check.
104 -- The context needs to know how many elements to demand.
105 swallow2
106 :: (Monad m, Vector v a, Vector v b)
107 => v a -> v b
108 -> Stream m (a, b)
109
110 swallow2 aa bb
111 = aa `seq` bb `seq` n `seq` (S.unfoldr get 0 `S.sized` Unknown)
112 where n = G.length aa
113
114 {-# INLINE [0] get #-}
115 get i
116 | Box a <- basicUnsafeIndexM aa i
117 , Box b <- basicUnsafeIndexM bb i
118 = Just ((a, b), i + 1)
119 {-# INLINE [1] swallow2 #-}
120
121
122 -- Locked maps ----------------------------------------------------------------
123 map2S :: Monad m
124 => Int
125 -> Stream m a -> Stream m b
126 -> Stream m (a, b)
127
128 map2S len
129 (Stream mkStep1 sa1 _)
130 (Stream mkStep2 sa2 _)
131 = Stream step (sa1, sa2, 0) (S.Exact len)
132 where
133 {-# INLINE [0] step #-}
134 step (s1, s2, i)
135 = i `seq`
136 do step1 <- mkStep1 s1
137 step2 <- mkStep2 s2
138 return $ case (step1, step2) of
139 (Yield x1 s1', Yield x2 s2')
140 | i < len -> Yield (x1, x2) (s1', s2', i + 1)
141 _ -> Done
142 {-# INLINE [1] map2S #-}
143
144
145 map3S
146 :: Monad m
147 => Int
148 -> Stream m a -> Stream m b -> Stream m c
149 -> Stream m (a, b, c)
150
151 map3S len
152 (Stream mkStep1 sa1 _)
153 (Stream mkStep2 sa2 _)
154 (Stream mkStep3 sa3 _)
155 = Stream step (sa1, sa2, sa3, 0) (S.Exact len)
156 where
157 {-# INLINE [0] step #-}
158 step (s1, s2, s3, i)
159 = i `seq`
160 do step1 <- mkStep1 s1
161 step2 <- mkStep2 s2
162 step3 <- mkStep3 s3
163 return $ case (step1, step2, step3) of
164 (Yield x1 s1', Yield x2 s2', Yield x3 s3')
165 | i < len -> Yield (x1, x2, x3) (s1', s2', s3', i + 1)
166
167 _ -> Done
168 {-# INLINE [1] map3S #-}
169