Combine identical case alternatives in CSE
authorSimon Peyton Jones <simonpj@microsoft.com>
Tue, 28 Feb 2017 21:00:49 +0000 (16:00 -0500)
committerDavid Feuer <David.Feuer@gmail.com>
Tue, 28 Feb 2017 21:00:50 +0000 (16:00 -0500)
See Note [Combine case alternatives] in CSE.  This opportunity
surfaced when I was was studying early inlining.  It's easy (and
cheap) to exploit, and sometimes makes a worthwhile saving.

Reviewers: austin, bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D3194

compiler/simplCore/CSE.hs
testsuite/tests/numeric/should_compile/T7116.stdout

index 80013a3..31f0901 100644 (file)
@@ -12,17 +12,19 @@ module CSE (cseProgram, cseOneExpr) where
 
 import CoreSubst
 import Var              ( Var, isJoinId )
-import Id               ( Id, idType, idUnfolding, idInlineActivation
-                        , zapIdOccInfo, zapIdUsageInfo )
-import CoreUtils        ( mkAltExpr
+import VarEnv           ( elemInScopeSet )
+import Id               ( Id, idType, idInlineActivation, isDeadBinder
+                        , zapIdOccInfo, zapIdUsageInfo, idInlinePragma )
+import CoreUtils        ( mkAltExpr, eqExpr
                         , exprIsLiteralString
                         , stripTicksE, stripTicksT, mkTicks )
 import Literal          ( litIsTrivial )
 import Type             ( tyConAppArgs )
 import CoreSyn
 import Outputable
-import BasicTypes       ( isAlwaysActive )
+import BasicTypes       ( isAlwaysActive, isAnyInlinePragma )
 import TrieMap
+import Util             ( compareLength, filterOut )
 import Data.List        ( mapAccumL )
 
 {-
@@ -258,6 +260,27 @@ We could try and be careful by tracking which join points are still valid at
 each subexpression, but since join points aren't allocated or shared, there's
 less to gain by trying to CSE them.
 
+Note [CSE for recursive bindings]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+  f = \x ... f....
+  g = \y ... g ...
+where the "..." are identical.  Could we CSE them?  In full generality
+with mutual recursion it's quite hard; but for self-recursive bindings
+(which are very common) it's rather easy:
+
+* Maintain a separate cs_rec_map, that maps
+      (\f. (\x. ...f...) ) -> f
+  Note the \f in the domain of the mapping!
+
+* When we come across the binding for 'g', look up (\g. (\y. ...g...))
+  Bingo we get a hit.  So we can repace the 'g' binding with
+     g = f
+
+We can't use cs_map for this, because the key isn't an expression of
+the program; it's a kind of synthetic key for recursive bindings.
+
+
 ************************************************************************
 *                                                                      *
 \section{Common subexpression}
@@ -276,6 +299,26 @@ cseBind toplevel env (NonRec b e)
     (env1, b1) = addBinder env b
     (env2, b2) = addBinding env1 b b1 e1
 
+cseBind _ env (Rec [(in_id, rhs)])
+  | noCSE in_id
+  = (env1, Rec [(out_id, rhs')])
+
+  -- See Note [CSE for recursive bindings]
+  | Just previous <- lookupCSRecEnv env out_id rhs''
+  , let previous' = mkTicks ticks previous
+  = (extendCSSubst env1 in_id previous', NonRec out_id previous')
+
+  | otherwise
+  = (extendCSRecEnv env1 out_id rhs'' id_expr', Rec [(zapped_id, rhs')])
+
+  where
+    (env1, [out_id]) = addRecBinders env [in_id]
+    rhs'  = cseExpr env1 rhs
+    rhs'' = stripTicksE tickishFloatable rhs'
+    ticks = stripTicksT tickishFloatable rhs'
+    id_expr'  = varToCoreExpr out_id
+    zapped_id = zapIdUsageInfo out_id
+
 cseBind toplevel env (Rec pairs)
   = (env2, Rec pairs')
   where
@@ -296,9 +339,9 @@ addBinding :: CSEnv                      -- Includes InId->OutId cloning
 -- Extend the CSE env with a mapping [rhs -> out-id]
 -- unless we can instead just substitute [in-id -> rhs]
 addBinding env in_id out_id rhs'
-  | no_cse    = (env,                              out_id)
-  | use_subst = (extendCSSubst env in_id rhs',     out_id)
-  | otherwise = (extendCSEnv env rhs' id_expr', zapped_id)
+  | noCSE in_id = (env,                              out_id)
+  | use_subst   = (extendCSSubst env in_id rhs',     out_id)
+  | otherwise   = (extendCSEnv env rhs' id_expr', zapped_id)
   where
     id_expr'  = varToCoreExpr out_id
     zapped_id = zapIdUsageInfo out_id
@@ -312,13 +355,6 @@ addBinding env in_id out_id rhs'
        -- it is bad for performance if you don't do late demand
        -- analysis
 
-    no_cse = not (isAlwaysActive (idInlineActivation out_id))
-             -- See Note [CSE for INLINE and NOINLINE]
-          || isStableUnfolding (idUnfolding out_id)
-             -- See Note [CSE for stable unfoldings]
-          || isJoinId in_id
-             -- See Note [CSE for join points?]
-
     -- Should we use SUBSTITUTE or EXTEND?
     -- See Note [CSE for bindings]
     use_subst = case rhs' of
@@ -326,6 +362,16 @@ addBinding env in_id out_id rhs'
                    Lit l  -> litIsTrivial l
                    _      -> False
 
+noCSE :: InId -> Bool
+noCSE id = not (isAlwaysActive (idInlineActivation id))
+             -- See Note [CSE for INLINE and NOINLINE]
+         || isAnyInlinePragma (idInlinePragma id)
+             --isStableUnfolding (idUnfolding id)
+             -- See Note [CSE for stable unfoldings]
+         || isJoinId id
+             -- See Note [CSE for join points?]
+
+
 {-
 Note [Take care with literal strings]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -373,7 +419,7 @@ tryForCSE toplevel env expr
     -- top of the replaced sub-expression. This is probably not too
     -- useful in practice, but upholds our semantics.
 
-cseOneExpr :: CoreExpr -> CoreExpr
+cseOneExpr :: InExpr -> OutExpr
 cseOneExpr = cseExpr emptyCSEnv
 
 cseExpr :: CSEnv -> InExpr -> OutExpr
@@ -392,7 +438,8 @@ cseExpr env (Case e bndr ty alts) = cseCase env e bndr ty alts
 
 cseCase :: CSEnv -> InExpr -> InId -> InType -> [InAlt] -> OutExpr
 cseCase env scrut bndr ty alts
-  = Case scrut1 bndr3 ty' (map cse_alt alts)
+  = Case scrut1 bndr3 ty' $
+    combineAlts alt_env (map cse_alt alts)
   where
     ty' = substTy (csEnvSubst env) ty
     scrut1 = tryForCSE False env scrut
@@ -429,7 +476,42 @@ cseCase env scrut bndr ty alts
         where
           (env', args') = addBinders alt_env args
 
-{-
+combineAlts :: CSEnv -> [InAlt] -> [InAlt]
+-- See Note [Combine case alternatives]
+combineAlts env ((_,bndrs1,rhs1) : rest_alts)
+  | all isDeadBinder bndrs1
+  = (DEFAULT, [], rhs1) : filtered_alts
+  where
+    in_scope = substInScope (csEnvSubst env)
+    filtered_alts = filterOut identical rest_alts
+    identical (_con, bndrs, rhs) = all ok bndrs && eqExpr in_scope rhs1 rhs
+    ok bndr = isDeadBinder bndr || not (bndr `elemInScopeSet` in_scope)
+
+combineAlts _ alts = alts  -- Default case
+
+{- Note [Combine case alternatives]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+combineAlts is just a more heavyweight version of the use of
+combineIdentialAlts in SimplUtils.prepareAlts.  The basic idea is
+to transform
+
+    DEFAULT -> e1
+    K x     -> e1
+    W y z   -> e2
+===>
+   DEFAULT -> e1
+   W y z   -> e2
+
+In the simplifier we use cheapEqExpr, because it is called a lot.
+But here in CSE we use the full eqExpr.  After all, two alterantives usually
+differ near the root, so it probably isn't expensive to compare the full
+alternative.  It seems like the the same kind of thing that CSE is supposed
+to be doing, which is why I put it here.
+
+I acutally saw some examples in the wild, where some inlining made e1 too
+big for cheapEqExpr to catch it.
+
+
 ************************************************************************
 *                                                                      *
 \section{The CSE envt}
@@ -445,10 +527,14 @@ data CSEnv
        , cs_map   :: CoreMap OutExpr   -- The reverse mapping
             -- Maps a OutExpr to a /trivial/ OutExpr
             -- The key of cs_map is stripped of all Ticks
+
+       , cs_rec_map :: CoreMap OutExpr
+            -- See Note [CSE for recursive bindings]
        }
 
 emptyCSEnv :: CSEnv
-emptyCSEnv = CS { cs_map = emptyCoreMap, cs_subst = emptySubst }
+emptyCSEnv = CS { cs_map = emptyCoreMap, cs_rec_map = emptyCoreMap
+                , cs_subst = emptySubst }
 
 lookupCSEnv :: CSEnv -> OutExpr -> Maybe OutExpr
 lookupCSEnv (CS { cs_map = csmap }) expr
@@ -460,6 +546,16 @@ extendCSEnv cse expr triv_expr
   where
     sexpr = stripTicksE tickishFloatable expr
 
+extendCSRecEnv :: CSEnv -> OutId -> OutExpr -> OutExpr -> CSEnv
+-- See Note [CSE for recursive bindings]
+extendCSRecEnv cse bndr expr triv_expr
+  = cse { cs_rec_map = extendCoreMap (cs_map cse) (Lam bndr expr) triv_expr }
+
+lookupCSRecEnv :: CSEnv -> OutId -> OutExpr -> Maybe OutExpr
+-- See Note [CSE for recursive bindings]
+lookupCSRecEnv (CS { cs_rec_map = csmap }) bndr expr
+  = lookupCoreMap csmap (Lam bndr expr)
+
 csEnvSubst :: CSEnv -> Subst
 csEnvSubst = cs_subst
 
index ee136c2..681d171 100644 (file)
@@ -1,7 +1,7 @@
 
 ==================== Tidy Core ====================
 Result size of Tidy Core
-  = {terms: 50, types: 25, coercions: 0, joins: 0/0}
+  = {terms: 36, types: 19, coercions: 0, joins: 0/0}
 
 -- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
 T7116.$trModule4 :: GHC.Prim.Addr#
@@ -64,7 +64,7 @@ dr
   = \ (x :: Double) ->
       case x of { GHC.Types.D# x1 -> GHC.Types.D# (GHC.Prim.+## x1 x1) }
 
--- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
 dl :: Double -> Double
 [GblId,
  Arity=1,
@@ -75,9 +75,7 @@ dl :: Double -> Double
          Guidance=ALWAYS_IF(arity=1,unsat_ok=True,boring_ok=False)
          Tmpl= \ (x [Occ=Once!] :: Double) ->
                  case x of { GHC.Types.D# y -> GHC.Types.D# (GHC.Prim.+## y y) }}]
-dl
-  = \ (x :: Double) ->
-      case x of { GHC.Types.D# y -> GHC.Types.D# (GHC.Prim.+## y y) }
+dl = dr
 
 -- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
 fr :: Float -> Float
@@ -98,7 +96,7 @@ fr
       GHC.Types.F# (GHC.Prim.plusFloat# x1 x1)
       }
 
--- RHS size: {terms: 8, types: 3, coercions: 0, joins: 0/0}
+-- RHS size: {terms: 1, types: 0, coercions: 0, joins: 0/0}
 fl :: Float -> Float
 [GblId,
  Arity=1,
@@ -111,11 +109,7 @@ fl :: Float -> Float
                  case x of { GHC.Types.F# y ->
                  GHC.Types.F# (GHC.Prim.plusFloat# y y)
                  }}]
-fl
-  = \ (x :: Float) ->
-      case x of { GHC.Types.F# y ->
-      GHC.Types.F# (GHC.Prim.plusFloat# y y)
-      }
+fl = fr