Call arity: Handle type application correctly
[ghc.git] / compiler / simplCore / CallArity.hs
1 --
2 -- Copyright (c) 2014 Joachim Breitner
3 --
4
5 module CallArity
6 ( callArityAnalProgram
7 , callArityRHS -- for testing
8 ) where
9
10 import VarSet
11 import VarEnv
12 import DynFlags ( DynFlags )
13
14 import BasicTypes
15 import CoreSyn
16 import Id
17 import CoreArity ( exprArity, typeArity )
18 import CoreUtils ( exprIsHNF )
19 import Outputable
20
21 import Control.Arrow ( first, second )
22
23
24 {-
25 %************************************************************************
26 %* *
27 Call Arity Analyis
28 %* *
29 %************************************************************************
30
31 Note [Call Arity: The goal]
32 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
33
34 The goal of this analysis is to find out if we can eta-expand a local function,
35 based on how it is being called. The motivating example is code this this,
36 which comes up when we implement foldl using foldr, and do list fusion:
37
38 let go = \x -> let d = case ... of
39 False -> go (x+1)
40 True -> id
41 in \z -> d (x + z)
42 in go 1 0
43
44 If we do not eta-expand `go` to have arity 2, we are going to allocate a lot of
45 partial function applications, which would be bad.
46
47 The function `go` has a type of arity two, but only one lambda is manifest.
48 Further more, an analysis that only looks at the RHS of go cannot be sufficient
49 to eta-expand go: If `go` is ever called with one argument (and the result used
50 multiple times), we would be doing the work in `...` multiple times.
51
52 So `callArityAnalProgram` looks at the whole let expression to figure out if
53 all calls are nice, i.e. have a high enough arity. It then stores the result in
54 the `calledArity` field of the `IdInfo` of `go`, which the next simplifier
55 phase will eta-expand.
56
57 The specification of the `calledArity` field is:
58
59 No work will be lost if you eta-expand me to the arity in `calledArity`.
60
61 The specification of the analysis
62 ---------------------------------
63
64 The analysis only does a conservative approximation, there are plenty of
65 situations where eta-expansion would be ok, but we do not catch it. We are
66 content if all the code that foldl-via-foldr generates is being optimized
67 sufficiently.
68
69 The work-hourse of the analysis is the function `callArityAnal`, with the
70 following type:
71
72 data Count = Many | OnceAndOnly
73 type CallCount = (Count, Arity)
74 type CallArityEnv = VarEnv (CallCount, Arity)
75 callArityAnal ::
76 Arity -> -- The arity this expression is called with
77 VarSet -> -- The set of interesting variables
78 CoreExpr -> -- The expression to analyse
79 (CallArityEnv, CoreExpr)
80
81 and the following specification:
82
83 (callArityEnv, expr') = callArityEnv arity interestingIds expr
84
85 <=>
86
87 Assume the expression `expr` is being passed `arity` arguments. Then it calls
88 the functions mentioned in `interestingIds` according to `callArityEnv`:
89 * The domain of `callArityEnv` is a subset of `interestingIds`.
90 * Any variable from interestingIds that is not mentioned in the `callArityEnv`
91 is absent, i.e. not called at all.
92 * Of all the variables that are mapped to OnceAndOnly by the `callArityEnv`,
93 at most one is being called, at most once, with at least that many
94 arguments.
95 * Variables mapped to Many are called an unknown number of times, but if they
96 are called, then with at least that many arguments.
97 Furthermore, expr' is expr with the callArity field of the `IdInfo` updated.
98
99 The (pointwise) domain is a product domain:
100
101 Many 0
102 | × |
103 OneAndOnly 1
104 |
105 ...
106
107 The at-most-once is important for various reasons:
108
109 1. Consider:
110
111 let n = case .. of .. -- A thunk!
112 in n 0 + n 1
113
114 vs.
115
116 let n = case .. of ..
117 in case .. of T -> n 0
118 F -> n 1
119
120 We are only allowed to eta-expand `n` if it is going to be called at most
121 once in the body of the outer let. So we need to know, for each variable
122 individually, that it is going to be called at most once.
123
124 2. We need to know it for non-thunks as well, because they might call a thunk:
125
126 let n = case .. of ..
127 f x = n (x+1)
128 in f 1 + f 2
129
130 vs.
131
132 let n = case .. of ..
133 f x = n (x+1)
134 in case .. of T -> f 0
135 F -> f 1
136
137 Here, the body of f calls n exactly once, but f itself is being called
138 multiple times, so eta-expansion is not allowed.
139
140 3. We need to know that at most one of the interesting functions is being
141 called, because of recursion. Consider:
142
143 let n = case .. of ..
144 in case .. of
145 True -> let go = \y -> case .. of
146 True -> go (y + n 1)
147 False > n
148 in go 1
149 False -> n
150
151 vs.
152
153 let n = case .. of ..
154 in case .. of
155 True -> let go = \y -> case .. of
156 True -> go (y+1)
157 False > n
158 in go 1
159 False -> n
160
161 In both cases, the body and the rhs of the inner let call n at most once.
162 But only in the second case that holds for the whole expression! The
163 crucial difference is that in the first case, the rhs of `go` can call
164 *both* `go` and `n`, and hence can call `n` multiple times as it recurses,
165 while in the second case it calls `go` or `n`, but not both.
166
167 Note [Which variables are interesting]
168 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
169
170 Unfortunately, the set of interesting variables is not irrelevant for the
171 precision of the analysis. Consider this example (and ignore the pointlessnes
172 of `d` recursing into itself):
173
174 let n = ... :: Int
175 in let d = let d = case ... of
176 False -> d
177 True -> id
178 in \z -> d (x + z)
179 in d 0
180
181 Of course, `d` should be interesting. If we consider `n` as interesting as
182 well, then the body of the second let will return
183 { go |-> (Many, 1) , n |-> (OnceAndOnly, 0) }
184 or
185 { go |-> (OnceAndOnly, 1), n |-> (Many, 0)}.
186 Only the latter is useful, but it is hard to decide that locally.
187 (Returning OnceAndOnly for both would be wrong, as both are being called.)
188
189 So the heuristics is:
190
191 Variables are interesting if their RHS has a lower exprArity than
192 typeArity.
193
194 (which is precisely the those variables where this analysis can actually cause
195 some eta-expansion.)
196
197 But this is not uniformly a win. Consider:
198
199 let go = \x -> let d = case ... of
200 False -> go (x+1)
201 True -> id
202 n x = d (x+1)
203 in \z -> n (x + z)
204 in go n 0
205
206 Now `n` is not going to be considered interesting (its type is `Int -> Int`).
207 But this will prevent us from detecting how often the body of the let calls
208 `d`, and we will not find out anything.
209
210 It might be possible to be smarter here; this needs find-tuning as we find more
211 examples.
212
213
214 Note [Recursion and fixpointing]
215 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
216
217 For a recursive let, we begin by analysing the body, using the same incoming
218 arity as for the whole expression.
219 * We use the arity from the body on the variable as the incoming demand on the
220 rhs. Then we check if the rhs calls itself with the same arity.
221 - If so, we are done.
222 - If not, we re-analise the rhs with the reduced arity. We do that until
223 we are down to the exprArity, which then is certainly correct.
224 * If the rhs calls itself many times, we must (conservatively) pass the result
225 through forgetOnceCalls.
226 * Similarly, if the body calls the variable many times, we must pass the
227 result of the fixpointing through forgetOnceCalls.
228 * Then we can `lubEnv` the results from the body and the rhs: If all mentioned
229 calls are OnceAndOnly calls, then the body calls *either* the rhs *or* one
230 of the other mentioned variables. Similarly, the rhs calls *either* itself
231 again *or* one of the other mentioned variables. This precision is required!
232
233 We do not analyse mutually recursive functions. This can be done once we see it
234 in the wild.
235
236 Note [Case and App: Which side to take?]
237 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
238
239 Combining the case branches is easy, just `lubEnv` them – at most one branch is
240 taken.
241
242 But how to combine that with the information coming from the scrunitee? Very
243 similarly, how to combine the information from the callee and argument of an
244 `App`?
245
246 It would not be correct to just `lubEnv` then: `f n` obviously calls *both* `f`
247 and `n`. We need to forget about the cardinality of calls from one side using
248 `forgetOnceCalls`. But which one?
249
250 Both are correct, and sometimes one and sometimes the other is more precise
251 (also see example in [Which variables are interesting]).
252
253 So currently, we first check the scrunitee (resp. the callee) if the returned
254 value has any usesful information, and if so, we use that; otherwise we use the
255 information from the alternatives (resp. the argument).
256
257 It might be smarter to look for “more important” variables first, i.e. the
258 innermost recursive variable.
259
260 Note [Analysing top-level binds]
261 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
262
263 We can eta-expand top-level-binds if they are not exported, as we see all calls
264 to them. The plan is as follows: Treat the top-level binds as nested lets around
265 a body representing “all external calls”, which returns a CallArityEnv that calls
266 every exported function with the top of the lattice.
267
268 This means that the incoming arity on all top-level binds will have a Many
269 attached, and we will never eta-expand CAFs. Which is good.
270
271 -}
272
273 callArityAnalProgram :: DynFlags -> CoreProgram -> CoreProgram
274 callArityAnalProgram _dflags binds = binds'
275 where
276 (_, binds') = callArityTopLvl [] emptyVarSet binds
277
278 -- See Note [Analysing top-level-binds]
279 callArityTopLvl :: [Var] -> VarSet -> [CoreBind] -> (CallArityEnv, [CoreBind])
280 callArityTopLvl exported _ []
281 = (mkVarEnv $ zip exported (repeat topCallCount), [])
282 callArityTopLvl exported int1 (b:bs)
283 = (ae2, b':bs')
284 where
285 int2 = interestingBinds b
286 exported' = filter isExportedId int2 ++ exported
287 int' = int1 `extendVarSetList` int2
288 (ae1, bs') = callArityTopLvl exported' int' bs
289 (ae2, b') = callArityBind ae1 int1 b
290
291
292 callArityRHS :: CoreExpr -> CoreExpr
293 callArityRHS = snd . callArityAnal 0 emptyVarSet
294
295
296 data Count = Many | OnceAndOnly deriving (Eq, Ord)
297 type CallCount = (Count, Arity)
298
299 topCallCount :: CallCount
300 topCallCount = (Many, 0)
301
302 type CallArityEnv = VarEnv CallCount
303
304 callArityAnal ::
305 Arity -> -- The arity this expression is called with
306 VarSet -> -- The set of interesting variables
307 CoreExpr -> -- The expression to analyse
308 (CallArityEnv, CoreExpr)
309 -- How this expression uses its interesting variables
310 -- and the expression with IdInfo updated
311
312 -- The trivial base cases
313 callArityAnal _ _ e@(Lit _)
314 = (emptyVarEnv, e)
315 callArityAnal _ _ e@(Type _)
316 = (emptyVarEnv, e)
317 callArityAnal _ _ e@(Coercion _)
318 = (emptyVarEnv, e)
319 -- The transparent cases
320 callArityAnal arity int (Tick t e)
321 = second (Tick t) $ callArityAnal arity int e
322 callArityAnal arity int (Cast e co)
323 = second (\e -> Cast e co) $ callArityAnal arity int e
324
325 -- The interesting case: Variables, Lambdas, Lets, Applications, Cases
326 callArityAnal arity int e@(Var v)
327 | v `elemVarSet` int
328 = (unitVarEnv v (OnceAndOnly, arity), e)
329 | otherwise
330 = (emptyVarEnv, e)
331
332 -- Non-value lambdas are ignored
333 callArityAnal arity int (Lam v e) | not (isId v)
334 = second (Lam v) $ callArityAnal arity int e
335
336 -- We have a lambda that we are not sure to call. Tail calls therein
337 -- are no longer OneAndOnly calls
338 callArityAnal 0 int (Lam v e)
339 = (ae', Lam v e')
340 where
341 (ae, e') = callArityAnal 0 int e
342 ae' = forgetOnceCalls ae
343 -- We have a lambda that we are calling. decrease arity.
344 callArityAnal arity int (Lam v e)
345 = (ae, Lam v e')
346 where
347 (ae, e') = callArityAnal (arity - 1) int e
348
349 -- For lets, use callArityBind
350 callArityAnal arity int (Let bind e)
351 = -- pprTrace "callArityAnal:Let"
352 -- (vcat [ppr v, ppr arity, ppr n, ppr final_ae ])
353 (final_ae, Let bind' e')
354 where
355 int_body = int `extendVarSetList` interestingBinds bind
356 (ae_body, e') = callArityAnal arity int_body e
357 (final_ae, bind') = callArityBind ae_body int bind
358
359
360 -- Application. Increase arity for the called expresion, nothing to know about
361 -- the second
362 callArityAnal arity int (App e (Type t))
363 = second (\e -> App e (Type t)) $ callArityAnal arity int e
364 callArityAnal arity int (App e1 e2)
365 = (final_ae, App e1' e2')
366 where
367 (ae1, e1') = callArityAnal (arity + 1) int e1
368 (ae2, e2') = callArityAnal 0 int e2
369 -- See Note [Case and App: Which side to take?]
370 final_ae = ae1 `useBetterOf` ae2
371
372 -- Case expression. Here we decide whether
373 -- we want to look at calls from the scrunitee or the alternatives;
374 -- one of them we set to Nothing.
375 -- Naive idea: If there are interesting calls in the scrunitee,
376 -- zap the alternatives
377 callArityAnal arity int (Case scrut bndr ty alts)
378 = -- pprTrace "callArityAnal:Case"
379 -- (vcat [ppr scrut, ppr final_ae])
380 (final_ae, Case scrut' bndr ty alts')
381 where
382 (alt_aes, alts') = unzip $ map go alts
383 go (dc, bndrs, e) = let (ae, e') = callArityAnal arity int e
384 in (ae, (dc, bndrs, e'))
385 alt_ae = foldl lubEnv emptyVarEnv alt_aes
386 (scrut_ae, scrut') = callArityAnal 0 int scrut
387 -- See Note [Case and App: Which side to take?]
388 final_ae = scrut_ae `useBetterOf` alt_ae
389
390 -- Which bindings should we look at?
391 -- See Note [Which variables are interesting]
392 interestingBinds :: CoreBind -> [Var]
393 interestingBinds bind =
394 map fst $ filter go $ case bind of (NonRec v e) -> [(v,e)]
395 (Rec ves) -> ves
396 where
397 go (v,e) = exprArity e < length (typeArity (idType v))
398
399 -- Used for both local and top-level binds
400 -- First argument is the demand from the body
401 callArityBind :: CallArityEnv -> VarSet -> CoreBind -> (CallArityEnv, CoreBind)
402
403 -- Non-recursive let
404 callArityBind ae_body int (NonRec v rhs)
405 = -- pprTrace "callArityBind:NonRec"
406 -- (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr safe_arity])
407 (final_ae, NonRec v' rhs')
408 where
409 callcount = lookupWithDefaultVarEnv ae_body topCallCount v
410 (ae_rhs, safe_arity, rhs') = callArityBound callcount int rhs
411 final_ae = ae_rhs `lubEnv` (ae_body `delVarEnv` v)
412 v' = v `setIdCallArity` safe_arity
413
414 -- Recursive let. See Note [Recursion and fixpointing]
415 callArityBind ae_body int b@(Rec [(v,rhs)])
416 = -- pprTrace "callArityBind:Rec"
417 -- (vcat [ppr v, ppr ae_body, ppr int, ppr ae_rhs, ppr new_arity])
418 (final_ae, Rec [(v',rhs')])
419 where
420 int_body = int `extendVarSetList` interestingBinds b
421 callcount = lookupWithDefaultVarEnv ae_body topCallCount v
422 (ae_rhs, new_arity, rhs') = callArityFix callcount int_body v rhs
423 final_ae = (ae_rhs `lubEnv` ae_body) `delVarEnv` v
424 v' = v `setIdCallArity` new_arity
425
426
427 -- Mutual recursion. Do nothing serious here, for now
428 callArityBind ae_body int (Rec binds)
429 = (final_ae, Rec binds')
430 where
431 (aes, binds') = unzip $ map go binds
432 go (i,e) = let (ae, _, e') = callArityBound topCallCount int e
433 in (ae, (i,e'))
434 final_ae = foldl lubEnv ae_body aes `delVarEnvList` map fst binds
435
436
437 callArityFix :: CallCount -> VarSet -> Id -> CoreExpr -> (CallArityEnv, Arity, CoreExpr)
438 callArityFix arity int v e
439
440 | arity `lteCallCount` min_arity
441 -- The incoming arity is already lower than the exprArity, so we can
442 -- ignore the arity coming from the RHS
443 = (ae `delVarEnv` v, 0, e')
444
445 | otherwise
446 = if new_arity `ltCallCount` arity
447 -- RHS puts a lower arity on itself, so try that
448 then callArityFix new_arity int v e
449
450 -- RHS calls itself with at least as many arguments as the body of the let: Great!
451 else (ae `delVarEnv` v, safe_arity, e')
452 where
453 (ae, safe_arity, e') = callArityBound arity int e
454 new_arity = lookupWithDefaultVarEnv ae topCallCount v
455 min_arity = (Many, exprArity e)
456
457 -- This is a variant of callArityAnal that takes a CallCount (i.e. an arity with a
458 -- cardinality) and adjust the resulting environment accordingly. It is to be used
459 -- on bound expressions that can possibly be shared.
460 -- It also returns the safe arity used: For a thunk that is called multiple
461 -- times, this will be 0!
462 callArityBound :: CallCount -> VarSet -> CoreExpr -> (CallArityEnv, Arity, CoreExpr)
463 callArityBound (count, arity) int e = (final_ae, safe_arity, e')
464 where
465 is_thunk = not (exprIsHNF e)
466
467 safe_arity | OnceAndOnly <- count = arity
468 | is_thunk = 0 -- A thunk! Do not eta-expand
469 | otherwise = arity
470
471 (ae, e') = callArityAnal safe_arity int e
472
473 final_ae | OnceAndOnly <- count = ae
474 | otherwise = forgetOnceCalls ae
475
476
477 anyGoodCalls :: CallArityEnv -> Bool
478 anyGoodCalls = foldVarEnv ((||) . isOnceCall) False
479
480 isOnceCall :: CallCount -> Bool
481 isOnceCall (OnceAndOnly, _) = True
482 isOnceCall (Many, _) = False
483
484 forgetOnceCalls :: CallArityEnv -> CallArityEnv
485 forgetOnceCalls = mapVarEnv (first (const Many))
486
487 -- See Note [Case and App: Which side to take?]
488 useBetterOf :: CallArityEnv -> CallArityEnv -> CallArityEnv
489 useBetterOf ae1 ae2 | anyGoodCalls ae1 = ae1 `lubEnv` forgetOnceCalls ae2
490 useBetterOf ae1 ae2 | otherwise = forgetOnceCalls ae1 `lubEnv` ae2
491
492 lubCallCount :: CallCount -> CallCount -> CallCount
493 lubCallCount (count1, arity1) (count2, arity2)
494 = (count1 `lubCount` count2, arity1 `min` arity2)
495
496 lubCount :: Count -> Count -> Count
497 lubCount OnceAndOnly OnceAndOnly = OnceAndOnly
498 lubCount _ _ = Many
499
500 lteCallCount :: CallCount -> CallCount -> Bool
501 lteCallCount (count1, arity1) (count2, arity2)
502 = count1 <= count2 && arity1 <= arity2
503
504 ltCallCount :: CallCount -> CallCount -> Bool
505 ltCallCount c1 c2 = c1 `lteCallCount` c2 && c1 /= c2
506
507 -- Used when combining results from alternative cases; take the minimum
508 lubEnv :: CallArityEnv -> CallArityEnv -> CallArityEnv
509 lubEnv = plusVarEnv_C lubCallCount
510
511 instance Outputable Count where
512 ppr Many = text "Many"
513 ppr OnceAndOnly = text "OnceAndOnly"