@@ -18,22 +18,23 @@ import CoreArity ( typeArity )
import CoreUtils ( exprIsHNF )
--import Outputable
import UnVarGraph
+import Demand

import Control.Arrow ( first, second )

{-
%************************************************************************
-%*                                                                     *
+%*                                                                      *
Call Arity Analyis
-%*                                                                     *
+%*                                                                      *
%************************************************************************

Note [Call Arity: The goal]
~~~~~~~~~~~~~~~~~~~~~~~~~~~

The goal of this analysis is to find out if we can eta-expand a local function,
-based on how it is being called. The motivating example is code this this,
+based on how it is being called. The motivating example is this code,
which comes up when we implement foldl using foldr, and do list fusion:

let go = \x -> let d = case ... of
@@ -46,7 +47,7 @@ If we do not eta-expand `go` to have arity 2, we are going to allocate a lot of
partial function applications, which would be bad.

The function `go` has a type of arity two, but only one lambda is manifest.
-Further more, an analysis that only looks at the RHS of go cannot be sufficient
+Furthermore, an analysis that only looks at the RHS of go cannot be sufficient
to eta-expand go: If `go` is ever called with one argument (and the result used
multiple times), we would be doing the work in `...` multiple times.

@@ -150,7 +151,7 @@ The interesting cases of the analysis:
Return (alt₁ ∪ alt₂ ∪...)
* App e₁ e₂ (and analogously Case scrut alts):
We get the results from both sides. Additionally, anything called by e₁ can
-   possibly called with anything from e₂.
+   possibly be called with anything from e₂.
Return: C(e₁) ∪ C(e₂) ∪ (fv e₁) × (fv e₂)
* Let v = rhs in body:
In addition to the results from the subexpressions, add all co-calls from
@@ -304,18 +305,25 @@ called, i.e. variables bound in a pattern match. So interesting are variables th
* top-level or let bound
* and possibly functions (typeArity > 0)

-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Note [Taking boring variables into account]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If we decide that the variable bound in `let x = e1 in e2` is not interesting,
the analysis of `e2` will not report anything about `x`. To ensure that
-`callArityBind` does still do the right thing we have to extend the result from
-`e2` with a safe approximation.
-
-This is done using `fakeBoringCalls` and has the effect of analysing
-   x `seq` x `seq` e2
-instead, i.e. with `both` the result from `e2` with the most conservative
+`callArityBind` does still do the right thing we have to take that into account
+everytime we would be lookup up `x` in the analysis result of `e2`.
+  * Instead of calling lookupCallArityRes, we return (0, True), indicating
+    that this variable might be called many times with no variables.
+  * Instead of checking `calledWith x`, we assume that everything can be called
+    with it.
+  * In the recursive case, when calclulating the `cross_calls`, if there is
+    any boring variable in the recursive group, we ignore all co-call-results
+    and directly go to a very conservative assumption.
+
+The last point has the nice side effect that the relatively expensive
+integration of co-call results in a recursive groups is often skipped. This
+helped to avoid the compile time blowup in some real-world code with large
+recursive groups (#10293).

Note [Recursion and fixpointing]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -347,7 +355,7 @@ t1) in the follwing code:
t2 = if ... then go 1 else ...
in go 0

-Detecting this would reqiure finding out what variables are only ever called
+Detecting this would require finding out what variables are only ever called
from thunks. While this is certainly possible, we yet have to see this to be
relevant in the wild.

@@ -360,6 +368,28 @@ to them. The plan is as follows: Treat the top-level binds as nested lets around
a body representing “all external calls”, which returns a pessimistic
CallArityRes (the co-call graph is the complete graph, all arityies 0).

+Note [Trimming arity]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In the Call Arity papers, we are working on an untyped lambda calculus with no
+other id annotations, where eta-expansion is always possible. But this is not
+the case for Core!
+ 1. We need to ensure the invariant
+      callArity e <= typeArity (exprType e)
+    for the same reasons that exprArity needs this invariant (see Note
+    [exprArity invariant] in CoreArity).
+
+    If we are not doing that, a too-high arity annotation will be stored with
+    the id, confusing the simplifier later on.
+
+ 2. Eta-expanding a right hand side might invalidate existing annotations. In
+    particular, if an id has a strictness annotation of <...><...>b, then
+    passing two arguments to it will definitely bottom out, so the simplifier
+    will throw away additional parameters. This conflicts with Call Arity! So
+    we ensure that we never eta-expand such a value beyond the number of
+    arguments mentioned in the strictness signature.
+    See #10176 for a real-world-example.
+
-}

