4ed38456e4487e09e991841db94b2efd6a5bbcc7
[packages/hoopl.git] / src / Compiler / Hoopl / Combinators.hs
1 {-# LANGUAGE RankNTypes, LiberalTypeSynonyms, ScopedTypeVariables, GADTs #-}
2
3 module Compiler.Hoopl.Combinators
4 ( thenFwdRw
5 , deepFwdRw3, deepFwdRw, iterFwdRw
6 , thenBwdRw
7 , deepBwdRw3, deepBwdRw, iterBwdRw
8 , pairFwd, pairBwd, pairLattice
9 )
10
11 where
12
13 import Control.Monad
14 import Data.Maybe
15
16 import Compiler.Hoopl.Collections
17 import Compiler.Hoopl.Dataflow
18 import Compiler.Hoopl.Fuel
19 import Compiler.Hoopl.Graph (Graph, C, O, Shape(..))
20 import Compiler.Hoopl.Label
21
22 ----------------------------------------------------------------
23
24 deepFwdRw3 :: FuelMonad m
25 => (n C O -> f -> m (Maybe (Graph n C O)))
26 -> (n O O -> f -> m (Maybe (Graph n O O)))
27 -> (n O C -> f -> m (Maybe (Graph n O C)))
28 -> (FwdRewrite m n f)
29 deepFwdRw :: FuelMonad m
30 => (forall e x . n e x -> f -> m (Maybe (Graph n e x))) -> FwdRewrite m n f
31 deepFwdRw3 f m l = iterFwdRw $ mkFRewrite3 f m l
32 deepFwdRw f = deepFwdRw3 f f f
33
34 -- N.B. rw3, rw3', and rw3a are triples of functions.
35 -- But rw and rw' are single functions.
36 -- @ start comb1.tex
37 thenFwdRw :: Monad m
38 => FwdRewrite m n f
39 -> FwdRewrite m n f
40 -> FwdRewrite m n f
41 -- @ end comb1.tex
42 thenFwdRw rw3 rw3' = wrapFR2 thenrw rw3 rw3'
43 where
44 thenrw rw rw' n f = rw n f >>= fwdRes
45 where fwdRes Nothing = rw' n f
46 fwdRes (Just gr) = return $ Just $ fadd_rw rw3' gr
47
48 -- @ start iterf.tex
49 iterFwdRw :: Monad m
50 => FwdRewrite m n f
51 -> FwdRewrite m n f
52 -- @ end iterf.tex
53 iterFwdRw rw3 = wrapFR iter rw3
54 where iter rw n = (liftM $ liftM $ fadd_rw (iterFwdRw rw3)) . rw n
55 _iter = frewrite_cps (return . Just . fadd_rw (iterFwdRw rw3)) (return Nothing)
56
57 -- | Function inspired by 'rew' in the paper
58 frewrite_cps :: Monad m
59 => ((Graph n e x, FwdRewrite m n f) -> m a)
60 -> m a
61 -> (forall e x . n e x -> f -> m (Maybe (Graph n e x, FwdRewrite m n f)))
62 -> n e x
63 -> f
64 -> m a
65 frewrite_cps j n rw node f =
66 do mg <- rw node f
67 case mg of Nothing -> n
68 Just gr -> j gr
69
70
71
72 -- | Function inspired by 'add' in the paper
73 fadd_rw :: Monad m
74 => FwdRewrite m n f
75 -> (Graph n e x, FwdRewrite m n f)
76 -> (Graph n e x, FwdRewrite m n f)
77 fadd_rw rw2 (g, rw1) = (g, rw1 `thenFwdRw` rw2)
78
79 ----------------------------------------------------------------
80
81 deepBwdRw3 :: FuelMonad m
82 => (n C O -> f -> m (Maybe (Graph n C O)))
83 -> (n O O -> f -> m (Maybe (Graph n O O)))
84 -> (n O C -> FactBase f -> m (Maybe (Graph n O C)))
85 -> (BwdRewrite m n f)
86 deepBwdRw :: FuelMonad m
87 => (forall e x . n e x -> Fact x f -> m (Maybe (Graph n e x)))
88 -> BwdRewrite m n f
89 deepBwdRw3 f m l = iterBwdRw $ mkBRewrite3 f m l
90 deepBwdRw f = deepBwdRw3 f f f
91
92
93 thenBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f -> BwdRewrite m n f
94 thenBwdRw rw1 rw2 = wrapBR2 f rw1 rw2
95 where f _ rw1 rw2' n f = do
96 res1 <- rw1 n f
97 case res1 of
98 Nothing -> rw2' n f
99 Just gr -> return $ Just $ badd_rw rw2 gr
100
101 iterBwdRw :: Monad m => BwdRewrite m n f -> BwdRewrite m n f
102 iterBwdRw rw = wrapBR f rw
103 where f _ rw' n f = liftM (liftM (badd_rw (iterBwdRw rw))) (rw' n f)
104
105 -- | Function inspired by 'add' in the paper
106 badd_rw :: Monad m
107 => BwdRewrite m n f
108 -> (Graph n e x, BwdRewrite m n f)
109 -> (Graph n e x, BwdRewrite m n f)
110 badd_rw rw2 (g, rw1) = (g, rw1 `thenBwdRw` rw2)
111
112
113 -- @ start pairf.tex
114 pairFwd :: Monad m
115 => FwdPass m n f
116 -> FwdPass m n f'
117 -> FwdPass m n (f, f')
118 -- @ end pairf.tex
119 pairFwd pass1 pass2 = FwdPass lattice transfer rewrite
120 where
121 lattice = pairLattice (fp_lattice pass1) (fp_lattice pass2)
122 transfer = mkFTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
123 where
124 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
125 tfb t1 t2 n (f1, f2) = mapMapWithKey withfb2 fb1
126 where fb1 = t1 n f1
127 fb2 = t2 n f2
128 withfb2 l f = (f, fromMaybe bot2 $ lookupFact l fb2)
129 bot2 = fact_bot (fp_lattice pass2)
130 (tf1, tm1, tl1) = getFTransfer3 (fp_transfer pass1)
131 (tf2, tm2, tl2) = getFTransfer3 (fp_transfer pass2)
132 rewrite = lift fst (fp_rewrite pass1) `thenFwdRw` lift snd (fp_rewrite pass2)
133 where
134 lift proj = wrapFR project
135 where project rw = \n pair -> liftM (liftM repair) $ rw n (proj pair)
136 repair (g, rw') = (g, lift proj rw')
137
138 pairBwd :: forall m n f f' .
139 Monad m => BwdPass m n f -> BwdPass m n f' -> BwdPass m n (f, f')
140 pairBwd pass1 pass2 = BwdPass lattice transfer rewrite
141 where
142 lattice = pairLattice (bp_lattice pass1) (bp_lattice pass2)
143 transfer = mkBTransfer3 (tf tf1 tf2) (tf tm1 tm2) (tfb tl1 tl2)
144 where
145 tf t1 t2 n (f1, f2) = (t1 n f1, t2 n f2)
146 tfb t1 t2 n fb = (t1 n $ mapMap fst fb, t2 n $ mapMap snd fb)
147 (tf1, tm1, tl1) = getBTransfer3 (bp_transfer pass1)
148 (tf2, tm2, tl2) = getBTransfer3 (bp_transfer pass2)
149 rewrite = lift fst (bp_rewrite pass1) `thenBwdRw` lift snd (bp_rewrite pass2)
150 where
151 lift :: forall f1 .
152 ((f, f') -> f1) -> BwdRewrite m n f1 -> BwdRewrite m n (f, f')
153 lift proj = wrapBR project
154 where project :: forall e x . Shape x
155 -> (n e x ->
156 Fact x f1 -> m (Maybe (Graph n e x, BwdRewrite m n f1)))
157 -> (n e x ->
158 Fact x (f,f') -> m (Maybe (Graph n e x, BwdRewrite m n (f,f'))))
159 project Open =
160 \rw n pair -> liftM (liftM repair) $ rw n ( proj pair)
161 project Closed =
162 \rw n pair -> liftM (liftM repair) $ rw n (mapMap proj pair)
163 repair (g, rw') = (g, lift proj rw')
164 -- XXX specialize repair so that the cost
165 -- of discriminating is one per combinator not one
166 -- per rewrite
167
168 pairLattice :: forall f f' .
169 DataflowLattice f -> DataflowLattice f' -> DataflowLattice (f, f')
170 pairLattice l1 l2 =
171 DataflowLattice
172 { fact_name = fact_name l1 ++ " x " ++ fact_name l2
173 , fact_bot = (fact_bot l1, fact_bot l2)
174 , fact_join = join
175 }
176 where
177 join lbl (OldFact (o1, o2)) (NewFact (n1, n2)) = (c', (f1, f2))
178 where (c1, f1) = fact_join l1 lbl (OldFact o1) (NewFact n1)
179 (c2, f2) = fact_join l2 lbl (OldFact o2) (NewFact n2)
180 c' = case (c1, c2) of
181 (NoChange, NoChange) -> NoChange
182 _ -> SomeChange