Allow CSE'ing of work-wrapped bindings (#14186)
[ghc.git] / compiler / simplCore / LiberateCase.hs
1 {-
2 (c) The AQUA Project, Glasgow University, 1994-1998
3
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5 -}
6
7 {-# LANGUAGE CPP #-}
8 module LiberateCase ( liberateCase ) where
9
10 #include "HsVersions.h"
11
12 import DynFlags
13 import CoreSyn
14 import CoreUnfold ( couldBeSmallEnoughToInline )
15 import Id
16 import VarEnv
17 import Util ( notNull )
18
19 {-
20 The liberate-case transformation
21 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22 This module walks over @Core@, and looks for @case@ on free variables.
23 The criterion is:
24 if there is case on a free on the route to the recursive call,
25 then the recursive call is replaced with an unfolding.
26
27 Example
28
29 f = \ t -> case v of
30 V a b -> a : f t
31
32 => the inner f is replaced.
33
34 f = \ t -> case v of
35 V a b -> a : (letrec
36 f = \ t -> case v of
37 V a b -> a : f t
38 in f) t
39 (note the NEED for shadowing)
40
41 => Simplify
42
43 f = \ t -> case v of
44 V a b -> a : (letrec
45 f = \ t -> a : f t
46 in f t)
47
48 Better code, because 'a' is free inside the inner letrec, rather
49 than needing projection from v.
50
51 Note that this deals with *free variables*. SpecConstr deals with
52 *arguments* that are of known form. E.g.
53
54 last [] = error
55 last (x:[]) = x
56 last (x:xs) = last xs
57
58
59 Note [Scrutinee with cast]
60 ~~~~~~~~~~~~~~~~~~~~~~~~~~
61 Consider this:
62 f = \ t -> case (v `cast` co) of
63 V a b -> a : f t
64
65 Exactly the same optimisation (unrolling one call to f) will work here,
66 despite the cast. See mk_alt_env in the Case branch of libCase.
67
68
69 Note [Only functions!]
70 ~~~~~~~~~~~~~~~~~~~~~~
71 Consider the following code
72
73 f = g (case v of V a b -> a : t f)
74
75 where g is expensive. If we aren't careful, liberate case will turn this into
76
77 f = g (case v of
78 V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
79 in f)
80 )
81
82 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
83 if g calls back to the same code recursively.
84
85 Solution: make sure that we only do the liberate-case thing on *functions*
86
87 To think about (Apr 94)
88 ~~~~~~~~~~~~~~
89 Main worry: duplicating code excessively. At the moment we duplicate
90 the entire binding group once at each recursive call. But there may
91 be a group of recursive calls which share a common set of evaluated
92 free variables, in which case the duplication is a plain waste.
93
94 Another thing we could consider adding is some unfold-threshold thing,
95 so that we'll only duplicate if the size of the group rhss isn't too
96 big.
97
98 Data types
99 ~~~~~~~~~~
100 The ``level'' of a binder tells how many
101 recursive defns lexically enclose the binding
102 A recursive defn "encloses" its RHS, not its
103 scope. For example:
104 \begin{verbatim}
105 letrec f = let g = ... in ...
106 in
107 let h = ...
108 in ...
109 \end{verbatim}
110 Here, the level of @f@ is zero, the level of @g@ is one,
111 and the level of @h@ is zero (NB not one).
112
113
114 ************************************************************************
115 * *
116 Top-level code
117 * *
118 ************************************************************************
119 -}
120
121 liberateCase :: DynFlags -> CoreProgram -> CoreProgram
122 liberateCase dflags binds = do_prog (initEnv dflags) binds
123 where
124 do_prog _ [] = []
125 do_prog env (bind:binds) = bind' : do_prog env' binds
126 where
127 (env', bind') = libCaseBind env bind
128
129 {-
130 ************************************************************************
131 * *
132 Main payload
133 * *
134 ************************************************************************
135
136 Bindings
137 ~~~~~~~~
138 -}
139
140 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
141
142 libCaseBind env (NonRec binder rhs)
143 = (addBinders env [binder], NonRec binder (libCase env rhs))
144
145 libCaseBind env (Rec pairs)
146 = (env_body, Rec pairs')
147 where
148 binders = map fst pairs
149
150 env_body = addBinders env binders
151
152 pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
153
154 -- We extend the rec-env by binding each Id to its rhs, first
155 -- processing the rhs with an *un-extended* environment, so
156 -- that the same process doesn't occur for ever!
157 env_rhs = addRecBinds env [ (localiseId binder, libCase env_body rhs)
158 | (binder, rhs) <- pairs
159 , rhs_small_enough binder rhs ]
160 -- localiseID : see Note [Need to localiseId in libCaseBind]
161
162
163 rhs_small_enough id rhs -- Note [Small enough]
164 = idArity id > 0 -- Note [Only functions!]
165 && maybe True (\size -> couldBeSmallEnoughToInline (lc_dflags env) size rhs)
166 (bombOutSize env)
167
168 {-
169 Note [Need to localiseId in libCaseBind]
170 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171 The call to localiseId is needed for two subtle reasons
172 (a) Reset the export flags on the binders so
173 that we don't get name clashes on exported things if the
174 local binding floats out to top level. This is most unlikely
175 to happen, since the whole point concerns free variables.
176 But resetting the export flag is right regardless.
177
178 (b) Make the name an Internal one. External Names should never be
179 nested; if it were floated to the top level, we'd get a name
180 clash at code generation time.
181
182 Note [Small enough]
183 ~~~~~~~~~~~~~~~~~~~
184 Consider
185 \fv. letrec
186 f = \x. BIG...(case fv of { (a,b) -> ...g.. })...
187 g = \y. SMALL...f...
188 Then we *can* do liberate-case on g (small RHS) but not for f (too big).
189 But we can choose on a item-by-item basis, and that's what the
190 rhs_small_enough call in the comprehension for env_rhs does.
191
192 Expressions
193 ~~~~~~~~~~~
194 -}
195
196 libCase :: LibCaseEnv
197 -> CoreExpr
198 -> CoreExpr
199
200 libCase env (Var v) = libCaseApp env v []
201 libCase _ (Lit lit) = Lit lit
202 libCase _ (Type ty) = Type ty
203 libCase _ (Coercion co) = Coercion co
204 libCase env e@(App {}) | let (fun, args) = collectArgs e
205 , Var v <- fun
206 = libCaseApp env v args
207 libCase env (App fun arg) = App (libCase env fun) (libCase env arg)
208 libCase env (Tick tickish body) = Tick tickish (libCase env body)
209 libCase env (Cast e co) = Cast (libCase env e) co
210
211 libCase env (Lam binder body)
212 = Lam binder (libCase (addBinders env [binder]) body)
213
214 libCase env (Let bind body)
215 = Let bind' (libCase env_body body)
216 where
217 (env_body, bind') = libCaseBind env bind
218
219 libCase env (Case scrut bndr ty alts)
220 = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
221 where
222 env_alts = addBinders (mk_alt_env scrut) [bndr]
223 mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
224 mk_alt_env (Cast scrut _) = mk_alt_env scrut -- Note [Scrutinee with cast]
225 mk_alt_env _ = env
226
227 libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
228 -> (AltCon, [CoreBndr], CoreExpr)
229 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
230
231 {-
232 Ids
233 ~~~
234
235 To unfold, we can't just wrap the id itself in its binding if it's a join point:
236
237 jump j a b c => (joinrec j x y z = ... in jump j) a b c -- wrong!!!
238
239 Every jump must provide all arguments, so we have to be careful to wrap the
240 whole jump instead:
241
242 jump j a b c => joinrec j x y z = ... in jump j a b c -- right
243
244 -}
245
246 libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
247 libCaseApp env v args
248 | Just the_bind <- lookupRecId env v -- It's a use of a recursive thing
249 , notNull free_scruts -- with free vars scrutinised in RHS
250 = Let the_bind expr'
251
252 | otherwise
253 = expr'
254
255 where
256 rec_id_level = lookupLevel env v
257 free_scruts = freeScruts env rec_id_level
258 expr' = mkApps (Var v) (map (libCase env) args)
259
260 freeScruts :: LibCaseEnv
261 -> LibCaseLevel -- Level of the recursive Id
262 -> [Id] -- Ids that are scrutinised between the binding
263 -- of the recursive Id and here
264 freeScruts env rec_bind_lvl
265 = [v | (v, scrut_bind_lvl, scrut_at_lvl) <- lc_scruts env
266 , scrut_bind_lvl <= rec_bind_lvl
267 , scrut_at_lvl > rec_bind_lvl]
268 -- Note [When to specialise]
269 -- Note [Avoiding fruitless liberate-case]
270
271 {-
272 Note [When to specialise]
273 ~~~~~~~~~~~~~~~~~~~~~~~~~
274 Consider
275 f = \x. letrec g = \y. case x of
276 True -> ... (f a) ...
277 False -> ... (g b) ...
278
279 We get the following levels
280 f 0
281 x 1
282 g 1
283 y 2
284
285 Then 'x' is being scrutinised at a deeper level than its binding, so
286 it's added to lc_sruts: [(x,1)]
287
288 We do *not* want to specialise the call to 'f', because 'x' is not free
289 in 'f'. So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
290
291 We *do* want to specialise the call to 'g', because 'x' is free in g.
292 Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
293
294 Note [Avoiding fruitless liberate-case]
295 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
296 Consider also:
297 f = \x. case top_lvl_thing of
298 I# _ -> let g = \y. ... g ...
299 in ...
300
301 Here, top_lvl_thing is scrutinised at a level (1) deeper than its
302 binding site (0). Nevertheless, we do NOT want to specialise the call
303 to 'g' because all the structure in its free variables is already
304 visible at the definition site for g. Hence, when considering specialising
305 an occurrence of 'g', we want to check that there's a scruted-var v st
306
307 a) v's binding site is *outside* g
308 b) v's scrutinisation site is *inside* g
309
310
311 ************************************************************************
312 * *
313 Utility functions
314 * *
315 ************************************************************************
316 -}
317
318 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
319 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
320 = env { lc_lvl_env = lvl_env' }
321 where
322 lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
323
324 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
325 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env,
326 lc_rec_env = rec_env}) pairs
327 = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
328 where
329 lvl' = lvl + 1
330 lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
331 rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
332
333 addScrutedVar :: LibCaseEnv
334 -> Id -- This Id is being scrutinised by a case expression
335 -> LibCaseEnv
336
337 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env,
338 lc_scruts = scruts }) scrut_var
339 | bind_lvl < lvl
340 = env { lc_scruts = scruts' }
341 -- Add to scruts iff the scrut_var is being scrutinised at
342 -- a deeper level than its defn
343
344 | otherwise = env
345 where
346 scruts' = (scrut_var, bind_lvl, lvl) : scruts
347 bind_lvl = case lookupVarEnv lvl_env scrut_var of
348 Just lvl -> lvl
349 Nothing -> topLevel
350
351 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
352 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
353
354 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
355 lookupLevel env id
356 = case lookupVarEnv (lc_lvl_env env) id of
357 Just lvl -> lvl
358 Nothing -> topLevel
359
360 {-
361 ************************************************************************
362 * *
363 The environment
364 * *
365 ************************************************************************
366 -}
367
368 type LibCaseLevel = Int
369
370 topLevel :: LibCaseLevel
371 topLevel = 0
372
373 data LibCaseEnv
374 = LibCaseEnv {
375 lc_dflags :: DynFlags,
376
377 lc_lvl :: LibCaseLevel, -- Current level
378 -- The level is incremented when (and only when) going
379 -- inside the RHS of a (sufficiently small) recursive
380 -- function.
381
382 lc_lvl_env :: IdEnv LibCaseLevel,
383 -- Binds all non-top-level in-scope Ids (top-level and
384 -- imported things have a level of zero)
385
386 lc_rec_env :: IdEnv CoreBind,
387 -- Binds *only* recursively defined ids, to their own
388 -- binding group, and *only* in their own RHSs
389
390 lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
391 -- Each of these Ids was scrutinised by an enclosing
392 -- case expression, at a level deeper than its binding
393 -- level.
394 --
395 -- The first LibCaseLevel is the *binding level* of
396 -- the scrutinised Id,
397 -- The second is the level *at which it was scrutinised*.
398 -- (see Note [Avoiding fruitless liberate-case])
399 -- The former is a bit redundant, since you could always
400 -- look it up in lc_lvl_env, but it's just cached here
401 --
402 -- The order is insignificant; it's a bag really
403 --
404 -- There's one element per scrutinisation;
405 -- in principle the same Id may appear multiple times,
406 -- although that'd be unusual:
407 -- case x of { (a,b) -> ....(case x of ...) .. }
408 }
409
410 initEnv :: DynFlags -> LibCaseEnv
411 initEnv dflags
412 = LibCaseEnv { lc_dflags = dflags,
413 lc_lvl = 0,
414 lc_lvl_env = emptyVarEnv,
415 lc_rec_env = emptyVarEnv,
416 lc_scruts = [] }
417
418 -- Bomb-out size for deciding if
419 -- potential liberatees are too big.
420 -- (passed in from cmd-line args)
421 bombOutSize :: LibCaseEnv -> Maybe Int
422 bombOutSize = liberateCaseThreshold . lc_dflags