BlockId: remove BlockMap and BlockSet synonyms
[ghc.git] / compiler / cmm / CmmCommonBlockElim.hs
1 {-# LANGUAGE GADTs, BangPatterns #-}
2 module CmmCommonBlockElim
3 ( elimCommonBlocks
4 )
5 where
6
7
8 import BlockId
9 import Cmm
10 import CmmUtils
11 import CmmSwitch (eqSwitchTargetWith)
12 import CmmContFlowOpt
13 -- import PprCmm ()
14 import Prelude hiding (iterate, succ, unzip, zip)
15
16 import Hoopl hiding (ChangeFlag)
17 import Data.Bits
18 import Data.Maybe (mapMaybe)
19 import qualified Data.List as List
20 import Data.Word
21 import qualified Data.Map as M
22 import Outputable
23 import UniqFM
24 import UniqDFM
25 import qualified TrieMap as TM
26 import Unique
27 import Control.Arrow (first, second)
28
29 -- -----------------------------------------------------------------------------
30 -- Eliminate common blocks
31
32 -- If two blocks are identical except for the label on the first node,
33 -- then we can eliminate one of the blocks. To ensure that the semantics
34 -- of the program are preserved, we have to rewrite each predecessor of the
35 -- eliminated block to proceed with the block we keep.
36
37 -- The algorithm iterates over the blocks in the graph,
38 -- checking whether it has seen another block that is equal modulo labels.
39 -- If so, then it adds an entry in a map indicating that the new block
40 -- is made redundant by the old block.
41 -- Otherwise, it is added to the useful blocks.
42
43 -- To avoid comparing every block with every other block repeatedly, we group
44 -- them by
45 -- * a hash of the block, ignoring labels (explained below)
46 -- * the list of outgoing labels
47 -- The hash is invariant under relabeling, so we only ever compare within
48 -- the same group of blocks.
49 --
50 -- The list of outgoing labels is updated as we merge blocks (that is why they
51 -- are not included in the hash, which we want to calculate only once).
52 --
53 -- All in all, two blocks should never be compared if they have different
54 -- hashes, and at most once otherwise. Previously, we were slower, and people
55 -- rightfully complained: #10397
56
57 -- TODO: Use optimization fuel
58 elimCommonBlocks :: CmmGraph -> CmmGraph
59 elimCommonBlocks g = replaceLabels env $ copyTicks env g
60 where
61 env = iterate mapEmpty blocks_with_key
62 groups = groupByInt hash_block (postorderDfs g)
63 blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups]
64
65 -- Invariant: The blocks in the list are pairwise distinct
66 -- (so avoid comparing them again)
67 type DistinctBlocks = [CmmBlock]
68 type Key = [Label]
69 type Subst = LabelMap BlockId
70
71 -- The outer list groups by hash. We retain this grouping throughout.
72 iterate :: Subst -> [[(Key, DistinctBlocks)]] -> Subst
73 iterate subst blocks
74 | mapNull new_substs = subst
75 | otherwise = iterate subst' updated_blocks
76 where
77 grouped_blocks :: [[(Key, [DistinctBlocks])]]
78 grouped_blocks = map groupByLabel blocks
79
80 merged_blocks :: [[(Key, DistinctBlocks)]]
81 (new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
82 where
83 go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db))
84 where
85 (new_subst2, db) = mergeBlockList subst dbs
86
87 subst' = subst `mapUnion` new_substs
88 updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks
89
90 mergeBlocks :: Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks)
91 mergeBlocks subst existing new = go new
92 where
93 go [] = (mapEmpty, existing)
94 go (b:bs) = case List.find (eqBlockBodyWith (eqBid subst) b) existing of
95 -- This block is a duplicate. Drop it, and add it to the substitution
96 Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
97 -- This block is not a duplicate, keep it.
98 Nothing -> second (b:) $ go bs
99
100 mergeBlockList :: Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks)
101 mergeBlockList _ [] = pprPanic "mergeBlockList" empty
102 mergeBlockList subst (b:bs) = go mapEmpty b bs
103 where
104 go !new_subst1 b [] = (new_subst1, b)
105 go !new_subst1 b1 (b2:bs) = go new_subst b bs
106 where
107 (new_subst2, b) = mergeBlocks subst b1 b2
108 new_subst = new_subst1 `mapUnion` new_subst2
109
110
111 -- -----------------------------------------------------------------------------
112 -- Hashing and equality on blocks
113
114 -- Below here is mostly boilerplate: hashing blocks ignoring labels,
115 -- and comparing blocks modulo a label mapping.
116
117 -- To speed up comparisons, we hash each basic block modulo jump labels.
118 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
119 -- but it should be fast and good enough.
120
121 -- We want to get as many small buckets as possible, as comparing blocks is
122 -- expensive. So include as much as possible in the hash. Ideally everything
123 -- that is compared with (==) in eqBlockBodyWith.
124
125 type HashCode = Int
126
127 hash_block :: CmmBlock -> HashCode
128 hash_block block =
129 fromIntegral (foldBlockNodesB3 (hash_fst, hash_mid, hash_lst) block (0 :: Word32) .&. (0x7fffffff :: Word32))
130 -- UniqFM doesn't like negative Ints
131 where hash_fst _ h = h
132 hash_mid m h = hash_node m + h `shiftL` 1
133 hash_lst m h = hash_node m + h `shiftL` 1
134
135 hash_node :: CmmNode O x -> Word32
136 hash_node n | dont_care n = 0 -- don't care
137 hash_node (CmmUnwind _ e) = hash_e e
138 hash_node (CmmAssign r e) = hash_reg r + hash_e e
139 hash_node (CmmStore e e') = hash_e e + hash_e e'
140 hash_node (CmmUnsafeForeignCall t _ as) = hash_tgt t + hash_list hash_e as
141 hash_node (CmmBranch _) = 23 -- NB. ignore the label
142 hash_node (CmmCondBranch p _ _ _) = hash_e p
143 hash_node (CmmCall e _ _ _ _ _) = hash_e e
144 hash_node (CmmForeignCall t _ _ _ _ _ _) = hash_tgt t
145 hash_node (CmmSwitch e _) = hash_e e
146 hash_node _ = error "hash_node: unknown Cmm node!"
147
148 hash_reg :: CmmReg -> Word32
149 hash_reg (CmmLocal localReg) = hash_unique localReg -- important for performance, see #10397
150 hash_reg (CmmGlobal _) = 19
151
152 hash_e :: CmmExpr -> Word32
153 hash_e (CmmLit l) = hash_lit l
154 hash_e (CmmLoad e _) = 67 + hash_e e
155 hash_e (CmmReg r) = hash_reg r
156 hash_e (CmmMachOp _ es) = hash_list hash_e es -- pessimal - no operator check
157 hash_e (CmmRegOff r i) = hash_reg r + cvt i
158 hash_e (CmmStackSlot _ _) = 13
159
160 hash_lit :: CmmLit -> Word32
161 hash_lit (CmmInt i _) = fromInteger i
162 hash_lit (CmmFloat r _) = truncate r
163 hash_lit (CmmVec ls) = hash_list hash_lit ls
164 hash_lit (CmmLabel _) = 119 -- ugh
165 hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
166 hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
167 hash_lit (CmmBlock _) = 191 -- ugh
168 hash_lit (CmmHighStackMark) = cvt 313
169
170 hash_tgt (ForeignTarget e _) = hash_e e
171 hash_tgt (PrimTarget _) = 31 -- lots of these
172
173 hash_list f = foldl (\z x -> f x + z) (0::Word32)
174
175 cvt = fromInteger . toInteger
176
177 hash_unique :: Uniquable a => a -> Word32
178 hash_unique = cvt . getKey . getUnique
179
180 -- | Ignore these node types for equality
181 dont_care :: CmmNode O x -> Bool
182 dont_care CmmComment {} = True
183 dont_care CmmTick {} = True
184 dont_care _other = False
185
186 -- Utilities: equality and substitution on the graph.
187
188 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
189 eqBid :: LabelMap BlockId -> BlockId -> BlockId -> Bool
190 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
191 lookupBid :: LabelMap BlockId -> BlockId -> BlockId
192 lookupBid subst bid = case mapLookup bid subst of
193 Just bid -> lookupBid subst bid
194 Nothing -> bid
195
196 -- Middle nodes and expressions can contain BlockIds, in particular in
197 -- CmmStackSlot and CmmBlock, so we have to use a special equality for
198 -- these.
199 --
200 eqMiddleWith :: (BlockId -> BlockId -> Bool)
201 -> CmmNode O O -> CmmNode O O -> Bool
202 eqMiddleWith eqBid (CmmAssign r1 e1) (CmmAssign r2 e2)
203 = r1 == r2 && eqExprWith eqBid e1 e2
204 eqMiddleWith eqBid (CmmStore l1 r1) (CmmStore l2 r2)
205 = eqExprWith eqBid l1 l2 && eqExprWith eqBid r1 r2
206 eqMiddleWith eqBid (CmmUnsafeForeignCall t1 r1 a1)
207 (CmmUnsafeForeignCall t2 r2 a2)
208 = t1 == t2 && r1 == r2 && and (zipWith (eqExprWith eqBid) a1 a2)
209 eqMiddleWith _ _ _ = False
210
211 eqExprWith :: (BlockId -> BlockId -> Bool)
212 -> CmmExpr -> CmmExpr -> Bool
213 eqExprWith eqBid = eq
214 where
215 CmmLit l1 `eq` CmmLit l2 = eqLit l1 l2
216 CmmLoad e1 _ `eq` CmmLoad e2 _ = e1 `eq` e2
217 CmmReg r1 `eq` CmmReg r2 = r1==r2
218 CmmRegOff r1 i1 `eq` CmmRegOff r2 i2 = r1==r2 && i1==i2
219 CmmMachOp op1 es1 `eq` CmmMachOp op2 es2 = op1==op2 && es1 `eqs` es2
220 CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2
221 _e1 `eq` _e2 = False
222
223 xs `eqs` ys = and (zipWith eq xs ys)
224
225 eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2
226 eqLit l1 l2 = l1 == l2
227
228 eqArea Old Old = True
229 eqArea (Young id1) (Young id2) = eqBid id1 id2
230 eqArea _ _ = False
231
232 -- Equality on the body of a block, modulo a function mapping block
233 -- IDs to block IDs.
234 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
235 eqBlockBodyWith eqBid block block'
236 {-
237 | equal = pprTrace "equal" (vcat [ppr block, ppr block']) True
238 | otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False
239 -}
240 = equal
241 where (_,m,l) = blockSplit block
242 nodes = filter (not . dont_care) (blockToList m)
243 (_,m',l') = blockSplit block'
244 nodes' = filter (not . dont_care) (blockToList m')
245
246 equal = and (zipWith (eqMiddleWith eqBid) nodes nodes') &&
247 eqLastWith eqBid l l'
248
249
250 eqLastWith :: (BlockId -> BlockId -> Bool) -> CmmNode O C -> CmmNode O C -> Bool
251 eqLastWith eqBid (CmmBranch bid1) (CmmBranch bid2) = eqBid bid1 bid2
252 eqLastWith eqBid (CmmCondBranch c1 t1 f1 l1) (CmmCondBranch c2 t2 f2 l2) =
253 c1 == c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
254 eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
255 t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
256 eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
257 e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
258 eqLastWith _ _ _ = False
259
260 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
261 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
262 eqMaybeWith _ Nothing Nothing = True
263 eqMaybeWith _ _ _ = False
264
265 -- | Given a block map, ensure that all "target" blocks are covered by
266 -- the same ticks as the respective "source" blocks. This not only
267 -- means copying ticks, but also adjusting tick scopes where
268 -- necessary.
269 copyTicks :: LabelMap BlockId -> CmmGraph -> CmmGraph
270 copyTicks env g
271 | mapNull env = g
272 | otherwise = ofBlockMap (g_entry g) $ mapMap copyTo blockMap
273 where -- Reverse block merge map
274 blockMap = toBlockMap g
275 revEnv = mapFoldWithKey insertRev M.empty env
276 insertRev k x = M.insertWith (const (k:)) x [k]
277 -- Copy ticks and scopes into the given block
278 copyTo block = case M.lookup (entryLabel block) revEnv of
279 Nothing -> block
280 Just ls -> foldr copy block $ mapMaybe (flip mapLookup blockMap) ls
281 copy from to =
282 let ticks = blockTicks from
283 CmmEntry _ scp0 = firstNode from
284 (CmmEntry lbl scp1, code) = blockSplitHead to
285 in CmmEntry lbl (combineTickScopes scp0 scp1) `blockJoinHead`
286 foldr blockCons code (map CmmTick ticks)
287
288 -- Group by [Label]
289 groupByLabel :: [(Key, a)] -> [(Key, [a])]
290 groupByLabel = go (TM.emptyTM :: TM.ListMap UniqDFM a)
291 where
292 go !m [] = TM.foldTM (:) m []
293 go !m ((k,v) : entries) = go (TM.alterTM k' adjust m) entries
294 where k' = map getUnique k
295 adjust Nothing = Just (k,[v])
296 adjust (Just (_,vs)) = Just (k,v:vs)
297
298
299 groupByInt :: (a -> Int) -> [a] -> [[a]]
300 groupByInt f xs = nonDetEltsUFM $ List.foldl' go emptyUFM xs
301 -- See Note [Unique Determinism and code generation]
302 where go m x = alterUFM (Just . maybe [x] (x:)) m (f x)