cmm/CBE: Collapse blocks equivalent up to alpha renaming of local registers
authorBen Gamari <bgamari.foss@gmail.com>
Tue, 19 Sep 2017 20:57:41 +0000 (16:57 -0400)
committerBen Gamari <ben@smart-cactus.org>
Tue, 19 Sep 2017 20:57:43 +0000 (16:57 -0400)
As noted in #14226, the common block elimination pass currently
implements an extremely strict equivalence relation, demanding that two
blocks are equivalent including the names of their local registers. This
is quite restrictive and severely hampers the effectiveness of the pass.

Here we allow the CBE pass to collapse blocks which are equivalent up to
alpha renaming of locally-bound local registers. This is completely safe
and catches many more duplicate blocks.

Test Plan: Validate

Reviewers: austin, simonmar, michalt

Reviewed By: michalt

Subscribers: rwbarton, thomie

GHC Trac Issues: #14226

Differential Revision: https://phabricator.haskell.org/D3973

compiler/cmm/CmmCommonBlockElim.hs
compiler/cmm/Hoopl/Block.hs

index f635520..aca39bc 100644 (file)
@@ -126,39 +126,106 @@ mergeBlockList subst (b:bs) = go mapEmpty b bs
 -- expensive. So include as much as possible in the hash. Ideally everything
 -- that is compared with (==) in eqBlockBodyWith.
 