-- Main entry point
@@ -381,8 +411,7 @@ callArityTopLvl exported int1 (b:bs)
exported' = filter isExportedId int2 ++ exported
(ae1, bs') = callArityTopLvl exported' int' bs
-    ae1' = fakeBoringCalls int' b ae1 -- See Note [Information about boring variables]
-    (ae2, b')  = callArityBind ae1' int1 b
+    (ae2, b')  = callArityBind (boringBinds b) ae1 int1 b

callArityRHS :: CoreExpr -> CoreExpr
@@ -443,7 +472,6 @@ callArityAnal arity int (App e1 e2)
where
(ae1, e1') = callArityAnal (arity + 1) int e1
(ae2, e2') = callArityAnal 0           int e2
-    -- See Note [Case and App: Which side to take?]
final_ae = ae1 `both` ae2

-- Case expression.
@@ -457,7 +485,6 @@ callArityAnal arity int (Case scrut bndr ty alts)
in  (ae, (dc, bndrs, e'))
alt_ae = lubRess alt_aes
(scrut_ae, scrut') = callArityAnal 0 int scrut
-    -- See Note [Case and App: Which side to take?]
final_ae = scrut_ae `both` alt_ae

-- For lets, use callArityBind
@@ -468,63 +495,74 @@ callArityAnal arity int (Let bind e)
where
(ae_body, e') = callArityAnal arity int_body e
-    ae_body' = fakeBoringCalls int_body bind ae_body -- See Note [Information about boring variables]
-    (final_ae, bind') = callArityBind ae_body' int bind
+    (final_ae, bind') = callArityBind (boringBinds bind) ae_body int bind

-- Which bindings should we look at?
-- See Note [Which variables are interesting]
+isInteresting :: Var -> Bool
+isInteresting v = 0 < length (typeArity (idType v))
+
interestingBinds :: CoreBind -> [Var]
-interestingBinds = filter go . bindersOf
-  where go v = 0 < length (typeArity (idType v))
+interestingBinds = filter isInteresting . bindersOf
+
+boringBinds :: CoreBind -> VarSet
+boringBinds = mkVarSet . filter (not . isInteresting) . bindersOf

addInterestingBinds :: VarSet -> CoreBind -> VarSet
= int `delVarSetList`    bindersOf bind -- Possible shadowing
`extendVarSetList` interestingBinds bind

--- For every boring variable in the binder, add a safe approximation
--- See Note [Information about boring variables]
-fakeBoringCalls :: VarSet -> CoreBind -> CallArityRes -> CallArityRes
-fakeBoringCalls int bind res = boring `both` res
-  where
-    boring = calledMultipleTimes \$
-        ( emptyUnVarGraph
-        ,  mkVarEnv [ (v, 0) | v <- bindersOf bind, not (v `elemVarSet` int)])
-
-
-- Used for both local and top-level binds
--- First argument is the demand from the body
-callArityBind :: CallArityRes -> VarSet -> CoreBind -> (CallArityRes, CoreBind)
+-- Second argument is the demand from the body
+callArityBind :: VarSet -> CallArityRes -> VarSet -> CoreBind -> (CallArityRes, CoreBind)
-- Non-recursive let
-callArityBind ae_body int (NonRec v rhs)
+callArityBind boring_vars ae_body int (NonRec v rhs)
| otherwise
= -- pprTrace "callArityBind:NonRec"
--          (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr safe_arity])
(final_ae, NonRec v' rhs')
where
is_thunk = not (exprIsHNF rhs)
+    -- If v is boring, we will not find it in ae_body, but always assume (0, False)
+    boring = v `elemVarSet` boring_vars

-    (arity, called_once)  = lookupCallArityRes ae_body v
+    (arity, called_once)
+        | boring    = (0, False) -- See Note [Taking boring variables into account]
+        | otherwise = lookupCallArityRes ae_body v
safe_arity | called_once = arity
| is_thunk    = 0      -- A thunk! Do not eta-expand
| otherwise   = arity
-    (ae_rhs, rhs') = callArityAnal safe_arity int rhs
+
+    -- See Note [Trimming arity]
+    trimmed_arity = trimArity v safe_arity
+
+    (ae_rhs, rhs') = callArityAnal trimmed_arity int rhs
+

ae_rhs'| called_once     = ae_rhs
| safe_arity == 0 = ae_rhs -- If it is not a function, its body is evaluated only once
| otherwise       = calledMultipleTimes ae_rhs

-    final_ae = callArityNonRecEnv v ae_rhs' ae_body
-    v' = v `setIdCallArity` safe_arity
+    called_by_v = domRes ae_rhs'
+    called_with_v
+        | boring    = domRes ae_body
+        | otherwise = calledWith ae_body v `delUnVarSet` v
+    final_ae = addCrossCoCalls called_by_v called_with_v \$ ae_rhs' `lubRes` resDel v ae_body

+    v' = v `setIdCallArity` trimmed_arity

-- Recursive let. See Note [Recursion and fixpointing]
-callArityBind ae_body int b@(Rec binds)
-  = -- pprTrace "callArityBind:Rec"
-    --           (vcat [ppr (Rec binds'), ppr ae_body, ppr int, ppr ae_rhs]) \$
+callArityBind boring_vars ae_body int b@(Rec binds)
+  = -- (if length binds > 300 then
+    -- pprTrace "callArityBind:Rec"
+    --           (vcat [ppr (Rec binds'), ppr ae_body, ppr int, ppr ae_rhs]) else id) \$
(final_ae, Rec binds')
where
+    -- See Note [Taking boring variables into account]
+    any_boring = any (`elemVarSet` boring_vars) [ i | (i, _) <- binds]
+
(ae_rhs, binds') = fix initial_binds
final_ae = bindersOf b `resDelList` ae_rhs
@@ -540,7 +578,7 @@ callArityBind ae_body int b@(Rec binds)
= (ae, map (\(i, _, e) -> (i, e)) ann_binds')
where
aes_old = [ (i,ae) | (i, Just (_,_,ae), _) <- ann_binds ]
-        ae = callArityRecEnv aes_old ae_body
+        ae = callArityRecEnv any_boring aes_old ae_body

rerun (i, mbLastRun, rhs)
| i `elemVarSet` int_body && not (i `elemUnVarSet` domRes ae)
@@ -560,40 +598,42 @@ callArityBind ae_body int b@(Rec binds)
safe_arity | is_thunk    = 0  -- See Note [Thunks in recursive groups]
| otherwise   = new_arity

-                  (ae_rhs, rhs') = callArityAnal safe_arity int_body rhs
+                  -- See Note [Trimming arity]
+                  trimmed_arity = trimArity i safe_arity
+
+                  (ae_rhs, rhs') = callArityAnal trimmed_arity int_body rhs

ae_rhs' | called_once     = ae_rhs
| safe_arity == 0 = ae_rhs -- If it is not a function, its body is evaluated only once
| otherwise       = calledMultipleTimes ae_rhs

-              in (True, (i `setIdCallArity` safe_arity, Just (called_once, new_arity, ae_rhs'), rhs'))
+              in (True, (i `setIdCallArity` trimmed_arity, Just (called_once, new_arity, ae_rhs'), rhs'))
where
-            (new_arity, called_once)  = lookupCallArityRes ae i
+            -- See Note [Taking boring variables into account]
+            (new_arity, called_once) | i `elemVarSet` boring_vars = (0, False)
+                                     | otherwise                  = lookupCallArityRes ae i

(changes, ann_binds') = unzip \$ map rerun ann_binds
any_change = or changes

--- Combining the results from body and rhs, non-recursive case
--- See Note [Analysis II: The Co-Called analysis]
-callArityNonRecEnv :: Var -> CallArityRes -> CallArityRes -> CallArityRes
-callArityNonRecEnv v ae_rhs ae_body
-    = addCrossCoCalls called_by_v called_with_v \$ ae_rhs `lubRes` resDel v ae_body
-  where
-    called_by_v = domRes ae_rhs
-    called_with_v = calledWith ae_body v `delUnVarSet` v
-
-- Combining the results from body and rhs, (mutually) recursive case
-- See Note [Analysis II: The Co-Called analysis]
-callArityRecEnv :: [(Var, CallArityRes)] -> CallArityRes -> CallArityRes
-callArityRecEnv ae_rhss ae_body
-    = -- pprTrace "callArityRecEnv" (vcat [ppr ae_rhss, ppr ae_body, ppr ae_new])
+callArityRecEnv :: Bool -> [(Var, CallArityRes)] -> CallArityRes -> CallArityRes
+callArityRecEnv any_boring ae_rhss ae_body
+    = -- (if length ae_rhss > 300 then pprTrace "callArityRecEnv" (vcat [ppr ae_rhss, ppr ae_body, ppr ae_new]) else id) \$
ae_new
where
vars = map fst ae_rhss

ae_combined = lubRess (map snd ae_rhss) `lubRes` ae_body

-    cross_calls = unionUnVarGraphs \$ map cross_call ae_rhss
+    cross_calls
+        -- See Note [Taking boring variables into account]
+        | any_boring          = completeGraph (domRes ae_combined)
+        -- Also, calculating cross_calls is expensive. Simply be conservative
+        -- if the mutually recursive group becomes too large.
+        | length ae_rhss > 25 = completeGraph (domRes ae_combined)
+        | otherwise           = unionUnVarGraphs \$ map cross_call ae_rhss
cross_call (v, ae_rhs) = completeBipartiteGraph called_by_v called_with_v
where
is_thunk = idCallArity v == 0
@@ -610,6 +650,17 @@ callArityRecEnv ae_rhss ae_body

ae_new = first (cross_calls `unionUnVarGraph`) ae_combined

+-- See Note [Trimming arity]
+trimArity :: Id -> Arity -> Arity
+trimArity v a = minimum [a, max_arity_by_type, max_arity_by_strsig]
+  where
+    max_arity_by_type = length (typeArity (idType v))
+    max_arity_by_strsig
+        | isBotRes result_info = length demands
+        | otherwise = a
+
+    (demands, result_info) = splitStrictSig (idStrictness v)
+
---------------------------------------
-- Functions related to CallArityRes --
---------------------------------------