In Exitify, zap idInfo of abstracted variables (fixes #15005)
[ghc.git] / compiler / simplCore / Exitify.hs
1 module Exitify ( exitifyProgram ) where
2
3 {-
4 Note [Exitification]
5 ~~~~~~~~~~~~~~~~~~~~
6
7 This module implements Exitification. The goal is to pull as much code out of
8 recursive functions as possible, as the simplifier is better at inlining into
9 call-sites that are not in recursive functions.
10
11 Example:
12
13 let t = foo bar
14 joinrec go 0 x y = t (x*x)
15 go (n-1) x y = jump go (n-1) (x+y)
16 in …
17
18 We’d like to inline `t`, but that does not happen: Because t is a thunk and is
19 used in a recursive function, doing so might lose sharing in general. In
20 this case, however, `t` is on the _exit path_ of `go`, so called at most once.
21 How do we make this clearly visible to the simplifier?
22
23 A code path (i.e., an expression in a tail-recursive position) in a recursive
24 function is an exit path if it does not contain a recursive call. We can bind
25 this expression outside the recursive function, as a join-point.
26
27 Example result:
28
29 let t = foo bar
30 join exit x = t (x*x)
31 joinrec go 0 x y = jump exit x
32 go (n-1) x y = jump go (n-1) (x+y)
33 in …
34
35 Now `t` is no longer in a recursive function, and good things happen!
36 -}
37
38 import GhcPrelude
39 import Var
40 import Id
41 import IdInfo
42 import CoreSyn
43 import CoreUtils
44 import State
45 import Unique
46 import VarSet
47 import VarEnv
48 import CoreFVs
49 import FastString
50 import Type
51 import MkCore ( sortQuantVars )
52
53 import Data.Bifunctor
54 import Control.Monad
55
56 -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
57 -- The really interesting function is exitify
58 exitifyProgram :: CoreProgram -> CoreProgram
59 exitifyProgram binds = map goTopLvl binds
60 where
61 goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
62 goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
63 -- Top-level bindings are never join points
64
65 in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
66
67 go :: InScopeSet -> CoreExpr -> CoreExpr
68 go _ e@(Var{}) = e
69 go _ e@(Lit {}) = e
70 go _ e@(Type {}) = e
71 go _ e@(Coercion {}) = e
72
73 go in_scope (Lam v e') = Lam v (go in_scope' e')
74 where in_scope' = in_scope `extendInScopeSet` v
75 go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
76 go in_scope (Case scrut bndr ty alts)
77 = Case (go in_scope scrut) bndr ty (map (goAlt in_scope') alts)
78 where in_scope' = in_scope `extendInScopeSet` bndr
79 go in_scope (Cast e' c) = Cast (go in_scope e') c
80 go in_scope (Tick t e') = Tick t (go in_scope e')
81 go in_scope (Let bind body) = goBind in_scope bind (go in_scope' body)
82 where in_scope' = in_scope `extendInScopeSetList` bindersOf bind
83
84 goAlt :: InScopeSet -> CoreAlt -> CoreAlt
85 goAlt in_scope (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
86 where in_scope' = in_scope `extendInScopeSetList` pats
87
88 goBind :: InScopeSet -> CoreBind -> (CoreExpr -> CoreExpr)
89 goBind in_scope (NonRec v rhs) = Let (NonRec v (go in_scope rhs))
90 goBind in_scope (Rec pairs)
91 | is_join_rec = exitify in_scope' pairs'
92 | otherwise = Let (Rec pairs')
93 where pairs' = map (second (go in_scope')) pairs
94 is_join_rec = any (isJoinId . fst) pairs
95 in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
96
97
98 -- | State Monad used inside `exitify`
99 type ExitifyM = State [(JoinId, CoreExpr)]
100
101 -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
102 -- join-points outside the joinrec.
103 exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
104 exitify in_scope pairs =
105 \body ->mkExitLets exits (mkLetRec pairs' body)
106 where
107 mkExitLets ((exitId, exitRhs):exits') = mkLetNonRec exitId exitRhs . mkExitLets exits'
108 mkExitLets [] = id
109
110 -- We need the set of free variables of many subexpressions here, so
111 -- annotate the AST with them
112 -- see Note [Calculating free variables]
113 ann_pairs = map (second freeVars) pairs
114
115 -- Which are the recursive calls?
116 recursive_calls = mkVarSet $ map fst pairs
117
118 (pairs',exits) = (`runState` []) $ do
119 forM ann_pairs $ \(x,rhs) -> do
120 -- go past the lambdas of the join point
121 let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
122 body' <- go args body
123 let rhs' = mkLams args body'
124 return (x, rhs')
125
126 -- main working function. Goes through the RHS (tail-call positions only),
127 -- checks if there are no more recursive calls, if so, abstracts over
128 -- variables bound on the way and lifts it out as a join point.
129 --
130 -- ExitifyM is a state monad to keep track of floated binds
131 go :: [Var] -- ^ variables to abstract over
132 -> CoreExprWithFVs -- ^ current expression in tail position
133 -> ExitifyM CoreExpr
134
135 -- We first look at the expression (no matter what it shape is)
136 -- and determine if we can turn it into a exit join point
137 go captured ann_e
138 -- Do not touch an expression that is already a join jump where all arguments
139 -- are captured variables. See Note [Idempotency]
140 -- But _do_ float join jumps with interesting arguments.
141 -- See Note [Jumps can be interesting]
142 | (Var f, args) <- collectArgs e
143 , isJoinId f
144 , all isCapturedVarArg args
145 = return e
146
147 -- Do not touch a boring expression (see Note [Interesting expression])
148 | is_exit, not is_interesting = return e
149
150 -- Cannot float out if local join points are used, as
151 -- we cannot abstract over them
152 | is_exit, captures_join_points = return e
153
154 -- We have something to float out!
155 | is_exit = do
156 -- Assemble the RHS of the exit join point
157 let rhs = mkLams abs_vars e
158 ty = exprType rhs
159 let avoid = in_scope `extendInScopeSetList` captured
160 -- Remember this binding under a suitable name
161 v <- addExit avoid ty (length abs_vars) rhs
162 -- And jump to it from here
163 return $ mkVarApps (Var v) abs_vars
164 where
165 -- An exit expression has no recursive calls
166 is_exit = disjointVarSet fvs recursive_calls
167
168 -- Used to detect exit expressoins that are already proper exit jumps
169 isCapturedVarArg (Var v) = v `elem` captured
170 isCapturedVarArg _ = False
171
172 -- An interesting exit expression has free, non-imported
173 -- variables from outside the recursive group
174 -- See Note [Interesting expression]
175 is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
176
177 -- The possible arguments of this exit join point
178 abs_vars =
179 map zap $
180 sortQuantVars $
181 filter (`elemVarSet` fvs) captured
182
183 -- cf. SetLevels.abstractVars
184 zap v | isId v = setIdInfo v vanillaIdInfo
185 | otherwise = v
186
187 -- We cannot abstract over join points
188 captures_join_points = any isJoinId abs_vars
189
190 e = deAnnotate ann_e
191 fvs = dVarSetToVarSet (freeVarsOf ann_e)
192
193 -- We could not turn it into a exit joint point. So now recurse
194 -- into all expression where eligible exit join points might sit,
195 -- i.e. into all tail-call positions:
196
197 -- Case right hand sides are in tail-call position
198 go captured (_, AnnCase scrut bndr ty alts) = do
199 alts' <- forM alts $ \(dc, pats, rhs) -> do
200 rhs' <- go (captured ++ [bndr] ++ pats) rhs
201 return (dc, pats, rhs')
202 return $ Case (deAnnotate scrut) bndr ty alts'
203
204 go captured (_, AnnLet ann_bind body)
205 -- join point, RHS and body are in tail-call position
206 | AnnNonRec j rhs <- ann_bind
207 , Just join_arity <- isJoinId_maybe j
208 = do let (params, join_body) = collectNAnnBndrs join_arity rhs
209 join_body' <- go (captured ++ params) join_body
210 let rhs' = mkLams params join_body'
211 body' <- go (captured ++ [j]) body
212 return $ Let (NonRec j rhs') body'
213
214 -- rec join point, RHSs and body are in tail-call position
215 | AnnRec pairs <- ann_bind
216 , isJoinId (fst (head pairs))
217 = do let js = map fst pairs
218 pairs' <- forM pairs $ \(j,rhs) -> do
219 let join_arity = idJoinArity j
220 (params, join_body) = collectNAnnBndrs join_arity rhs
221 join_body' <- go (captured ++ js ++ params) join_body
222 let rhs' = mkLams params join_body'
223 return (j, rhs')
224 body' <- go (captured ++ js) body
225 return $ Let (Rec pairs') body'
226
227 -- normal Let, only the body is in tail-call position
228 | otherwise
229 = do body' <- go (captured ++ bindersOf bind ) body
230 return $ Let bind body'
231 where bind = deAnnBind ann_bind
232
233 -- Cannot be turned into an exit join point, but also has no
234 -- tail-call subexpression. Nothing to do here.
235 go _ ann_e = return (deAnnotate ann_e)
236
237
238 -- Picks a new unique, which is disjoint from
239 -- * the free variables of the whole joinrec
240 -- * any bound variables (captured)
241 -- * any exit join points created so far.
242 mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
243 mkExitJoinId in_scope ty join_arity = do
244 fs <- get
245 let avoid = in_scope `extendInScopeSetList` (map fst fs)
246 `extendInScopeSet` exit_id_tmpl -- just cosmetics
247 return (uniqAway avoid exit_id_tmpl)
248 where
249 exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
250 `asJoinId` join_arity
251
252 addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
253 addExit in_scope ty join_arity rhs = do
254 -- Pick a suitable name
255 v <- mkExitJoinId in_scope ty join_arity
256 fs <- get
257 put ((v,rhs):fs)
258 return v
259
260
261 {-
262 Note [Interesting expression]
263 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264 We do not want this to happen:
265
266 joinrec go 0 x y = x
267 go (n-1) x y = jump go (n-1) (x+y)
268 in …
269 ==>
270 join exit x = x
271 joinrec go 0 x y = jump exit x
272 go (n-1) x y = jump go (n-1) (x+y)
273 in …
274
275 because the floated exit path (`x`) is simply a parameter of `go`; there are
276 not useful interactions exposed this way.
277
278 Neither do we want this to happen
279
280 joinrec go 0 x y = x+x
281 go (n-1) x y = jump go (n-1) (x+y)
282 in …
283 ==>
284 join exit x = x+x
285 joinrec go 0 x y = jump exit x
286 go (n-1) x y = jump go (n-1) (x+y)
287 in …
288
289 where the floated expression `x+x` is a bit more complicated, but still not
290 intersting.
291
292 Expressions are interesting when they move an occurrence of a variable outside
293 the recursive `go` that can benefit from being obviously called once, for example:
294 * a local thunk that can then be inlined (see example in note [Exitification])
295 * the parameter of a function, where the demand analyzer then can then
296 see that it is called at most once, and hence improve the function’s
297 strictness signature
298
299 So we only hoist an exit expression out if it mentiones at least one free,
300 non-imported variable.
301
302 Note [Jumps can be interesting]
303 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
304 A jump to a join point can be interesting, if its arguments contain free
305 non-exported variables (z in the following example):
306
307 joinrec go 0 x y = jump j (x+z)
308 go (n-1) x y = jump go (n-1) (x+y)
309 in …
310 ==>
311 join exit x y = jump j (x+z)
312 joinrec go 0 x y = jump exit x
313 go (n-1) x y = jump go (n-1) (x+y)
314
315
316 The join point itself can be interesting, even if none if its
317 arguments have free variables free in the joinrec. For example
318
319 join j p = case p of (x,y) -> x+y
320 joinrec go 0 x y = jump j (x,y)
321 go (n-1) x y = jump go (n-1) (x+y) y
322 in …
323
324 Here, `j` would not be inlined because we do not inline something that looks
325 like an exit join point (see Note [Do not inline exit join points]). But
326 if we exitify the 'jump j (x,y)' we get
327
328 join j p = case p of (x,y) -> x+y
329 join exit x y = jump j (x,y)
330 joinrec go 0 x y = jump exit x y
331 go (n-1) x y = jump go (n-1) (x+y) y
332 in …
333
334 and now 'j' can inline, and we get rid of the pair. Here's another
335 example (assume `g` to be an imported function that, on its own,
336 does not make this interesting):
337
338 join j y = map f y
339 joinrec go 0 x y = jump j (map g x)
340 go (n-1) x y = jump go (n-1) (x+y)
341 in …
342
343 Again, `j` would not be inlined because we do not inline something that looks
344 like an exit join point (see Note [Do not inline exit join points]).
345
346 But after exitification we have
347
348 join j y = map f y
349 join exit x = jump j (map g x)
350 joinrec go 0 x y = jump j (map g x)
351 go (n-1) x y = jump go (n-1) (x+y)
352 in …
353
354 and now we can inline `j` and this will allow `map/map` to fire.
355
356
357 Note [Idempotency]
358 ~~~~~~~~~~~~~~~~~~
359
360 We do not want this to happen, where we replace the floated expression with
361 essentially the same expression:
362
363 join exit x = t (x*x)
364 joinrec go 0 x y = jump exit x
365 go (n-1) x y = jump go (n-1) (x+y)
366 in …
367 ==>
368 join exit x = t (x*x)
369 join exit' x = jump exit x
370 joinrec go 0 x y = jump exit' x
371 go (n-1) x y = jump go (n-1) (x+y)
372 in …
373
374 So when the RHS is a join jump, and all of its arguments are captured variables,
375 then we leave it in place.
376
377 Note that `jump exit x` in this example looks interesting, as `exit` is a free
378 variable. Therefore, idempotency does not simply follow from floating only
379 interesting expressions.
380
381 Note [Calculating free variables]
382 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
383 We have two options where to annotate the tree with free variables:
384
385 A) The whole tree.
386 B) Each individual joinrec as we come across it.
387
388 Downside of A: We pay the price on the whole module, even outside any joinrecs.
389 Downside of B: We pay the price per joinrec, possibly multiple times when
390 joinrecs are nested.
391
392 Further downside of A: If the exitify function returns annotated expressions,
393 it would have to ensure that the annotations are correct.
394
395 We therefore choose B, and calculate the free variables in `exitify`.
396
397
398 Note [Do not inline exit join points]
399 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
400 When we have
401
402 let t = foo bar
403 join exit x = t (x*x)
404 joinrec go 0 x y = jump exit x
405 go (n-1) x y = jump go (n-1) (x+y)
406 in …
407
408 we do not want the simplifier to simply inline `exit` back in (which it happily
409 would).
410
411 To prevent this, we need to recognize exit join points, and then disable
412 inlining.
413
414 Exit join points, recognizeable using `isExitJoinId` are join points with an
415 occurence in a recursive group, and can be recognized (after the occurence
416 analyzer ran!) using `isExitJoinId`.
417 This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
418 because the lambdas of a non-recursive join point are not considered for
419 `occ_in_lam`. For example, in the following code, `j1` is /not/ marked
420 occ_in_lam, because `j2` is called only once.
421
422 join j1 x = x+1
423 join j2 y = join j1 (y+2)
424
425 To prevent inlining, we check for isExitJoinId
426 * In `preInlineUnconditionally` directly.
427 * In `simplLetUnfolding` we simply give exit join points no unfolding, which
428 prevents inlining in `postInlineUnconditionally` and call sites.
429
430 Note [Placement of the exitification pass]
431 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
432 I (Joachim) experimented with multiple positions for the Exitification pass in
433 the Core2Core pipeline:
434
435 A) Before the `simpl_phases`
436 B) Between the `simpl_phases` and the "main" simplifier pass
437 C) After demand_analyser
438 D) Before the final simplification phase
439
440 Here is the table (this is without inlining join exit points in the final
441 simplifier run):
442
443 Program | Allocs | Instrs
444 | ABCD.log A.log B.log C.log D.log | ABCD.log A.log B.log C.log D.log
445 ----------------|---------------------------------------------------|-------------------------------------------------
446 fannkuch-redux | -99.9% +0.0% -99.9% -99.9% -99.9% | -3.9% +0.5% -3.0% -3.9% -3.9%
447 fasta | -0.0% +0.0% +0.0% -0.0% -0.0% | -8.5% +0.0% +0.0% -0.0% -8.5%
448 fem | 0.0% 0.0% 0.0% 0.0% +0.0% | -2.2% -0.1% -0.1% -2.1% -2.1%
449 fish | 0.0% 0.0% 0.0% 0.0% +0.0% | -3.1% +0.0% -1.1% -1.1% -0.0%
450 k-nucleotide | -91.3% -91.0% -91.0% -91.3% -91.3% | -6.3% +11.4% +11.4% -6.3% -6.2%
451 scs | -0.0% -0.0% -0.0% -0.0% -0.0% | -3.4% -3.0% -3.1% -3.3% -3.3%
452 simple | -6.0% 0.0% -6.0% -6.0% +0.0% | -3.4% +0.0% -5.2% -3.4% -0.1%
453 spectral-norm | -0.0% 0.0% 0.0% -0.0% +0.0% | -2.7% +0.0% -2.7% -5.4% -5.4%
454 ----------------|---------------------------------------------------|-------------------------------------------------
455 Min | -95.0% -91.0% -95.0% -95.0% -95.0% | -8.5% -3.0% -5.2% -6.3% -8.5%
456 Max | +0.2% +0.2% +0.2% +0.2% +1.5% | +0.4% +11.4% +11.4% +0.4% +1.5%
457 Geometric Mean | -4.7% -2.1% -4.7% -4.7% -4.6% | -0.4% +0.1% -0.1% -0.3% -0.2%
458
459 Position A is disqualified, as it does not get rid of the allocations in
460 fannkuch-redux.
461 Position A and B are disqualified because it increases instructions in k-nucleotide.
462 Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta.
463
464 Assuming we have a budget of _one_ run of Exitification, then C wins (but we
465 could get more from running it multiple times, as seen in fish).
466
467 -}