+{-
+Note [Equivalence up to local registers in CBE]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+CBE treats two blocks which are equivalent up to alpha-renaming of locally-bound
+local registers as equivalent. This was not always the case (see #14226) but is
+quite important for effective CBE. For instance, consider the blocks,
+
+    c2VZ: // global
+        _c2Yd::I64 = _s2Se::I64 + 1;
+        _s2Sx::I64 = _c2Yd::I64;
+        _s2Se::I64 = _s2Sx::I64;
+        goto c2TE;
+
+    c2VY: // global
+        _c2Yb::I64 = _s2Se::I64 + 1;
+        _s2Sw::I64 = _c2Yb::I64;
+        _s2Se::I64 = _s2Sw::I64;
+        goto c2TE;
+
+These clearly implement precisely the same logic, differing only register
+naming. This happens quite often in the code produced by GHC.
+
+This alpha-equivalence relation must be accounted for in two places:
+
+ 1. the block hash function (hash_block), which we use for approximate "binning"
+ 2. the exact block comparison function, which computes pair-wise equivalence
+
+In (1) we maintain a de Bruijn numbering of each block's locally-bound local
+registers and compute the hash relative to this numbering.
+
+For (2) we maintain a substitution which maps the local registers of one block
+onto those of the other. We then compare local registers modulo this
+substitution.
+
+-}
+
 type HashCode = Int
 
+type LocalRegEnv a = UniqFM a
+type DeBruijn = Int
+
+-- | Maintains a de Bruijn numbering of local registers bound within a block.
+--
+-- See Note [Equivalence up to local registers in CBE]
+data HashEnv = HashEnv { localRegHashEnv :: !(LocalRegEnv DeBruijn)
+                       , nextIndex       :: !DeBruijn
+                       }
+
 hash_block :: CmmBlock -> HashCode
 hash_block block =
-  fromIntegral (foldBlockNodesB3 (hash_fst, hash_mid, hash_lst) block (0 :: Word32) .&. (0x7fffffff :: Word32))
-  -- UniqFM doesn't like negative Ints
-  where hash_fst _ h = h
-        hash_mid m h = hash_node m + h `shiftL` 1
-        hash_lst m h = hash_node m + h `shiftL` 1
-
-        hash_node :: CmmNode O x -> Word32
-        hash_node n | dont_care n = 0 -- don't care
-        hash_node (CmmAssign r e) = hash_reg r + hash_e e
-        hash_node (CmmStore e e') = hash_e e + hash_e e'
-        hash_node (CmmUnsafeForeignCall t _ as) = hash_tgt t + hash_list hash_e as
-        hash_node (CmmBranch _) = 23 -- NB. ignore the label
-        hash_node (CmmCondBranch p _ _ _) = hash_e p
-        hash_node (CmmCall e _ _ _ _ _) = hash_e e
-        hash_node (CmmForeignCall t _ _ _ _ _ _) = hash_tgt t
-        hash_node (CmmSwitch e _) = hash_e e
-        hash_node _ = error "hash_node: unknown Cmm node!"
-
-        hash_reg :: CmmReg -> Word32
-        hash_reg   (CmmLocal localReg) = hash_unique localReg -- important for performance, see #10397
-        hash_reg   (CmmGlobal _)    = 19
-
-        hash_e :: CmmExpr -> Word32
-        hash_e (CmmLit l) = hash_lit l
-        hash_e (CmmLoad e _) = 67 + hash_e e
-        hash_e (CmmReg r) = hash_reg r
-        hash_e (CmmMachOp _ es) = hash_list hash_e es -- pessimal - no operator check
-        hash_e (CmmRegOff r i) = hash_reg r + cvt i
-        hash_e (CmmStackSlot _ _) = 13
+  --pprTrace "hash_block" (ppr (entryLabel block) $$ ppr hash)
+  hash
+  where hash_fst _ (env, h) = (env, h)
+        hash_mid m (env, h) = let (env', h') = hash_node env m
+                              in (env', h' + h `shiftL` 1)
+        hash_lst m (env, h) = let (env', h') = hash_node env m
+                              in (env', h' + h `shiftL` 1)
+
+        hash =
+            let (_, raw_hash) =
+                    foldBlockNodesF3 (hash_fst, hash_mid, hash_lst)
+                                     block
+                                     (emptyEnv, 0 :: Word32)
+                emptyEnv = HashEnv mempty 0
+            in fromIntegral (raw_hash .&. (0x7fffffff :: Word32))
+               -- UniqFM doesn't like negative Ints
+
+        hash_node :: HashEnv -> CmmNode O x -> (HashEnv, Word32)
+        hash_node env n =
+            case n of
+              n | dont_care n -> pure_ 0  -- don't care
+              CmmAssign (CmmLocal r) e -> (bind_local_reg r env, hash_e env e)
+              CmmAssign r e   -> pure_ $ hash_reg env r + hash_e env e
+              CmmStore e e'   -> pure_ $ hash_e env e + hash_e env e'
+              CmmUnsafeForeignCall t _ as
+                              -> pure_ $ hash_tgt env t + hash_list (hash_e env) as
+              CmmBranch _     -> pure_ 23 -- NB. ignore the label
+              CmmCondBranch p _ _ _ -> pure_ $ hash_e env p
+              CmmCall e _ _ _ _ _   -> pure_ $ hash_e env e
+              CmmForeignCall t _ _ _ _ _ _ -> pure_ $ hash_tgt env t
+              CmmSwitch e _   -> pure_ $ hash_e env e
+              _               -> error "hash_node: unknown Cmm node!"
+          where pure_ x = (env, x)
+
+        hash_reg :: HashEnv -> CmmReg -> Word32
+        hash_reg env (CmmLocal localReg)
+          | Just idx <- lookupUFM (localRegHashEnv env) localReg
+          = fromIntegral idx
+          | otherwise
+          = hash_unique localReg -- important for performance, see #10397
+        hash_reg _  (CmmGlobal _)    = 19
+
+        hash_e :: HashEnv -> CmmExpr -> Word32
+        hash_e _   (CmmLit l) = hash_lit l
+        hash_e env (CmmLoad e _) = 67 + hash_e env e
+        hash_e env (CmmReg r) = hash_reg env r
+        hash_e env (CmmMachOp _ es) = hash_list (hash_e env) es -- pessimal - no operator check
+        hash_e env (CmmRegOff r i) = hash_reg env r + cvt i
+        hash_e _   (CmmStackSlot _ _) = 13
 
         hash_lit :: CmmLit -> Word32
         hash_lit (CmmInt i _) = fromInteger i
@@ -170,13 +237,21 @@ hash_block block =
         hash_lit (CmmBlock _) = 191 -- ugh
         hash_lit (CmmHighStackMark) = cvt 313
 
-        hash_tgt (ForeignTarget e _) = hash_e e
-        hash_tgt (PrimTarget _) = 31 -- lots of these
+        hash_tgt :: HashEnv -> ForeignTarget -> Word32
+        hash_tgt env (ForeignTarget e _) = hash_e env e
+        hash_tgt _   (PrimTarget _) = 31 -- lots of these
 
-        hash_list f = foldl (\z x -> f x + z) (0::Word32)
+        hash_list f = List.foldl' (\z x -> f x + z) (0::Word32)
 
         cvt = fromInteger . toInteger
 
+        bind_local_reg :: LocalReg -> HashEnv -> HashEnv
+        bind_local_reg reg env =
+            env { localRegHashEnv =
+                    addToUFM (localRegHashEnv env) reg (nextIndex env)
+                , nextIndex = nextIndex env + 1
+                }
+
         hash_unique :: Uniquable a => a -> Word32
         hash_unique = cvt . getKey . getUnique
 
@@ -197,35 +272,65 @@ lookupBid subst bid = case mapLookup bid subst of
                         Just bid  -> lookupBid subst bid
                         Nothing -> bid
 
+-- | Maps the local registers of one block to those of another
+--
+-- See Note [Equivalence up to local registers in CBE]
+type LocalRegMapping = LocalRegEnv LocalReg
+
 -- Middle nodes and expressions can contain BlockIds, in particular in
 -- CmmStackSlot and CmmBlock, so we have to use a special equality for
 -- these.
 --
 eqMiddleWith :: (BlockId -> BlockId -> Bool)
-             -> CmmNode O O -> CmmNode O O -> Bool
-eqMiddleWith eqBid (CmmAssign r1 e1) (CmmAssign r2 e2)
-  = r1 == r2 && eqExprWith eqBid e1 e2
-eqMiddleWith eqBid (CmmStore l1 r1) (CmmStore l2 r2)
-  = eqExprWith eqBid l1 l2 && eqExprWith eqBid r1 r2
-eqMiddleWith eqBid (CmmUnsafeForeignCall t1 r1 a1)
-                   (CmmUnsafeForeignCall t2 r2 a2)
-  = t1 == t2 && r1 == r2 && and (zipWith (eqExprWith eqBid) a1 a2)
-eqMiddleWith _ _ _ = False
+             -> LocalRegMapping
+             -> CmmNode O O -> CmmNode O O
+             -> (LocalRegMapping, Bool)
+eqMiddleWith eqBid env a b =
+  case (a, b) of
+    (CmmAssign (CmmLocal r1) e1,  CmmAssign (CmmLocal r2) e2) ->
+        let eq = eqExprWith eqBid env e1 e2
+            env' = addToUFM env r1 r2
+        in (env', eq)
+
+    (CmmAssign r1 e1,  CmmAssign r2 e2) ->
+        let eq = r1 == r2
+              && eqExprWith eqBid env e1 e2
+        in (env, eq)
+
+    (CmmStore l1 r1,  CmmStore l2 r2) ->
+        let eq = eqExprWith eqBid env l1 l2
+              && eqExprWith eqBid env r1 r2
+        in (env, eq)
+
+    (CmmUnsafeForeignCall t1 r1 a1,  CmmUnsafeForeignCall t2 r2 a2) ->
+        let eq = t1 == t2
+              && r1 == r2
+              && and (zipWith (eqExprWith eqBid env) a1 a2)
+        in (env, eq)
+
+    _ -> (env, False)
 
 eqExprWith :: (BlockId -> BlockId -> Bool)
+           -> LocalRegMapping
            -> CmmExpr -> CmmExpr -> Bool
-eqExprWith eqBid = eq
+eqExprWith eqBid env = eq
  where
   CmmLit l1          `eq` CmmLit l2          = eqLit l1 l2
   CmmLoad e1 _       `eq` CmmLoad e2 _       = e1 `eq` e2
-  CmmReg r1          `eq` CmmReg r2          = r1==r2
-  CmmRegOff r1 i1    `eq` CmmRegOff r2 i2    = r1==r2 && i1==i2
+  CmmReg r1          `eq` CmmReg r2          = r1 `eqReg` r2
+  CmmRegOff r1 i1    `eq` CmmRegOff r2 i2    = r1 `eqReg` r2 && i1==i2
   CmmMachOp op1 es1  `eq` CmmMachOp op2 es2  = op1==op2 && es1 `eqs` es2
   CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2
   _e1                `eq` _e2                = False
 
   xs `eqs` ys = and (zipWith eq xs ys)
 
+  -- See Note [Equivalence up to local registers in CBE]
+  CmmLocal a `eqReg` CmmLocal b
+    | Just a' <- lookupUFM env a
+    = a' == b
+  a `eqReg` b = a == b
+
   eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2
   eqLit l1 l2 = l1 == l2
 
@@ -241,25 +346,41 @@ eqBlockBodyWith eqBid block block'
   | equal     = pprTrace "equal" (vcat [ppr block, ppr block']) True
   | otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False
   -}
-  = equal
+  = equal_go emptyUFM nodes nodes'
   where (_,m,l)   = blockSplit block
         nodes     = filter (not . dont_care) (blockToList m)
         (_,m',l') = blockSplit block'
         nodes'    = filter (not . dont_care) (blockToList m')
 
-        equal = and (zipWith (eqMiddleWith eqBid) nodes nodes') &&
-                eqLastWith eqBid l l'
-
-
-eqLastWith :: (BlockId -> BlockId -> Bool) -> CmmNode O C -> CmmNode O C -> Bool
-eqLastWith eqBid (CmmBranch bid1) (CmmBranch bid2) = eqBid bid1 bid2
-eqLastWith eqBid (CmmCondBranch c1 t1 f1 l1) (CmmCondBranch c2 t2 f2 l2) =
-  c1 == c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
-eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
-  t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
-eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
-  e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
-eqLastWith _ _ _ = False
+        -- Compare middle nodes, accumulating a local register mapping as we go.
+        -- We also must ensure that the lists are of equal length. Finally,
+        -- compare the last nodes.
+        equal_go :: LocalRegMapping -> [CmmNode O O] -> [CmmNode O O] -> Bool
+        equal_go acc (a:as) (b:bs)
+          | let (acc', eq) = eqMiddleWith eqBid acc a b
+          , eq
+          = equal_go acc' as bs
+        equal_go acc [] [] = eqLastWith eqBid acc l l'
+        equal_go _   _  _  = False
+
+
+eqLastWith :: (BlockId -> BlockId -> Bool) -> LocalRegMapping
+           -> CmmNode O C -> CmmNode O C -> Bool
+eqLastWith eqBid env a b =
+  case (a, b) of
+    (CmmBranch bid1, CmmBranch bid2) ->
+        eqBid bid1 bid2
+    (CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) ->
+        eqExprWith eqBid env c1 c2
+        && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
+    (CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) ->
+        eqExprWith eqBid env t1 t2
+        && eqMaybeWith eqBid c1 c2
+        && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
+    (CmmSwitch e1 ids1, CmmSwitch e2 ids2) ->
+        eqExprWith eqBid env e1 e2
+        && eqSwitchTargetWith eqBid ids1 ids2
+    _ -> False
 
 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
index c4ff179..4561fef 100644 (file)
@@ -24,6 +24,7 @@ module Hoopl.Block
     , foldBlockNodesB
     , foldBlockNodesB3
     , foldBlockNodesF
+    , foldBlockNodesF3
     , isEmptyBlock
     , lastNode
     , mapBlock