Call Arity: Now also done on Top-Level binds
authorJoachim Breitner <mail@joachim-breitner.de>
Tue, 18 Feb 2014 10:53:22 +0000 (10:53 +0000)
committerJoachim Breitner <mail@joachim-breitner.de>
Tue, 18 Feb 2014 14:50:37 +0000 (14:50 +0000)
compiler/simplCore/CallArity.hs

index b1ad34e..975c703 100644 (file)
@@ -257,14 +257,37 @@ information from the alternatives (resp. the argument).
 It might be smarter to look for “more important” variables first, i.e. the
 innermost recursive variable.
 
+Note [Analysing top-level binds]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We can eta-expand top-level-binds if they are not exported, as we see all calls
+to them. The plan is as follows: Treat the top-level binds as nested lets around
+a body representing “all external calls”, which returns a CallArityEnv that calls
+every exported function with the top of the lattice.
+
+This means that the incoming arity on all top-level binds will have a Many
+attached, and we will never eta-expand CAFs. Which is good.
+
 -}
 
 callArityAnalProgram :: DynFlags -> CoreProgram -> CoreProgram
-callArityAnalProgram _dflags = map callArityBind
+callArityAnalProgram _dflags binds = binds'
+  where
+    (_, binds') = callArityTopLvl [] emptyVarSet binds
+
+-- See Note [Analysing top-level-binds]
+callArityTopLvl :: [Var] -> VarSet -> [CoreBind] -> (CallArityEnv, [CoreBind])
+callArityTopLvl exported _ []
+    = (mkVarEnv $ zip exported (repeat topCallCount), [])
+callArityTopLvl exported int1 (b:bs)
+    = (ae2, b':bs')
+  where
+    int2 = interestingBinds b
+    exported' = filter isExportedId int2 ++ exported
+    int' = int1 `extendVarSetList` int2
+    (ae1, bs') = callArityTopLvl exported' int' bs
+    (ae2, b')  = callArityBind ae1 int1 b
 
-callArityBind :: CoreBind -> CoreBind
-callArityBind (NonRec id rhs) = NonRec id (callArityRHS rhs)
-callArityBind (Rec binds) = Rec $ map (\(id,rhs) -> (id, callArityRHS rhs)) binds
 
 callArityRHS :: CoreExpr -> CoreExpr
 callArityRHS = snd . callArityAnal 0 emptyVarSet
@@ -319,67 +342,16 @@ callArityAnal arity int (Lam v e)
   where
     (ae, e') = callArityAnal (arity - 1) int e
 
--- Boring non-recursive let, i.e. no eta expansion possible. do not be smart about this
--- See Note [Which variables are interesting]
-callArityAnal arity int (Let (NonRec v rhs) e)
-    | exprArity rhs >= length (typeArity (idType v))
-    = (ae_final, Let (NonRec v rhs') e')
-  where
-    (ae_rhs, rhs') = callArityAnal 0 int rhs
-    (ae_body, e')  = callArityAnal arity int e
-    ae_body' = ae_body `delVarEnv` v
-    ae_final = forgetOnceCalls ae_rhs `lubEnv` ae_body'
-
--- Non-recursive let. Find out how the body calls the rhs, analise that,
--- and combine the results, convervatively using both
-callArityAnal arity int (Let (NonRec v rhs) e)
-  = -- pprTrace "callArityAnal:LetNonRec"
+-- For lets, use callArityBind
+callArityAnal arity int (Let bind e)
+  = -- pprTrace "callArityAnal:Let"
     --          (vcat [ppr v, ppr arity, ppr n, ppr final_ae ])
-    (final_ae, Let (NonRec v' rhs') e')
+    (final_ae, Let bind' e')
   where
-    int_body = int `extendVarSet` v
+    int_body = int `extendVarSetList` interestingBinds bind
     (ae_body, e') = callArityAnal arity int_body e
-    callcount = lookupWithDefaultVarEnv ae_body topCallCount v
+    (final_ae, bind') = callArityBind ae_body int bind
 
-    (ae_rhs, safe_arity, rhs') = callArityBound callcount int rhs
-    final_ae = ae_rhs `lubEnv` (ae_body `delVarEnv` v)
-    v' = v `setIdCallArity` safe_arity
-
--- Boring recursive let, i.e. no eta expansion possible. do not be smart about this
-callArityAnal arity int (Let (Rec [(v,rhs)]) e)
-    | exprArity rhs >= length (typeArity (idType v))
-    = (ae_final, Let (Rec [(v,rhs')]) e')
-  where
-    (ae_rhs, rhs') = callArityAnal 0 int rhs
-    (ae_body, e')  = callArityAnal arity int e
-    ae_final = (forgetOnceCalls ae_rhs `lubEnv` ae_body) `delVarEnv` v
-
--- Recursive let.
--- See Note [Recursion and fixpointing]
-callArityAnal arity int (Let (Rec [(v,rhs)]) e)
-  = -- pprTrace "callArityAnal:LetRec"
-    --         (vcat [ppr v, ppr arity, ppr safe_arity, ppr rhs_arity', ppr final_ae ])
-    (final_ae, Let (Rec [(v',rhs')]) e')
-  where
-    int_body = int `extendVarSet` v
-    (ae_body, e') = callArityAnal arity int_body e
-    callcount = lookupWithDefaultVarEnv ae_body topCallCount v
-
-    (ae_rhs, new_arity, rhs') = callArityFix callcount int_body v rhs
-    final_ae = (ae_rhs `lubEnv` ae_body) `delVarEnv` v
-    v' = v `setIdCallArity` new_arity
-
-
-
--- Mutual recursion. Do nothing serious here, for now
-callArityAnal arity int (Let (Rec binds) e)
-    = (final_ae, Let (Rec binds') e')
-  where
-    (aes, binds') = unzip $ map go binds
-    go (i,e) = let (ae,e') = callArityAnal 0 int e
-               in (forgetOnceCalls ae, (i,e'))
-    (ae, e') = callArityAnal arity int e
-    final_ae = foldl lubEnv ae aes `delVarEnvList` map fst binds
 
 -- Application. Increase arity for the called expresion, nothing to know about
 -- the second
@@ -409,6 +381,53 @@ callArityAnal arity int (Case scrut bndr ty alts)
     -- See Note [Case and App: Which side to take?]
     final_ae = scrut_ae `useBetterOf` alt_ae
 
+-- Which bindings should we look at?
+-- See Note [Which variables are interesting]
+interestingBinds :: CoreBind -> [Var]
+interestingBinds bind =
+    map fst $ filter go $ case bind of (NonRec v e) -> [(v,e)]
+                                       (Rec ves)    -> ves
+  where
+    go (v,e) = exprArity e < length (typeArity (idType v))
+
+-- Used for both local and top-level binds
+-- First argument is the demand from the body
+callArityBind :: CallArityEnv -> VarSet -> CoreBind -> (CallArityEnv, CoreBind)
+
+-- Non-recursive let
+callArityBind ae_body int (NonRec v rhs)
+  = -- pprTrace "callArityBind:NonRec"
+    --          (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr safe_arity])
+    (final_ae, NonRec v' rhs')
+  where
+    callcount = lookupWithDefaultVarEnv ae_body topCallCount v
+    (ae_rhs, safe_arity, rhs') = callArityBound callcount int rhs
+    final_ae = ae_rhs `lubEnv` (ae_body `delVarEnv` v)
+    v' = v `setIdCallArity` safe_arity
+
+-- Recursive let. See Note [Recursion and fixpointing]
+callArityBind ae_body int b@(Rec [(v,rhs)])
+  = -- pprTrace "callArityBind:Rec"
+    --          (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr new_arity])
+    (final_ae, Rec [(v',rhs')])
+  where
+    int_body = int `extendVarSetList` interestingBinds b
+    callcount = lookupWithDefaultVarEnv ae_body topCallCount v
+    (ae_rhs, new_arity, rhs') = callArityFix callcount int_body v rhs
+    final_ae = (ae_rhs `lubEnv` ae_body) `delVarEnv` v
+    v' = v `setIdCallArity` new_arity
+
+
+-- Mutual recursion. Do nothing serious here, for now
+callArityBind ae_body int (Rec binds)
+  = (final_ae, Rec binds')
+  where
+    (aes, binds') = unzip $ map go binds
+    go (i,e) = let (ae, _, e') = callArityBound topCallCount int e
+               in (ae, (i,e'))
+    final_ae = foldl lubEnv ae_body aes `delVarEnvList` map fst binds
+
+
 callArityFix :: CallCount -> VarSet -> Id -> CoreExpr -> (CallArityEnv, Arity, CoreExpr)
 callArityFix arity int v e