@@ -18,22 +18,23 @@ import CoreArity ( typeArity )
import CoreUtils ( exprIsHNF )
--import Outputable
import UnVarGraph
+import Demand

import Control.Arrow ( first, second )

{-
%************************************************************************
-%*                                                                     *
+%*                                                                      *
Call Arity Analyis
-%*                                                                     *
+%*                                                                      *
%************************************************************************

Note [Call Arity: The goal]
~~~~~~~~~~~~~~~~~~~~~~~~~~~

The goal of this analysis is to find out if we can eta-expand a local function,
-based on how it is being called. The motivating example is code this this,
+based on how it is being called. The motivating example is this code,
which comes up when we implement foldl using foldr, and do list fusion:

let go = \x -> let d = case ... of
@@ -46,7 +47,7 @@ If we do not eta-expand `go` to have arity 2, we are going to allocate a lot of
partial function applications, which would be bad.

The function `go` has a type of arity two, but only one lambda is manifest.
-Further more, an analysis that only looks at the RHS of go cannot be sufficient
+Furthermore, an analysis that only looks at the RHS of go cannot be sufficient
to eta-expand go: If `go` is ever called with one argument (and the result used
multiple times), we would be doing the work in `...` multiple times.

@@ -150,7 +151,7 @@ The interesting cases of the analysis:
Return (alt₁ ∪ alt₂ ∪...)
* App e₁ e₂ (and analogously Case scrut alts):
We get the results from both sides. Additionally, anything called by e₁ can
-   possibly called with anything from e₂.
+   possibly be called with anything from e₂.
Return: C(e₁) ∪ C(e₂) ∪ (fv e₁) × (fv e₂)
* Let v = rhs in body:
In addition to the results from the subexpressions, add all co-calls from
@@ -168,7 +169,7 @@ The interesting cases of the analysis:
cardinality consistent with the final result (this is the fixed-pointing).
Again we can use the results from all subexpressions.
In addition, for every variable vᵢ, we need to find out what it is called
-   with (calls this set Sᵢ). There are two cases:
+   with (call this set Sᵢ). There are two cases:
* If vᵢ is a function, we need to go through all right-hand-sides and bodies,
and collect every variable that is called together with any variable from V:
Sᵢ = {v' | j ∈ {1,...,n},      {v',vⱼ} ∈ C'(rhs₁) ∪ ... ∪ C'(rhsₙ) ∪ C(body) }
@@ -304,6 +305,26 @@ called, i.e. variables bound in a pattern match. So interesting are variables th
* top-level or let bound
* and possibly functions (typeArity > 0)

