1 module Exitify ( exitifyProgram ) where
3 {-
4 Note [Exitification]
5 ~~~~~~~~~~~~~~~~~~~~
7 This module implements Exitification. The goal is to pull as much code out of
8 recursive functions as possible, as the simplifier is better at inlining into
9 call-sites that are not in recursive functions.
11 Example:
13 let t = foo bar
14 joinrec go 0 x y = t (x*x)
15 go (n-1) x y = jump go (n-1) (x+y)
16 in …
18 We’d like to inline `t`, but that does not happen: Because t is a thunk and is
19 used in a recursive function, doing so might lose sharing in general. In
20 this case, however, `t` is on the _exit path_ of `go`, so called at most once.
21 How do we make this clearly visible to the simplifier?
23 A code path (i.e., an expression in a tail-recursive position) in a recursive
24 function is an exit path if it does not contain a recursive call. We can bind
25 this expression outside the recursive function, as a join-point.
27 Example result:
29 let t = foo bar
30 join exit x = t (x*x)
31 joinrec go 0 x y = jump exit x
32 go (n-1) x y = jump go (n-1) (x+y)
33 in …
35 Now `t` is no longer in a recursive function, and good things happen!
36 -}
38 import GhcPrelude
39 import Var
40 import Id
41 import IdInfo
42 import CoreSyn
43 import CoreUtils
44 import State
45 import Unique
46 import VarSet
47 import VarEnv
48 import CoreFVs
49 import FastString
50 import Type
51 import Util( mapSnd )
53 import Data.Bifunctor
56 -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
57 -- The really interesting function is exitifyRec
58 exitifyProgram :: CoreProgram -> CoreProgram
59 exitifyProgram binds = map goTopLvl binds
60 where
61 goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
62 goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
63 -- Top-level bindings are never join points
65 in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
67 go :: InScopeSet -> CoreExpr -> CoreExpr
68 go _ e@(Var{}) = e
69 go _ e@(Lit {}) = e
70 go _ e@(Type {}) = e
71 go _ e@(Coercion {}) = e
72 go in_scope (Cast e' c) = Cast (go in_scope e') c
73 go in_scope (Tick t e') = Tick t (go in_scope e')
74 go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
76 go in_scope (Lam v e')
77 = Lam v (go in_scope' e')
78 where in_scope' = in_scope `extendInScopeSet` v
80 go in_scope (Case scrut bndr ty alts)
81 = Case (go in_scope scrut) bndr ty (map go_alt alts)
82 where
83 in_scope1 = in_scope `extendInScopeSet` bndr
84 go_alt (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
85 where in_scope' = in_scope1 `extendInScopeSetList` pats
87 go in_scope (Let (NonRec bndr rhs) body)
88 = Let (NonRec bndr (go in_scope rhs)) (go in_scope' body)
89 where
90 in_scope' = in_scope `extendInScopeSet` bndr
92 go in_scope (Let (Rec pairs) body)
93 | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
94 | otherwise = Let (Rec pairs') body'
95 where
96 is_join_rec = any (isJoinId . fst) pairs
97 in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
98 pairs' = mapSnd (go in_scope') pairs
99 body' = go in_scope' body
102 -- | State Monad used inside `exitify`
103 type ExitifyM = State [(JoinId, CoreExpr)]
105 -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
106 -- join-points outside the joinrec.
107 exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
108 exitifyRec in_scope pairs
109 = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
110 where
111 -- We need the set of free variables of many subexpressions here, so
112 -- annotate the AST with them
113 -- see Note [Calculating free variables]
114 ann_pairs = map (second freeVars) pairs
116 -- Which are the recursive calls?
117 recursive_calls = mkVarSet \$ map fst pairs
119 (pairs',exits) = (`runState` []) \$ do
120 forM ann_pairs \$ \(x,rhs) -> do
121 -- go past the lambdas of the join point
122 let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
123 body' <- go args body
124 let rhs' = mkLams args body'
125 return (x, rhs')
127 ---------------------
128 -- 'go' is the main working function.
129 -- It goes through the RHS (tail-call positions only),
130 -- checks if there are no more recursive calls, if so, abstracts over
131 -- variables bound on the way and lifts it out as a join point.
132 --
133 -- ExitifyM is a state monad to keep track of floated binds
134 go :: [Var] -- ^ Variables that are in-scope here, but
135 -- not in scope at the joinrec; that is,
136 -- we must potentially abstract over them.
137 -- Invariant: they are kept in dependency order
138 -> CoreExprWithFVs -- ^ Current expression in tail position
139 -> ExitifyM CoreExpr
141 -- We first look at the expression (no matter what it shape is)
142 -- and determine if we can turn it into a exit join point
143 go captured ann_e
144 | -- An exit expression has no recursive calls
145 let fvs = dVarSetToVarSet (freeVarsOf ann_e)
146 , disjointVarSet fvs recursive_calls
147 = go_exit captured (deAnnotate ann_e) fvs
149 -- We could not turn it into a exit joint point. So now recurse
150 -- into all expression where eligible exit join points might sit,
151 -- i.e. into all tail-call positions:
153 -- Case right hand sides are in tail-call position
154 go captured (_, AnnCase scrut bndr ty alts) = do
155 alts' <- forM alts \$ \(dc, pats, rhs) -> do
156 rhs' <- go (captured ++ [bndr] ++ pats) rhs
157 return (dc, pats, rhs')
158 return \$ Case (deAnnotate scrut) bndr ty alts'
160 go captured (_, AnnLet ann_bind body)
161 -- join point, RHS and body are in tail-call position
162 | AnnNonRec j rhs <- ann_bind
163 , Just join_arity <- isJoinId_maybe j
164 = do let (params, join_body) = collectNAnnBndrs join_arity rhs
165 join_body' <- go (captured ++ params) join_body
166 let rhs' = mkLams params join_body'
167 body' <- go (captured ++ [j]) body
168 return \$ Let (NonRec j rhs') body'
170 -- rec join point, RHSs and body are in tail-call position
171 | AnnRec pairs <- ann_bind
172 , isJoinId (fst (head pairs))
173 = do let js = map fst pairs
174 pairs' <- forM pairs \$ \(j,rhs) -> do
175 let join_arity = idJoinArity j
176 (params, join_body) = collectNAnnBndrs join_arity rhs
177 join_body' <- go (captured ++ js ++ params) join_body
178 let rhs' = mkLams params join_body'
179 return (j, rhs')
180 body' <- go (captured ++ js) body
181 return \$ Let (Rec pairs') body'
183 -- normal Let, only the body is in tail-call position
184 | otherwise
185 = do body' <- go (captured ++ bindersOf bind ) body
186 return \$ Let bind body'
187 where bind = deAnnBind ann_bind
189 -- Cannot be turned into an exit join point, but also has no
190 -- tail-call subexpression. Nothing to do here.
191 go _ ann_e = return (deAnnotate ann_e)
193 ---------------------
194 go_exit :: [Var] -- Variables captured locally
195 -> CoreExpr -- An exit expression
196 -> VarSet -- Free vars of the expression
197 -> ExitifyM CoreExpr
198 -- go_exit deals with a tail expression that is floatable
199 -- out as an exit point; that is, it mentions no recursive calls
200 go_exit captured e fvs
201 -- Do not touch an expression that is already a join jump where all arguments
202 -- are captured variables. See Note [Idempotency]
203 -- But _do_ float join jumps with interesting arguments.
204 -- See Note [Jumps can be interesting]
205 | (Var f, args) <- collectArgs e
206 , isJoinId f
207 , all isCapturedVarArg args
208 = return e
210 -- Do not touch a boring expression (see Note [Interesting expression])
211 | not is_interesting
212 = return e
214 -- Cannot float out if local join points are used, as
215 -- we cannot abstract over them
216 | captures_join_points
217 = return e
219 -- We have something to float out!
220 | otherwise
221 = do { -- Assemble the RHS of the exit join point
222 let rhs = mkLams abs_vars e
223 avoid = in_scope `extendInScopeSetList` captured
224 -- Remember this binding under a suitable name
225 ; v <- addExit avoid (length abs_vars) rhs
227 ; return \$ mkVarApps (Var v) abs_vars }
229 where
230 -- Used to detect exit expressoins that are already proper exit jumps
231 isCapturedVarArg (Var v) = v `elem` captured
232 isCapturedVarArg _ = False
234 -- An interesting exit expression has free, non-imported
235 -- variables from outside the recursive group
236 -- See Note [Interesting expression]
237 is_interesting = anyVarSet isLocalId \$
238 fvs `minusVarSet` mkVarSet captured
240 -- The arguments of this exit join point
241 -- See Note [Picking arguments to abstract over]
242 abs_vars = snd \$ foldr pick (fvs, []) captured
243 where
244 pick v (fvs', acc) | v `elemVarSet` fvs' = (fvs' `delVarSet` v, zap v : acc)
245 | otherwise = (fvs', acc)
247 -- We are going to abstract over these variables, so we must
248 -- zap any IdInfo they have; see Trac #15005
249 -- cf. SetLevels.abstractVars
250 zap v | isId v = setIdInfo v vanillaIdInfo
251 | otherwise = v
253 -- We cannot abstract over join points
254 captures_join_points = any isJoinId abs_vars
257 -- Picks a new unique, which is disjoint from
258 -- * the free variables of the whole joinrec
259 -- * any bound variables (captured)
260 -- * any exit join points created so far.
261 mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
262 mkExitJoinId in_scope ty join_arity = do
263 fs <- get
264 let avoid = in_scope `extendInScopeSetList` (map fst fs)
265 `extendInScopeSet` exit_id_tmpl -- just cosmetics
266 return (uniqAway avoid exit_id_tmpl)
267 where
268 exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
269 `asJoinId` join_arity
271 addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
272 addExit in_scope join_arity rhs = do
273 -- Pick a suitable name
274 let ty = exprType rhs
275 v <- mkExitJoinId in_scope ty join_arity
276 fs <- get
277 put ((v,rhs):fs)
278 return v
280 {-
281 Note [Interesting expression]
282 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
283 We do not want this to happen:
285 joinrec go 0 x y = x
286 go (n-1) x y = jump go (n-1) (x+y)
287 in …
288 ==>
289 join exit x = x
290 joinrec go 0 x y = jump exit x
291 go (n-1) x y = jump go (n-1) (x+y)
292 in …
294 because the floated exit path (`x`) is simply a parameter of `go`; there are
295 not useful interactions exposed this way.
297 Neither do we want this to happen
299 joinrec go 0 x y = x+x
300 go (n-1) x y = jump go (n-1) (x+y)
301 in …
302 ==>
303 join exit x = x+x
304 joinrec go 0 x y = jump exit x
305 go (n-1) x y = jump go (n-1) (x+y)
306 in …
308 where the floated expression `x+x` is a bit more complicated, but still not
309 intersting.
311 Expressions are interesting when they move an occurrence of a variable outside
312 the recursive `go` that can benefit from being obviously called once, for example:
313 * a local thunk that can then be inlined (see example in note [Exitification])
314 * the parameter of a function, where the demand analyzer then can then
315 see that it is called at most once, and hence improve the function’s
316 strictness signature
318 So we only hoist an exit expression out if it mentiones at least one free,
319 non-imported variable.
321 Note [Jumps can be interesting]
322 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
323 A jump to a join point can be interesting, if its arguments contain free
324 non-exported variables (z in the following example):
326 joinrec go 0 x y = jump j (x+z)
327 go (n-1) x y = jump go (n-1) (x+y)
328 in …
329 ==>
330 join exit x y = jump j (x+z)
331 joinrec go 0 x y = jump exit x
332 go (n-1) x y = jump go (n-1) (x+y)
335 The join point itself can be interesting, even if none if its
336 arguments have free variables free in the joinrec. For example
338 join j p = case p of (x,y) -> x+y
339 joinrec go 0 x y = jump j (x,y)
340 go (n-1) x y = jump go (n-1) (x+y) y
341 in …
343 Here, `j` would not be inlined because we do not inline something that looks
344 like an exit join point (see Note [Do not inline exit join points]). But
345 if we exitify the 'jump j (x,y)' we get
347 join j p = case p of (x,y) -> x+y
348 join exit x y = jump j (x,y)
349 joinrec go 0 x y = jump exit x y
350 go (n-1) x y = jump go (n-1) (x+y) y
351 in …
353 and now 'j' can inline, and we get rid of the pair. Here's another
354 example (assume `g` to be an imported function that, on its own,
355 does not make this interesting):
357 join j y = map f y
358 joinrec go 0 x y = jump j (map g x)
359 go (n-1) x y = jump go (n-1) (x+y)
360 in …
362 Again, `j` would not be inlined because we do not inline something that looks
363 like an exit join point (see Note [Do not inline exit join points]).
365 But after exitification we have
367 join j y = map f y
368 join exit x = jump j (map g x)
369 joinrec go 0 x y = jump j (map g x)
370 go (n-1) x y = jump go (n-1) (x+y)
371 in …
373 and now we can inline `j` and this will allow `map/map` to fire.
376 Note [Idempotency]
377 ~~~~~~~~~~~~~~~~~~
379 We do not want this to happen, where we replace the floated expression with
380 essentially the same expression:
382 join exit x = t (x*x)
383 joinrec go 0 x y = jump exit x
384 go (n-1) x y = jump go (n-1) (x+y)
385 in …
386 ==>
387 join exit x = t (x*x)
388 join exit' x = jump exit x
389 joinrec go 0 x y = jump exit' x
390 go (n-1) x y = jump go (n-1) (x+y)
391 in …
393 So when the RHS is a join jump, and all of its arguments are captured variables,
394 then we leave it in place.
396 Note that `jump exit x` in this example looks interesting, as `exit` is a free
397 variable. Therefore, idempotency does not simply follow from floating only
398 interesting expressions.
400 Note [Calculating free variables]
401 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
402 We have two options where to annotate the tree with free variables:
404 A) The whole tree.
405 B) Each individual joinrec as we come across it.
407 Downside of A: We pay the price on the whole module, even outside any joinrecs.
408 Downside of B: We pay the price per joinrec, possibly multiple times when
409 joinrecs are nested.
411 Further downside of A: If the exitify function returns annotated expressions,
412 it would have to ensure that the annotations are correct.
414 We therefore choose B, and calculate the free variables in `exitify`.
417 Note [Do not inline exit join points]
418 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419 When we have
421 let t = foo bar
422 join exit x = t (x*x)
423 joinrec go 0 x y = jump exit x
424 go (n-1) x y = jump go (n-1) (x+y)
425 in …
427 we do not want the simplifier to simply inline `exit` back in (which it happily
428 would).
430 To prevent this, we need to recognize exit join points, and then disable
431 inlining.
433 Exit join points, recognizeable using `isExitJoinId` are join points with an
434 occurence in a recursive group, and can be recognized (after the occurence
435 analyzer ran!) using `isExitJoinId`.
436 This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
437 because the lambdas of a non-recursive join point are not considered for
438 `occ_in_lam`. For example, in the following code, `j1` is /not/ marked
439 occ_in_lam, because `j2` is called only once.
441 join j1 x = x+1
442 join j2 y = join j1 (y+2)
444 To prevent inlining, we check for isExitJoinId
445 * In `preInlineUnconditionally` directly.
446 * In `simplLetUnfolding` we simply give exit join points no unfolding, which
447 prevents inlining in `postInlineUnconditionally` and call sites.
449 Note [Placement of the exitification pass]
450 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
451 I (Joachim) experimented with multiple positions for the Exitification pass in
452 the Core2Core pipeline:
454 A) Before the `simpl_phases`
455 B) Between the `simpl_phases` and the "main" simplifier pass
456 C) After demand_analyser
457 D) Before the final simplification phase
459 Here is the table (this is without inlining join exit points in the final
460 simplifier run):
462 Program | Allocs | Instrs
463 | ABCD.log A.log B.log C.log D.log | ABCD.log A.log B.log C.log D.log
464 ----------------|---------------------------------------------------|-------------------------------------------------
465 fannkuch-redux | -99.9% +0.0% -99.9% -99.9% -99.9% | -3.9% +0.5% -3.0% -3.9% -3.9%
466 fasta | -0.0% +0.0% +0.0% -0.0% -0.0% | -8.5% +0.0% +0.0% -0.0% -8.5%
467 fem | 0.0% 0.0% 0.0% 0.0% +0.0% | -2.2% -0.1% -0.1% -2.1% -2.1%
468 fish | 0.0% 0.0% 0.0% 0.0% +0.0% | -3.1% +0.0% -1.1% -1.1% -0.0%
469 k-nucleotide | -91.3% -91.0% -91.0% -91.3% -91.3% | -6.3% +11.4% +11.4% -6.3% -6.2%
470 scs | -0.0% -0.0% -0.0% -0.0% -0.0% | -3.4% -3.0% -3.1% -3.3% -3.3%
471 simple | -6.0% 0.0% -6.0% -6.0% +0.0% | -3.4% +0.0% -5.2% -3.4% -0.1%
472 spectral-norm | -0.0% 0.0% 0.0% -0.0% +0.0% | -2.7% +0.0% -2.7% -5.4% -5.4%
473 ----------------|---------------------------------------------------|-------------------------------------------------
474 Min | -95.0% -91.0% -95.0% -95.0% -95.0% | -8.5% -3.0% -5.2% -6.3% -8.5%
475 Max | +0.2% +0.2% +0.2% +0.2% +1.5% | +0.4% +11.4% +11.4% +0.4% +1.5%
476 Geometric Mean | -4.7% -2.1% -4.7% -4.7% -4.6% | -0.4% +0.1% -0.1% -0.3% -0.2%
478 Position A is disqualified, as it does not get rid of the allocations in
479 fannkuch-redux.
480 Position A and B are disqualified because it increases instructions in k-nucleotide.
481 Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta.
483 Assuming we have a budget of _one_ run of Exitification, then C wins (but we
484 could get more from running it multiple times, as seen in fish).
486 Note [Picking arguments to abstract over]
487 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
489 When we create an exit join point, so we need to abstract over those of its
490 free variables that are be out-of-scope at the destination of the exit join
491 point. So we go through the list `captured` and pick those that are actually
492 free variables of the join point.
494 We do not just `filter (`elemVarSet` fvs) captured`, as there might be
495 shadowing, and `captured` may contain multiple variables with the same Unique. I
496 these cases we want to abstract only over the last occurence, hence the `foldr`
497 (with emphasis on the `r`). This is #15110.
499 -}