Make Hoopl work with MonoLocalBinds.
[packages/hoopl.git] / src / Compiler / Hoopl / Combinators.hs
1 {-# LANGUAGE RankNTypes, LiberalTypeSynonyms, ScopedTypeVariables, GADTs #-}
2
3 module Compiler.Hoopl.Combinators
4 ( thenFwdRw
5 , deepFwdRw3, deepFwdRw, iterFwdRw
6 , thenBwdRw
7 , deepBwdRw3, deepBwdRw, iterBwdRw
8 , pairFwd, pairBwd, pairLattice
9 )
10
11 where
12
13 import Control.Monad
14 import Data.Maybe
15
16 import Compiler.Hoopl.Collections
17 import Compiler.Hoopl.Dataflow
18 import Compiler.Hoopl.Fuel
19 import Compiler.Hoopl.Graph (Graph, C, O, Shape(..))
20 import Compiler.Hoopl.Label
21
22 ----------------------------------------------------------------
23
24 deepFwdRw3 :: FuelMonad m
25 => (n C O -> f -> m (Maybe (Graph n C O)))
26 -> (n O O -> f -> m (Maybe (Graph n O O)))
27 -> (n O C -> f -> m (Maybe (Graph n O C)))
28 -> (FwdRewrite m n f)
29 deepFwdRw :: FuelMonad m
30 => (forall e x . n e x -> f -> m (Maybe (Graph n e x))) -> FwdRewrite m n f
31 deepFwdRw3 f m l = iterFwdRw $ mkFRewrite3 f m l
32 deepFwdRw f = deepFwdRw3 f f f
33
34 -- N.B. rw3, rw3', and rw3a are triples of functions.
35 -- But rw and rw' are single functions.
36 -- @ start comb1.tex
37 thenFwdRw :: forall m n f. Monad m
38 => FwdRewrite m n f
39 -> FwdRewrite m n f
40 -> FwdRewrite m n f
41 -- @ end comb1.tex
42 thenFwdRw rw3 rw3' = wrapFR2 thenrw rw3 rw3'
43 where
44 thenrw :: forall m1 e x t t1.
45 Monad m1 =>
46 (t -> t1 -> m1 (Maybe (Graph n e x, FwdRewrite m n f)))
47 -> (t -> t1 -> m1 (Maybe (Graph n e x, FwdRewrite m n f)))
48 -> t
49 -> t1
50 -> m1 (Maybe (Graph n e x, FwdRewrite m n f))
51 thenrw rw rw' n f = rw n f >>= fwdRes
52 where fwdRes Nothing = rw' n f
53 fwdRes (Just gr) = return $ Just $ fadd_rw rw3' gr
54
55 -- @ start iterf.tex
56 iterFwdRw :: forall m n f. Monad m
57 => FwdRewrite m n f
58 -> FwdRewrite m n f
59 -- @ end iterf.tex
60 iterFwdRw rw3 = wrapFR iter rw3
61 where iter :: forall a m1 m2 e x t.
62 (Monad m2, Monad m1) =>
63 (t -> a -> m1 (m2 (Graph n e x, FwdRewrite m n f)))
64 -> t
65 -> a
66 -> m1 (m2 (Graph n e x, FwdRewrite m n f))
67 iter rw n = (liftM $ liftM $ fadd_rw (iterFwdRw rw3)) . rw n
68
69 -- | Function inspired by 'rew' in the paper
70 frewrite_cps :: Monad m
71 => ((Graph n e x, FwdRewrite m n f) -> m a)
72 -> m a
73 -> (forall e x . n e x -> f -> m (Maybe (Graph n e x, FwdRewrite m n f)))
74 -> n e x
75 -> f
76 -> m a
77 frewrite_cps j n rw node f =
78 do mg <- rw node f
79 case mg of Nothing -> n
80 Just gr -> j gr
81
82
83
84 -- | Function inspired by 'add' in the paper
85 fadd_rw :: Monad m
86 => FwdRewrite m n f
87 -> (Graph n e x, FwdRewrite m n f)
88 -> (Graph n e x, FwdRewrite m n f)
89 fadd_rw rw2 (g, rw1) = (g, rw1 `thenFwdRw` rw2)
90
91 ----------------------------------------------------------------
92
93 deepBwdRw3 :: FuelMonad m
94 => (n C O -> f -> m (Maybe (Graph n C O)))
95 -> (n O O -> f -> m (Maybe (Graph n O O)))
96 -> (n O C -> FactBase f -> m (Maybe (Graph n O C)))
97 -> (BwdRewrite m n f)
98 deepBwdRw :: FuelMonad m
99 => (forall e x . n e x -> Fact x f -> m (Maybe (Graph n e x)))
100 -> BwdRewrite m n f
101 deepBwdRw3 f m l = iterBwdRw $ mkBRewrite3 f m l
102 deepBwdRw f = deepBwdRw3 f f f
103
104
105 thenBwdRw :: forall m n f. Monad m => BwdRewrite m n f -> BwdRewrite m n f -> BwdRewrite m n f
106 thenBwdRw rw1 rw2 = wrapBR2 f rw1 rw2
107 where f :: forall t t1 t2 m1 e x.
108 Monad m1 =>
109 t
110 -> (t1 -> t2 -> m1 (Maybe (Graph n e x, BwdRewrite m n f)))
111 -> (t1 -> t2 -> m1 (Maybe (Graph n e x, BwdRewrite m n f)))
112 -> t1
113 -> t2
114 -> m1 (Maybe (Graph n e x, BwdRewrite m n f))
115 f _ rw1 rw2' n f = do
116 res1 <- rw1 n f
117 case res1 of
118 Nothing -> rw2' n f
119 Just gr -> return $ Just $ badd_rw rw2 gr
120
121 iterBwdRw :: forall m n f. Monad m => BwdRewrite m n f -> BwdRewrite m n f
122 iterBwdRw rw = wrapBR f rw
123 where f :: forall t m1 m2 e x t1 t2.
124 (Monad m2, Monad m1) =>
125 t
126 -> (t1 -> t2 -> m1 (m2 (Graph n e x, BwdRewrite m n f)))
127 -> t1
128 -> t2
129 -> m1 (m2 (Graph n e x, BwdRewrite m n f))
130 f _ rw' n f = liftM (liftM (badd_rw (iterBwdRw rw))) (rw' n f)
131
132 -- | Function inspired by 'add' in the paper
133 badd_rw :: Monad m
134 => BwdRewrite m n f
135 -> (Graph n e x, BwdRewrite m n f)
136 -> (Graph n e x, BwdRewrite m n f)
137 badd_rw rw2 (g, rw1) = (g, rw1 `thenBwdRw` rw2)
138
139
140 -- @ start pairf.tex
141 pairFwd :: forall m n f f'. Monad m
142 => FwdPass m n f
143 -> FwdPass m n f'
144 -> FwdPass m n (f, f')
145 -- @ end pairf.tex
146 pairFwd pass1 pass2 = FwdPass lattice transfer rewrite
147 where
148 lattice = pairLattice (fp_lattice pass1) (fp_lattice pass2)
149 transfer = mkFTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
150 where
151 tf :: forall t t1 t2 t3 t4.
152 (t4 -> t -> t2) -> (t4 -> t1 -> t3) -> t4 -> (t, t1) -> (t2, t3)
153 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
154 tfb t1 t2 n (f1, f2) = mapMapWithKey withfb2 fb1
155 where fb1 = t1 n f1
156 fb2 = t2 n f2
157 withfb2 :: forall t. Label -> t -> (t, f')
158 withfb2 l f = (f, fromMaybe bot2 $ lookupFact l fb2)
159 bot2 = fact_bot (fp_lattice pass2)
160 (tf1, tm1, tl1) = getFTransfer3 (fp_transfer pass1)
161 (tf2, tm2, tl2) = getFTransfer3 (fp_transfer pass2)
162 rewrite = lift fst (fp_rewrite pass1) `thenFwdRw` lift snd (fp_rewrite pass2)
163 where
164 lift :: forall f m' n' f'.
165 Monad m' =>
166 (f' -> f) -> FwdRewrite m' n' f -> FwdRewrite m' n' f'
167 lift proj = wrapFR project
168 where project :: forall m m1 t t1.
169 (Monad m1, Monad m) =>
170 (t1 -> f -> m (m1 (t, FwdRewrite m' n' f)))
171 -> t1
172 -> f'
173 -> m (m1 (t, FwdRewrite m' n' f'))
174 project rw = \n pair -> liftM (liftM repair) $ rw n (proj pair)
175 repair :: forall t.
176 (t, FwdRewrite m' n' f) -> (t, FwdRewrite m' n' f')
177 repair (g, rw') = (g, lift proj rw')
178
179 pairBwd :: forall m n f f' .
180 Monad m => BwdPass m n f -> BwdPass m n f' -> BwdPass m n (f, f')
181 pairBwd pass1 pass2 = BwdPass lattice transfer rewrite
182 where
183 lattice = pairLattice (bp_lattice pass1) (bp_lattice pass2)
184 transfer = mkBTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
185 where
186 tf :: (t4 -> t -> t2) -> (t4 -> t1 -> t3) -> t4 -> (t, t1) -> (t2, t3)
187 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
188 tfb :: IsMap map =>
189 (t2 -> map a -> t)
190 -> (t2 -> map b -> t1)
191 -> t2
192 -> map (a, b)
193 -> (t, t1)
194 tfb t1 t2 n fb = (t1 n $ mapMap fst fb, t2 n $ mapMap snd fb)
195 (tf1, tm1, tl1) = getBTransfer3 (bp_transfer pass1)
196 (tf2, tm2, tl2) = getBTransfer3 (bp_transfer pass2)
197 rewrite = lift fst (bp_rewrite pass1) `thenBwdRw` lift snd (bp_rewrite pass2)
198 where
199 lift :: forall f1 .
200 ((f, f') -> f1) -> BwdRewrite m n f1 -> BwdRewrite m n (f, f')
201 lift proj = wrapBR project
202 where project :: forall e x . Shape x
203 -> (n e x ->
204 Fact x f1 -> m (Maybe (Graph n e x, BwdRewrite m n f1)))
205 -> (n e x ->
206 Fact x (f,f') -> m (Maybe (Graph n e x, BwdRewrite m n (f,f'))))
207 project Open =
208 \rw n pair -> liftM (liftM repair) $ rw n ( proj pair)
209 project Closed =
210 \rw n pair -> liftM (liftM repair) $ rw n (mapMap proj pair)
211 repair :: forall t.
212 (t, BwdRewrite m n f1) -> (t, BwdRewrite m n (f, f'))
213 repair (g, rw') = (g, lift proj rw')
214 -- XXX specialize repair so that the cost
215 -- of discriminating is one per combinator not one
216 -- per rewrite
217
218 pairLattice :: forall f f' .
219 DataflowLattice f -> DataflowLattice f' -> DataflowLattice (f, f')
220 pairLattice l1 l2 =
221 DataflowLattice
222 { fact_name = fact_name l1 ++ " x " ++ fact_name l2
223 , fact_bot = (fact_bot l1, fact_bot l2)
224 , fact_join = join
225 }
226 where
227 join lbl (OldFact (o1, o2)) (NewFact (n1, n2)) = (c', (f1, f2))
228 where (c1, f1) = fact_join l1 lbl (OldFact o1) (NewFact n1)
229 (c2, f2) = fact_join l2 lbl (OldFact o2) (NewFact n2)
230 c' = case (c1, c2) of
231 (NoChange, NoChange) -> NoChange
232 _ -> SomeChange