+Note [Taking boring variables into account]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+If we decide that the variable bound in `let x = e1 in e2` is not interesting,
+the analysis of `e2` will not report anything about `x`. To ensure that
+`callArityBind` does still do the right thing we have to take that into account
+everytime we would be lookup up `x` in the analysis result of `e2`.
+  * Instead of calling lookupCallArityRes, we return (0, True), indicating
+    that this variable might be called many times with no variables.
+  * Instead of checking `calledWith x`, we assume that everything can be called
+    with it.
+  * In the recursive case, when calclulating the `cross_calls`, if there is
+    any boring variable in the recursive group, we ignore all co-call-results
+    and directly go to a very conservative assumption.
+
+The last point has the nice side effect that the relatively expensive
+integration of co-call results in a recursive groups is often skipped. This
+helped to avoid the compile time blowup in some real-world code with large
+recursive groups (#10293).
+
Note [Recursion and fixpointing]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

@@ -319,6 +340,26 @@ For a mutually recursive let, we begin by
5. If nothing had to be reanalized, we are done.
Otherwise, repeat from step 3.

+
+Note [Thunks in recursive groups]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We never eta-expand a thunk in a recursive group, on the grounds that if it is
+part of a recursive group, then it will be called multipe times.
+
+This is not necessarily true, e.g.  it would be safe to eta-expand t2 (but not
+t1) in the follwing code:
+
+  let go x = t1
+      t1 = if ... then t2 else ...
+      t2 = if ... then go 1 else ...
+  in go 0
+
+Detecting this would require finding out what variables are only ever called
+from thunks. While this is certainly possible, we yet have to see this to be
+relevant in the wild.
+
+
Note [Analysing top-level binds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

@@ -327,6 +368,28 @@ to them. The plan is as follows: Treat the top-level binds as nested lets around
a body representing “all external calls”, which returns a pessimistic
CallArityRes (the co-call graph is the complete graph, all arityies 0).

+Note [Trimming arity]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In the Call Arity papers, we are working on an untyped lambda calculus with no
+other id annotations, where eta-expansion is always possible. But this is not
+the case for Core!
+ 1. We need to ensure the invariant
+      callArity e <= typeArity (exprType e)
+    for the same reasons that exprArity needs this invariant (see Note
+    [exprArity invariant] in CoreArity).
+
+    If we are not doing that, a too-high arity annotation will be stored with
+    the id, confusing the simplifier later on.
+
+ 2. Eta-expanding a right hand side might invalidate existing annotations. In
+    particular, if an id has a strictness annotation of <...><...>b, then
+    passing two arguments to it will definitely bottom out, so the simplifier
+    will throw away additional parameters. This conflicts with Call Arity! So
+    we ensure that we never eta-expand such a value beyond the number of
+    arguments mentioned in the strictness signature.
+    See #10176 for a real-world-example.
+
-}

-- Main entry point
@@ -348,8 +411,7 @@ callArityTopLvl exported int1 (b:bs)
exported' = filter isExportedId int2 ++ exported
(ae1, bs') = callArityTopLvl exported' int' bs
-    ae1' = fakeBoringCalls int' b ae1
-    (ae2, b')  = callArityBind ae1' int1 b
+    (ae2, b')  = callArityBind (boringBinds b) ae1 int1 b

callArityRHS :: CoreExpr -> CoreExpr
@@ -410,7 +472,6 @@ callArityAnal arity int (App e1 e2)
where
(ae1, e1') = callArityAnal (arity + 1) int e1
(ae2, e2') = callArityAnal 0           int e2
-    -- See Note [Case and App: Which side to take?]
final_ae = ae1 `both` ae2

-- Case expression.
@@ -424,7 +485,6 @@ callArityAnal arity int (Case scrut bndr ty alts)
in  (ae, (dc, bndrs, e'))
alt_ae = lubRess alt_aes
(scrut_ae, scrut') = callArityAnal 0 int scrut
-    -- See Note [Case and App: Which side to take?]
final_ae = scrut_ae `both` alt_ae

-- For lets, use callArityBind
@@ -435,72 +495,74 @@ callArityAnal arity int (Let bind e)
where
(ae_body, e') = callArityAnal arity int_body e
-    ae_body' = fakeBoringCalls int_body bind ae_body
-    (final_ae, bind') = callArityBind ae_body' int bind
-
--- This is a variant of callArityAnal that is additionally told whether
--- the expression is called once or multiple times, and treats thunks appropriately.
--- It also returns the actual arity that can be used for this expression.
-callArityBound :: Bool -> Arity -> VarSet -> CoreExpr -> (CallArityRes, Arity, CoreExpr)
-callArityBound called_once arity int e
-    = -- pprTrace "callArityBound" (vcat [ppr (called_once, arity), ppr is_thunk, ppr safe_arity]) \$
-      (final_ae, safe_arity, e')
- where
-    is_thunk = not (exprIsHNF e)
-
-    safe_arity | called_once = arity
-               | is_thunk    = 0      -- A thunk! Do not eta-expand
-               | otherwise   = arity
-
-    (ae, e') = callArityAnal safe_arity int e
-
-    final_ae | called_once     = ae
-             | safe_arity == 0 = ae -- If it is not a function, its body is evaluated only once
-             | otherwise       = calledMultipleTimes ae
-
+    (final_ae, bind') = callArityBind (boringBinds bind) ae_body int bind

-- Which bindings should we look at?
-- See Note [Which variables are interesting]
+isInteresting :: Var -> Bool
+isInteresting v = 0 < length (typeArity (idType v))
+
interestingBinds :: CoreBind -> [Var]
-interestingBinds = filter go . bindersOf
-  where go v = 0 < length (typeArity (idType v))
+interestingBinds = filter isInteresting . bindersOf
+
+boringBinds :: CoreBind -> VarSet
+boringBinds = mkVarSet . filter (not . isInteresting) . bindersOf

addInterestingBinds :: VarSet -> CoreBind -> VarSet
= int `delVarSetList`    bindersOf bind -- Possible shadowing
`extendVarSetList` interestingBinds bind

--- For every boring variable in the binder, this amends the CallArityRes to
--- report safe information about them (co-called with everything else, arity 0).
-fakeBoringCalls :: VarSet -> CoreBind -> CallArityRes -> CallArityRes
-fakeBoringCalls int bind res
-    = addCrossCoCalls (domRes boring) (domRes res) \$ (boring `lubRes` res)
-  where
-    boring = ( emptyUnVarGraph
-             ,  mkVarEnv [ (v, 0) | v <- bindersOf bind, not (v `elemVarSet` int)])
-
-
-- Used for both local and top-level binds
--- First argument is the demand from the body
-callArityBind :: CallArityRes -> VarSet -> CoreBind -> (CallArityRes, CoreBind)
+-- Second argument is the demand from the body
+callArityBind :: VarSet -> CallArityRes -> VarSet -> CoreBind -> (CallArityRes, CoreBind)
-- Non-recursive let
-callArityBind ae_body int (NonRec v rhs)
+callArityBind boring_vars ae_body int (NonRec v rhs)
| otherwise
= -- pprTrace "callArityBind:NonRec"
--          (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr safe_arity])
(final_ae, NonRec v' rhs')
where
-    (arity, called_once)  = lookupCallArityRes ae_body v
-    (ae_rhs, safe_arity, rhs') = callArityBound called_once arity int rhs
-    final_ae = callArityNonRecEnv v ae_rhs ae_body
-    v' = v `setIdCallArity` safe_arity
+    is_thunk = not (exprIsHNF rhs)
+    -- If v is boring, we will not find it in ae_body, but always assume (0, False)
+    boring = v `elemVarSet` boring_vars
+
+    (arity, called_once)
+        | boring    = (0, False) -- See Note [Taking boring variables into account]
+        | otherwise = lookupCallArityRes ae_body v
+    safe_arity | called_once = arity
+               | is_thunk    = 0      -- A thunk! Do not eta-expand
+               | otherwise   = arity
+
+    -- See Note [Trimming arity]
+    trimmed_arity = trimArity v safe_arity
+
+    (ae_rhs, rhs') = callArityAnal trimmed_arity int rhs
+
+
+    ae_rhs'| called_once     = ae_rhs
+           | safe_arity == 0 = ae_rhs -- If it is not a function, its body is evaluated only once
+           | otherwise       = calledMultipleTimes ae_rhs
+
+    called_by_v = domRes ae_rhs'
+    called_with_v
+        | boring    = domRes ae_body
+        | otherwise = calledWith ae_body v `delUnVarSet` v
+    final_ae = addCrossCoCalls called_by_v called_with_v \$ ae_rhs' `lubRes` resDel v ae_body
+
+    v' = v `setIdCallArity` trimmed_arity
+

-- Recursive let. See Note [Recursion and fixpointing]
-callArityBind ae_body int b@(Rec binds)
-  = -- pprTrace "callArityBind:Rec"
-    --           (vcat [ppr (Rec binds'), ppr ae_body, ppr int, ppr ae_rhs]) \$
+callArityBind boring_vars ae_body int b@(Rec binds)
+  = -- (if length binds > 300 then
+    -- pprTrace "callArityBind:Rec"
+    --           (vcat [ppr (Rec binds'), ppr ae_body, ppr int, ppr ae_rhs]) else id) \$
(final_ae, Rec binds')
where
+    -- See Note [Taking boring variables into account]
+    any_boring = any (`elemVarSet` boring_vars) [ i | (i, _) <- binds]
+
(ae_rhs, binds') = fix initial_binds
final_ae = bindersOf b `resDelList` ae_rhs
@@ -516,7 +578,7 @@ callArityBind ae_body int b@(Rec binds)
= (ae, map (\(i, _, e) -> (i, e)) ann_binds')
where
aes_old = [ (i,ae) | (i, Just (_,_,ae), _) <- ann_binds ]
-        ae = callArityRecEnv aes_old ae_body
+        ae = callArityRecEnv any_boring aes_old ae_body

rerun (i, mbLastRun, rhs)
| i `elemVarSet` int_body && not (i `elemUnVarSet` domRes ae)
@@ -531,35 +593,47 @@ callArityBind ae_body int b@(Rec binds)

| otherwise
-- We previously analized this with a different arity (or not at all)
-            = let (ae_rhs, safe_arity, rhs') = callArityBound called_once new_arity int_body rhs
-              in (True, (i `setIdCallArity` safe_arity, Just (called_once, new_arity, ae_rhs), rhs'))
+            = let is_thunk = not (exprIsHNF rhs)
+
+                  safe_arity | is_thunk    = 0  -- See Note [Thunks in recursive groups]
+                             | otherwise   = new_arity
+
+                  -- See Note [Trimming arity]
+                  trimmed_arity = trimArity i safe_arity
+
+                  (ae_rhs, rhs') = callArityAnal trimmed_arity int_body rhs
+
+                  ae_rhs' | called_once     = ae_rhs
+                          | safe_arity == 0 = ae_rhs -- If it is not a function, its body is evaluated only once
+                          | otherwise       = calledMultipleTimes ae_rhs
+
+              in (True, (i `setIdCallArity` trimmed_arity, Just (called_once, new_arity, ae_rhs'), rhs'))
where
-            (new_arity, called_once)  = lookupCallArityRes ae i
+            -- See Note [Taking boring variables into account]
+            (new_arity, called_once) | i `elemVarSet` boring_vars = (0, False)
+                                     | otherwise                  = lookupCallArityRes ae i

(changes, ann_binds') = unzip \$ map rerun ann_binds
any_change = or changes

--- Combining the results from body and rhs, non-recursive case
--- See Note [Analysis II: The Co-Called analysis]
-callArityNonRecEnv :: Var -> CallArityRes -> CallArityRes -> CallArityRes
-callArityNonRecEnv v ae_rhs ae_body
-    = addCrossCoCalls called_by_v called_with_v \$ ae_rhs `lubRes` resDel v ae_body
-  where
-    called_by_v = domRes ae_rhs
-    called_with_v = calledWith ae_body v `delUnVarSet` v
-
-- Combining the results from body and rhs, (mutually) recursive case
-- See Note [Analysis II: The Co-Called analysis]
-callArityRecEnv :: [(Var, CallArityRes)] -> CallArityRes -> CallArityRes
-callArityRecEnv ae_rhss ae_body
-    = -- pprTrace "callArityRecEnv" (vcat [ppr ae_rhss, ppr ae_body, ppr ae_new])
+callArityRecEnv :: Bool -> [(Var, CallArityRes)] -> CallArityRes -> CallArityRes
+callArityRecEnv any_boring ae_rhss ae_body
+    = -- (if length ae_rhss > 300 then pprTrace "callArityRecEnv" (vcat [ppr ae_rhss, ppr ae_body, ppr ae_new]) else id) \$
ae_new
where
vars = map fst ae_rhss

ae_combined = lubRess (map snd ae_rhss) `lubRes` ae_body

-    cross_calls = unionUnVarGraphs \$ map cross_call ae_rhss
+    cross_calls
+        -- See Note [Taking boring variables into account]
+        | any_boring          = completeGraph (domRes ae_combined)
+        -- Also, calculating cross_calls is expensive. Simply be conservative
+        -- if the mutually recursive group becomes too large.
+        | length ae_rhss > 25 = completeGraph (domRes ae_combined)
+        | otherwise           = unionUnVarGraphs \$ map cross_call ae_rhss
cross_call (v, ae_rhs) = completeBipartiteGraph called_by_v called_with_v
where
is_thunk = idCallArity v == 0
@@ -576,6 +650,17 @@ callArityRecEnv ae_rhss ae_body

ae_new = first (cross_calls `unionUnVarGraph`) ae_combined

+-- See Note [Trimming arity]
+trimArity :: Id -> Arity -> Arity
+trimArity v a = minimum [a, max_arity_by_type, max_arity_by_strsig]
+  where
+    max_arity_by_type = length (typeArity (idType v))
+    max_arity_by_strsig
+        | isBotRes result_info = length demands
+        | otherwise = a
+
+    (demands, result_info) = splitStrictSig (idStrictness v)
+
---------------------------------------
-- Functions related to CallArityRes --
---------------------------------------