Refactor the story around switches (#10137)
[ghc.git] / compiler / cmm / CmmCommonBlockElim.hs
1 {-# LANGUAGE GADTs #-}
2 module CmmCommonBlockElim
3 ( elimCommonBlocks
4 )
5 where
6
7
8 import BlockId
9 import Cmm
10 import CmmUtils
11 import CmmSwitch (eqSwitchTargetWith)
12 import CmmContFlowOpt
13 import Prelude hiding (iterate, succ, unzip, zip)
14
15 import Hoopl hiding (ChangeFlag)
16 import Data.Bits
17 import Data.Maybe (mapMaybe)
18 import qualified Data.List as List
19 import Data.Word
20 import qualified Data.Map as M
21 import Outputable
22 import UniqFM
23
24 my_trace :: String -> SDoc -> a -> a
25 my_trace = if False then pprTrace else \_ _ a -> a
26
27 -- -----------------------------------------------------------------------------
28 -- Eliminate common blocks
29
30 -- If two blocks are identical except for the label on the first node,
31 -- then we can eliminate one of the blocks. To ensure that the semantics
32 -- of the program are preserved, we have to rewrite each predecessor of the
33 -- eliminated block to proceed with the block we keep.
34
35 -- The algorithm iterates over the blocks in the graph,
36 -- checking whether it has seen another block that is equal modulo labels.
37 -- If so, then it adds an entry in a map indicating that the new block
38 -- is made redundant by the old block.
39 -- Otherwise, it is added to the useful blocks.
40
41 -- TODO: Use optimization fuel
42 elimCommonBlocks :: CmmGraph -> CmmGraph
43 elimCommonBlocks g = replaceLabels env $ copyTicks env g
44 where
45 env = iterate hashed_blocks mapEmpty
46 hashed_blocks = map (\b -> (hash_block b, b)) $ postorderDfs g
47
48 -- Iterate over the blocks until convergence
49 iterate :: [(HashCode,CmmBlock)] -> BlockEnv BlockId -> BlockEnv BlockId
50 iterate blocks subst =
51 case foldl common_block (False, emptyUFM, subst) blocks of
52 (changed, _, subst)
53 | changed -> iterate blocks subst
54 | otherwise -> subst
55
56 type State = (ChangeFlag, UniqFM [CmmBlock], BlockEnv BlockId)
57
58 type ChangeFlag = Bool
59 type HashCode = Int
60
61 -- Try to find a block that is equal (or ``common'') to b.
62 common_block :: State -> (HashCode, CmmBlock) -> State
63 common_block (old_change, bmap, subst) (hash, b) =
64 case lookupUFM bmap hash of
65 Just bs -> case (List.find (eqBlockBodyWith (eqBid subst) b) bs,
66 mapLookup bid subst) of
67 (Just b', Nothing) -> addSubst b'
68 (Just b', Just b'') | entryLabel b' /= b'' -> addSubst b'
69 | otherwise -> (old_change, bmap, subst)
70 _ -> (old_change, addToUFM bmap hash (b : bs), subst)
71 Nothing -> (old_change, addToUFM bmap hash [b], subst)
72 where bid = entryLabel b
73 addSubst b' = my_trace "found new common block" (ppr bid <> char '=' <> ppr (entryLabel b')) $
74 (True, bmap, mapInsert bid (entryLabel b') subst)
75
76
77 -- -----------------------------------------------------------------------------
78 -- Hashing and equality on blocks
79
80 -- Below here is mostly boilerplate: hashing blocks ignoring labels,
81 -- and comparing blocks modulo a label mapping.
82
83 -- To speed up comparisons, we hash each basic block modulo labels.
84 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
85 -- but it should be fast and good enough.
86 hash_block :: CmmBlock -> HashCode
87 hash_block block =
88 fromIntegral (foldBlockNodesB3 (hash_fst, hash_mid, hash_lst) block (0 :: Word32) .&. (0x7fffffff :: Word32))
89 -- UniqFM doesn't like negative Ints
90 where hash_fst _ h = h
91 hash_mid m h = hash_node m + h `shiftL` 1
92 hash_lst m h = hash_node m + h `shiftL` 1
93
94 hash_node :: CmmNode O x -> Word32
95 hash_node n | dont_care n = 0 -- don't care
96 hash_node (CmmUnwind _ e) = hash_e e
97 hash_node (CmmAssign r e) = hash_reg r + hash_e e
98 hash_node (CmmStore e e') = hash_e e + hash_e e'
99 hash_node (CmmUnsafeForeignCall t _ as) = hash_tgt t + hash_list hash_e as
100 hash_node (CmmBranch _) = 23 -- NB. ignore the label
101 hash_node (CmmCondBranch p _ _) = hash_e p
102 hash_node (CmmCall e _ _ _ _ _) = hash_e e
103 hash_node (CmmForeignCall t _ _ _ _ _ _) = hash_tgt t
104 hash_node (CmmSwitch e _) = hash_e e
105 hash_node _ = error "hash_node: unknown Cmm node!"
106
107 hash_reg :: CmmReg -> Word32
108 hash_reg (CmmLocal _) = 117
109 hash_reg (CmmGlobal _) = 19
110
111 hash_e :: CmmExpr -> Word32
112 hash_e (CmmLit l) = hash_lit l
113 hash_e (CmmLoad e _) = 67 + hash_e e
114 hash_e (CmmReg r) = hash_reg r
115 hash_e (CmmMachOp _ es) = hash_list hash_e es -- pessimal - no operator check
116 hash_e (CmmRegOff r i) = hash_reg r + cvt i
117 hash_e (CmmStackSlot _ _) = 13
118
119 hash_lit :: CmmLit -> Word32
120 hash_lit (CmmInt i _) = fromInteger i
121 hash_lit (CmmFloat r _) = truncate r
122 hash_lit (CmmVec ls) = hash_list hash_lit ls
123 hash_lit (CmmLabel _) = 119 -- ugh
124 hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
125 hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
126 hash_lit (CmmBlock _) = 191 -- ugh
127 hash_lit (CmmHighStackMark) = cvt 313
128
129 hash_tgt (ForeignTarget e _) = hash_e e
130 hash_tgt (PrimTarget _) = 31 -- lots of these
131
132 hash_list f = foldl (\z x -> f x + z) (0::Word32)
133
134 cvt = fromInteger . toInteger
135
136 -- | Ignore these node types for equality
137 dont_care :: CmmNode O x -> Bool
138 dont_care CmmComment {} = True
139 dont_care CmmTick {} = True
140 dont_care _other = False
141
142 -- Utilities: equality and substitution on the graph.
143
144 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
145 eqBid :: BlockEnv BlockId -> BlockId -> BlockId -> Bool
146 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
147 lookupBid :: BlockEnv BlockId -> BlockId -> BlockId
148 lookupBid subst bid = case mapLookup bid subst of
149 Just bid -> lookupBid subst bid
150 Nothing -> bid
151
152 -- Middle nodes and expressions can contain BlockIds, in particular in
153 -- CmmStackSlot and CmmBlock, so we have to use a special equality for
154 -- these.
155 --
156 eqMiddleWith :: (BlockId -> BlockId -> Bool)
157 -> CmmNode O O -> CmmNode O O -> Bool
158 eqMiddleWith eqBid (CmmAssign r1 e1) (CmmAssign r2 e2)
159 = r1 == r2 && eqExprWith eqBid e1 e2
160 eqMiddleWith eqBid (CmmStore l1 r1) (CmmStore l2 r2)
161 = eqExprWith eqBid l1 l2 && eqExprWith eqBid r1 r2
162 eqMiddleWith eqBid (CmmUnsafeForeignCall t1 r1 a1)
163 (CmmUnsafeForeignCall t2 r2 a2)
164 = t1 == t2 && r1 == r2 && and (zipWith (eqExprWith eqBid) a1 a2)
165 eqMiddleWith _ _ _ = False
166
167 eqExprWith :: (BlockId -> BlockId -> Bool)
168 -> CmmExpr -> CmmExpr -> Bool
169 eqExprWith eqBid = eq
170 where
171 CmmLit l1 `eq` CmmLit l2 = eqLit l1 l2
172 CmmLoad e1 _ `eq` CmmLoad e2 _ = e1 `eq` e2
173 CmmReg r1 `eq` CmmReg r2 = r1==r2
174 CmmRegOff r1 i1 `eq` CmmRegOff r2 i2 = r1==r2 && i1==i2
175 CmmMachOp op1 es1 `eq` CmmMachOp op2 es2 = op1==op2 && es1 `eqs` es2
176 CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2
177 _e1 `eq` _e2 = False
178
179 xs `eqs` ys = and (zipWith eq xs ys)
180
181 eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2
182 eqLit l1 l2 = l1 == l2
183
184 eqArea Old Old = True
185 eqArea (Young id1) (Young id2) = eqBid id1 id2
186 eqArea _ _ = False
187
188 -- Equality on the body of a block, modulo a function mapping block
189 -- IDs to block IDs.
190 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
191 eqBlockBodyWith eqBid block block'
192 = and (zipWith (eqMiddleWith eqBid) nodes nodes') &&
193 eqLastWith eqBid l l'
194 where (_,m,l) = blockSplit block
195 nodes = filter (not . dont_care) (blockToList m)
196 (_,m',l') = blockSplit block'
197 nodes' = filter (not . dont_care) (blockToList m')
198
199
200
201 eqLastWith :: (BlockId -> BlockId -> Bool) -> CmmNode O C -> CmmNode O C -> Bool
202 eqLastWith eqBid (CmmBranch bid1) (CmmBranch bid2) = eqBid bid1 bid2
203 eqLastWith eqBid (CmmCondBranch c1 t1 f1) (CmmCondBranch c2 t2 f2) =
204 c1 == c2 && eqBid t1 t2 && eqBid f1 f2
205 eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
206 t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
207 eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
208 e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
209 eqLastWith _ _ _ = False
210
211 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
212 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
213 eqMaybeWith _ Nothing Nothing = True
214 eqMaybeWith _ _ _ = False
215
216 -- | Given a block map, ensure that all "target" blocks are covered by
217 -- the same ticks as the respective "source" blocks. This not only
218 -- means copying ticks, but also adjusting tick scopes where
219 -- necessary.
220 copyTicks :: BlockEnv BlockId -> CmmGraph -> CmmGraph
221 copyTicks env g
222 | mapNull env = g
223 | otherwise = ofBlockMap (g_entry g) $ mapMap copyTo blockMap
224 where -- Reverse block merge map
225 blockMap = toBlockMap g
226 revEnv = mapFoldWithKey insertRev M.empty env
227 insertRev k x = M.insertWith (const (k:)) x [k]
228 -- Copy ticks and scopes into the given block
229 copyTo block = case M.lookup (entryLabel block) revEnv of
230 Nothing -> block
231 Just ls -> foldr copy block $ mapMaybe (flip mapLookup blockMap) ls
232 copy from to =
233 let ticks = blockTicks from
234 CmmEntry _ scp0 = firstNode from
235 (CmmEntry lbl scp1, code) = blockSplitHead to
236 in CmmEntry lbl (combineTickScopes scp0 scp1) `blockJoinHead`
237 foldr blockCons code (map CmmTick ticks)