d60dd7b315de93447f409e15b21d032699892ae8
[packages/hoopl.git] / src / Compiler / Hoopl / Combinators.hs
1 {-# LANGUAGE RankNTypes, LiberalTypeSynonyms, ScopedTypeVariables #-}
2
3 module Compiler.Hoopl.Combinators
4 ( SimpleFwdRewrite, SimpleFwdRewrite', noFwdRewrite, thenFwdRw
5 , shallowFwdRw3, shallowFwdRwPoly, deepFwdRw3, deepFwdRwPoly, iterFwdRw
6 , SimpleBwdRewrite, SimpleBwdRewrite', noBwdRewrite, thenBwdRw
7 , shallowBwdRw, shallowBwdRw', deepBwdRw, deepBwdRw', iterBwdRw
8 , productFwd, productBwd
9 )
10
11 where
12
13 import Control.Monad
14 import Data.Function
15 import Data.Maybe
16
17 import Compiler.Hoopl.Collections
18 import Compiler.Hoopl.Dataflow
19 import Compiler.Hoopl.Graph (Graph, C, O)
20 import Compiler.Hoopl.Label
21
22 type FR m n f = FwdRewrite m n f
23 type BR m n f = BwdRewrite m n f
24
25 type SFRW m n f e x = n e x -> f -> m (Maybe (Graph n e x))
26 type FRW m n f e x = n e x -> f -> m (FwdRes m n f e x)
27 type SimpleFwdRewrite m n f = ExTriple (SFRW m n f)
28 type ExTriple a = (a C O, a O O, a O C) -- ^ entry/exit triple
29 type SimpleFwdRewrite' m n f = forall e x . SFRW m n f e x
30 type LiftFRW m n f e x = SFRW m n f e x -> FRW m n f e x
31 type MapFRW m n f e x = FRW m n f e x -> FRW m n f e x
32 type MapFRW2 m n f e x = FRW m n f e x -> FRW m n f e x -> FRW m n f e x
33
34 ----------------------------------------------------------------
35 -- common operations on triples
36
37 uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
38 uncurry3 f (a, b, c) = f a b c
39
40 apply :: (a -> b, d -> e, g -> h) -> (a, d, g) -> (b, e, h)
41 apply (f1, f2, f3) (x1, x2, x3) = (f1 x1, f2 x2, f3 x3)
42
43 applyBinary :: (a -> b -> c, d -> e -> f, g -> h -> i)
44 -> (a, d, g) -> (b, e, h) -> (c, f, i)
45 applyBinary (f1, f2, f3) (x1, x2, x3) (y1, y2, y3) = (f1 x1 y1, f2 x2 y2, f3 x3 y3)
46
47
48 ----------------------------------------------------------------
49
50 wrapSFRewrites :: ExTriple (LiftFRW m n f) -> SimpleFwdRewrite m n f -> FR m n f
51 wrapSFRewrites lift rw = uncurry3 mkFRewrite $ apply lift rw
52
53 wrapFRewrites :: ExTriple (MapFRW m n f) -> FR m n f -> FR m n f
54 wrapFRewrites map frw = uncurry3 mkFRewrite $ apply map $ getFRewrites frw
55
56 wrapFRewrites2 :: ExTriple (MapFRW2 m n f) -> FR m n f -> FR m n f -> FR m n f
57 wrapFRewrites2 map frw1 frw2 =
58 uncurry3 mkFRewrite $ (applyBinary map `on` getFRewrites) frw1 frw2
59
60
61 -- Combinators for higher-rank rewriting functions:
62 wrapSFRewrites' :: (forall e x . LiftFRW m n f e x) -> SimpleFwdRewrite m n f -> FR m n f
63 wrapSFRewrites' lift = wrapSFRewrites (lift, lift, lift)
64
65 wrapFRewrites' :: (forall e x . MapFRW m n f e x) -> FR m n f -> FR m n f
66 wrapFRewrites' map = wrapFRewrites (map, map, map)
67 -- It's ugly that we can't use
68 -- wrapFRewrites' = mkFRewrite'
69 -- Would be nice to refactor here XXX ---NR
70
71
72 wrapFRewrites2' :: (forall e x . MapFRW2 m n f e x) -> FR m n f -> FR m n f -> FR m n f
73 wrapFRewrites2' map = wrapFRewrites2 (map, map, map)
74
75 ----------------------------------------------------------------
76
77 noFwdRewrite :: Monad m => FwdRewrite m n f
78 noFwdRewrite = mkFRewrite' $ \ _ _ -> return NoFwdRes
79
80 shallowFwdRw3 :: forall m n f . Monad m => SimpleFwdRewrite m n f -> FwdRewrite m n f
81 shallowFwdRw3 rw = wrapSFRewrites' lift rw
82 where lift rw n f = liftM withoutRewrite (rw n f)
83 withoutRewrite Nothing = NoFwdRes
84 withoutRewrite (Just g) = FwdRes g noFwdRewrite
85
86 shallowFwdRwPoly :: Monad m => SimpleFwdRewrite' m n f -> FwdRewrite m n f
87 shallowFwdRwPoly f = shallowFwdRw3 (f, f, f)
88
89 deepFwdRw3 :: Monad m => SimpleFwdRewrite m n f -> FwdRewrite m n f
90 deepFwdRwPoly :: Monad m => SimpleFwdRewrite' m n f -> FwdRewrite m n f
91 deepFwdRw3 r = iterFwdRw (shallowFwdRw3 r)
92 deepFwdRwPoly f = deepFwdRw3 (f, f, f)
93
94 thenFwdRw :: Monad m => FwdRewrite m n f -> FwdRewrite m n f -> FwdRewrite m n f
95 thenFwdRw rw1 rw2 = wrapFRewrites2' tfr rw1 rw2
96 where tfr rw1 rw2' n f = do -- Gross!! Isn't rw2 == rw2' always? XXX ---NR
97 res1 <- rw1 n f
98 case res1 of
99 NoFwdRes -> rw2' n f
100 (FwdRes g rw1a) -> return $ FwdRes g (rw1a `thenFwdRw` rw2)
101
102 iterFwdRw :: Monad m => FwdRewrite m n f -> FwdRewrite m n f
103 iterFwdRw rw = wrapFRewrites' f rw
104 where f rw' n f = liftM iterRewrite (rw' n f)
105 iterRewrite NoFwdRes = NoFwdRes
106 iterRewrite (FwdRes g rw2) = FwdRes g (rw2 `thenFwdRw` iterFwdRw rw)
107
108 ----------------------------------------------------------------
109
110 type SBRW m n f e x = n e x -> Fact x f -> m (Maybe (Graph n e x))
111 type BRW m n f e x = n e x -> Fact x f -> m (BwdRes m n f e x)
112 type SimpleBwdRewrite m n f = ExTriple ( SBRW m n f)
113 type SimpleBwdRewrite' m n f = forall e x . SBRW m n f e x
114 type LiftBRW m n f e x = SBRW m n f e x -> BRW m n f e x
115 type MapBRW m n f e x = BRW m n f e x -> BRW m n f e x
116 type MapBRW2 m n f e x = BRW m n f e x -> BRW m n f e x -> BRW m n f e x
117
118 ----------------------------------------------------------------
119
120 wrapSBRewrites :: ExTriple (LiftBRW m n f) -> SimpleBwdRewrite m n f -> BwdRewrite m n f
121 wrapSBRewrites lift rw = uncurry3 mkBRewrite $ apply lift rw
122
123 wrapBRewrites :: ExTriple (MapBRW m n f) -> BwdRewrite m n f -> BwdRewrite m n f
124 wrapBRewrites map rw = uncurry3 mkBRewrite $ apply map $ getBRewrites rw
125
126 wrapBRewrites2 :: ExTriple (MapBRW2 m n f) -> BR m n f -> BR m n f -> BR m n f
127 wrapBRewrites2 map rw1 rw2 =
128 uncurry3 mkBRewrite $ (applyBinary map `on` getBRewrites) rw1 rw2
129
130 -- Combinators for higher-rank rewriting functions:
131 wrapSBRewrites' :: (forall e x . LiftBRW m n f e x) -> SimpleBwdRewrite m n f -> BR m n f
132 wrapSBRewrites' lift = wrapSBRewrites (lift, lift, lift)
133
134 wrapBRewrites' :: (forall e x . MapBRW m n f e x) -> BwdRewrite m n f -> BwdRewrite m n f
135 wrapBRewrites' map = wrapBRewrites (map, map, map)
136
137 wrapBRewrites2' :: (forall e x . MapBRW2 m n f e x) -> BR m n f -> BR m n f -> BR m n f
138 wrapBRewrites2' map = wrapBRewrites2 (map, map, map)
139
140 ----------------------------------------------------------------
141
142 noBwdRewrite :: Monad m => BwdRewrite m n f
143 noBwdRewrite = mkBRewrite' $ \ _ _ -> return NoBwdRes
144
145 shallowBwdRw :: Monad m => SimpleBwdRewrite m n f -> BwdRewrite m n f
146 shallowBwdRw rw = wrapSBRewrites' lift rw
147 where lift rw n f = liftM withoutRewrite (rw n f)
148 withoutRewrite Nothing = NoBwdRes
149 withoutRewrite (Just g) = BwdRes g noBwdRewrite
150
151 shallowBwdRw' :: Monad m => SimpleBwdRewrite' m n f -> BwdRewrite m n f
152 shallowBwdRw' f = shallowBwdRw (f, f, f)
153
154 deepBwdRw :: Monad m => SimpleBwdRewrite m n f -> BwdRewrite m n f
155 deepBwdRw' :: Monad m => SimpleBwdRewrite' m n f -> BwdRewrite m n f
156 deepBwdRw r = iterBwdRw (shallowBwdRw r)
157 deepBwdRw' f = deepBwdRw (f, f, f)
158
159
160 thenBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f -> BwdRewrite m n f
161 thenBwdRw rw1 rw2 = wrapBRewrites2' f rw1 rw2
162 where f rw1 rw2' n f = do
163 res1 <- rw1 n f
164 case res1 of
165 NoBwdRes -> rw2' n f
166 (BwdRes g rw1a) -> return $ BwdRes g (rw1a `thenBwdRw` rw2)
167
168 iterBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f
169 iterBwdRw rw = wrapBRewrites' f rw
170 where f rw' n f = liftM iterRewrite (rw' n f)
171 iterRewrite NoBwdRes = NoBwdRes
172 iterRewrite (BwdRes g rw2) = BwdRes g (rw2 `thenBwdRw` iterBwdRw rw)
173
174 productFwd :: forall m n f f' . Monad m => FwdPass m n f -> FwdPass m n f' -> FwdPass m n (f, f')
175 productFwd pass1 pass2 = FwdPass lattice transfer rewrite
176 where
177 lattice = productLattice (fp_lattice pass1) (fp_lattice pass2)
178 transfer = mkFTransfer (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
179 where
180 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
181 tfb t1 t2 n (f1, f2) = mapMapWithKey withfb2 fb1
182 where fb1 = t1 n f1
183 fb2 = t2 n f2
184 withfb2 l f = (f, fromMaybe bot2 $ lookupFact l fb2)
185 bot2 = fact_bot (fp_lattice pass2)
186 (tf1, tm1, tl1) = getFTransfers (fp_transfer pass1)
187 (tf2, tm2, tl2) = getFTransfers (fp_transfer pass2)
188 rewrite = liftRW (fp_rewrite pass1) fst `thenFwdRw` liftRW (fp_rewrite pass2) snd
189 where
190 liftRW rws proj = mkFRewrite (lift f) (lift m) (lift l)
191 where lift rw n f = liftM projRewrite $ rw n (proj f)
192 projRewrite NoFwdRes = NoFwdRes
193 projRewrite (FwdRes g rws') = FwdRes g $ liftRW rws' proj
194 (f, m, l) = getFRewrites rws
195
196 productBwd :: forall m n f f' . Monad m => BwdPass m n f -> BwdPass m n f' -> BwdPass m n (f, f')
197 productBwd pass1 pass2 = BwdPass lattice transfer rewrite
198 where
199 lattice = productLattice (bp_lattice pass1) (bp_lattice pass2)
200 transfer = mkBTransfer (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
201 where
202 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
203 tfb t1 t2 n fb = (t1 n $ mapMap fst fb, t2 n $ mapMap snd fb)
204 (tf1, tm1, tl1) = getBTransfers (bp_transfer pass1)
205 (tf2, tm2, tl2) = getBTransfers (bp_transfer pass2)
206 rewrite = liftRW (bp_rewrite pass1) fst `thenBwdRw` liftRW (bp_rewrite pass2) snd
207 where
208 liftRW :: forall f1 . BwdRewrite m n f1 -> ((f, f') -> f1) -> BwdRewrite m n (f, f')
209 liftRW rws proj = mkBRewrite (lift proj f) (lift proj m) (lift (mapMap proj) l)
210 where lift proj' rw n f = liftM projRewrite $ rw n (proj' f)
211 projRewrite NoBwdRes = NoBwdRes
212 projRewrite (BwdRes g rws') = BwdRes g $ liftRW rws' proj
213 (f, m, l) = getBRewrites rws
214
215 productLattice :: forall f f' . DataflowLattice f -> DataflowLattice f' -> DataflowLattice (f, f')
216 productLattice l1 l2 =
217 DataflowLattice
218 { fact_name = fact_name l1 ++ " x " ++ fact_name l2
219 , fact_bot = (fact_bot l1, fact_bot l2)
220 , fact_extend = extend'
221 , fact_do_logging = fact_do_logging l1 || fact_do_logging l2
222 }
223 where
224 extend' lbl (OldFact (o1, o2)) (NewFact (n1, n2)) = (c', (f1, f2))
225 where (c1, f1) = fact_extend l1 lbl (OldFact o1) (NewFact n1)
226 (c2, f2) = fact_extend l2 lbl (OldFact o2) (NewFact n2)
227 c' = case (c1, c2) of
228 (NoChange, NoChange) -> NoChange
229 _ -> SomeChange