Prepare source-tree for base-4.13 MFP bump
[ghc.git] / compiler / simplCore / Exitify.hs
index 2d3b5af..3e7d503 100644 (file)
@@ -48,58 +48,66 @@ import VarEnv
 import CoreFVs
 import FastString
 import Type
+import Util( mapSnd )
 
 import Data.Bifunctor
 import Control.Monad
 
 -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
+-- The really interesting function is exitifyRec
 exitifyProgram :: CoreProgram -> CoreProgram
 exitifyProgram binds = map goTopLvl binds
   where
     goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
     goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
+      -- Top-level bindings are never join points
 
     in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
 
     go :: InScopeSet -> CoreExpr -> CoreExpr
-    go _ e@(Var{})       = e
-    go _ e@(Lit {})      = e
-    go _ e@(Type {})     = e
-    go _ e@(Coercion {}) = e
+    go _    e@(Var{})       = e
+    go _    e@(Lit {})      = e
+    go _    e@(Type {})     = e
+    go _    e@(Coercion {}) = e
+    go in_scope (Cast e' c) = Cast (go in_scope e') c
+    go in_scope (Tick t e') = Tick t (go in_scope e')
+    go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
 
-    go in_scope (Lam v e')  = Lam v (go in_scope' e')
+    go in_scope (Lam v e')
+      = Lam v (go in_scope' e')
       where in_scope' = in_scope `extendInScopeSet` v
-    go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
+
     go in_scope (Case scrut bndr ty alts)
-        = Case (go in_scope scrut) bndr ty (map (goAlt in_scope') alts)
-      where in_scope' = in_scope `extendInScopeSet` bndr
-    go in_scope (Cast e' c) = Cast (go in_scope e') c
-    go in_scope (Tick t e') = Tick t (go in_scope e')
-    go in_scope (Let bind body) = goBind in_scope bind (go in_scope' body)
-      where in_scope' = in_scope `extendInScopeSetList` bindersOf bind
-
-    goAlt :: InScopeSet -> CoreAlt -> CoreAlt
-    goAlt in_scope (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
-      where in_scope' = in_scope `extendInScopeSetList` pats
-
-    goBind :: InScopeSet -> CoreBind -> (CoreExpr -> CoreExpr)
-    goBind in_scope (NonRec v rhs) = Let (NonRec v (go in_scope rhs))
-    goBind in_scope (Rec pairs)
-        | is_join_rec = exitify in_scope' pairs'
-        | otherwise   = Let (Rec pairs')
-      where pairs' = map (second (go in_scope')) pairs
-            is_join_rec = any (isJoinId . fst) pairs
-            in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
+      = Case (go in_scope scrut) bndr ty (map go_alt alts)
+      where
+        in_scope1 = in_scope `extendInScopeSet` bndr
+        go_alt (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
+           where in_scope' = in_scope1 `extendInScopeSetList` pats
+
+    go in_scope (Let (NonRec bndr rhs) body)
+      = Let (NonRec bndr (go in_scope rhs)) (go in_scope' body)
+      where
+        in_scope' = in_scope `extendInScopeSet` bndr
+
+    go in_scope (Let (Rec pairs) body)
+      | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
+      | otherwise   = Let (Rec pairs') body'
+      where
+        is_join_rec = any (isJoinId . fst) pairs
+        in_scope'   = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
+        pairs'      = mapSnd (go in_scope') pairs
+        body'       = go in_scope' body
+
+
+-- | State Monad used inside `exitify`
+type ExitifyM =  State [(JoinId, CoreExpr)]
 
 -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
 --   join-points outside the joinrec.
-exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
-exitify in_scope pairs =
-    \body ->mkExitLets exits (mkLetRec pairs' body)
+exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
+exitifyRec in_scope pairs
+  = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
   where
-    mkExitLets ((exitId, exitRhs):exits') = mkLetNonRec exitId exitRhs . mkExitLets exits'
-    mkExitLets [] = id
-
     -- We need the set of free variables of many subexpressions here, so
     -- annotate the AST with them
     -- see Note [Calculating free variables]
@@ -116,64 +124,31 @@ exitify in_scope pairs =
             let rhs' = mkLams args body'
             return (x, rhs')
 
-    -- main working function. Goes through the RHS (tail-call positions only),
+    ---------------------
+    -- 'go' is the main working function.
+    -- It goes through the RHS (tail-call positions only),
     -- checks if there are no more recursive calls, if so, abstracts over
     -- variables bound on the way and lifts it out as a join point.
     --
-    -- It uses a state monad to keep track of floated binds
-    go :: [Var]           -- ^ variables to abstract over
-       -> CoreExprWithFVs -- ^ current expression in tail position
-       -> State [(Id, CoreExpr)] CoreExpr
-
+    -- ExitifyM is a state monad to keep track of floated binds
+    go :: [Var]           -- ^ Variables that are in-scope here, but
+                          -- not in scope at the joinrec; that is,
+                          -- we must potentially abstract over them.
+                          -- Invariant: they are kept in dependency order
+       -> CoreExprWithFVs -- ^ Current expression in tail position
+       -> ExitifyM CoreExpr
+
+    -- We first look at the expression (no matter what it shape is)
+    -- and determine if we can turn it into a exit join point
     go captured ann_e
-        -- Do not touch an expression that is already a join jump where all arguments
-        -- are captured variables. See Note [Idempotency]
-        -- But _do_ float join jumps with interesting arguments.
-        -- See Note [Jumps can be interesting]
-        | (Var f, args) <- collectArgs e
-        , isJoinId f
-        , all isCapturedVarArg args
-        = return e
-
-        -- Do not touch a boring expression (see Note [Interesting expression])
-        | is_exit, not is_interesting = return e
-
-        -- Cannot float out if local join points are used, as
-        -- we cannot abstract over them
-        | is_exit, captures_join_points = return e
-
-        -- We have something to float out!
-        | is_exit = do
-            -- Assemble the RHS of the exit join point
-            let rhs = mkLams args e
-                ty = exprType rhs
-            let avoid = in_scope `extendInScopeSetList` captured
-            -- Remember this binding under a suitable name
-            v <- addExit avoid ty (length args) rhs
-            -- And jump to it from here
-            return $ mkVarApps (Var v) args
-      where
-        -- An exit expression has no recursive calls
-        is_exit = disjointVarSet fvs recursive_calls
-
-        -- Used to detect exit expressoins that are already proper exit jumps
-        isCapturedVarArg (Var v) = v `elem` captured
-        isCapturedVarArg _ = False
-
-        -- An interesting exit expression has free, non-imported
-        -- variables from outside the recursive group
-        -- See Note [Interesting expression]
-        is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
-
-        -- The possible arguments of this exit join point
-        args = filter (`elemVarSet` fvs) captured
-
-        -- We cannot abstract over join points
-        captures_join_points = any isJoinId args
-
-        e = deAnnotate ann_e
-        fvs = dVarSetToVarSet (freeVarsOf ann_e)
+        | -- An exit expression has no recursive calls
+          let fvs = dVarSetToVarSet (freeVarsOf ann_e)
+        , disjointVarSet fvs recursive_calls
+        = go_exit captured (deAnnotate ann_e) fvs
 
+    -- We could not turn it into a exit joint point. So now recurse
+    -- into all expression where eligible exit join points might sit,
+    -- i.e. into all tail-call positions:
 
     -- Case right hand sides are in tail-call position
     go captured (_, AnnCase scrut bndr ty alts) = do
@@ -211,8 +186,73 @@ exitify in_scope pairs =
              return $ Let bind body'
       where bind = deAnnBind ann_bind
 
+    -- Cannot be turned into an exit join point, but also has no
+    -- tail-call subexpression. Nothing to do here.
     go _ ann_e = return (deAnnotate ann_e)
 
+    ---------------------
+    go_exit :: [Var]      -- Variables captured locally
+            -> CoreExpr   -- An exit expression
+            -> VarSet     -- Free vars of the expression
+            -> ExitifyM CoreExpr
+    -- go_exit deals with a tail expression that is floatable
+    -- out as an exit point; that is, it mentions no recursive calls
+    go_exit captured e fvs
+      -- Do not touch an expression that is already a join jump where all arguments
+      -- are captured variables. See Note [Idempotency]
+      -- But _do_ float join jumps with interesting arguments.
+      -- See Note [Jumps can be interesting]
+      | (Var f, args) <- collectArgs e
+      , isJoinId f
+      , all isCapturedVarArg args
+      = return e
+
+      -- Do not touch a boring expression (see Note [Interesting expression])
+      | not is_interesting
+      = return e
+
+      -- Cannot float out if local join points are used, as
+      -- we cannot abstract over them
+      | captures_join_points
+      = return e
+
+      -- We have something to float out!
+      | otherwise
+      = do { -- Assemble the RHS of the exit join point
+             let rhs   = mkLams abs_vars e
+                 avoid = in_scope `extendInScopeSetList` captured
+             -- Remember this binding under a suitable name
+           ; v <- addExit avoid (length abs_vars) rhs
+             -- And jump to it from here
+           ; return $ mkVarApps (Var v) abs_vars }
+
+      where
+        -- Used to detect exit expressoins that are already proper exit jumps
+        isCapturedVarArg (Var v) = v `elem` captured
+        isCapturedVarArg _ = False
+
+        -- An interesting exit expression has free, non-imported
+        -- variables from outside the recursive group
+        -- See Note [Interesting expression]
+        is_interesting = anyVarSet isLocalId $
+                         fvs `minusVarSet` mkVarSet captured
+
+        -- The arguments of this exit join point
+        -- See Note [Picking arguments to abstract over]
+        abs_vars = snd $ foldr pick (fvs, []) captured
+          where
+            pick v (fvs', acc) | v `elemVarSet` fvs' = (fvs' `delVarSet` v, zap v : acc)
+                               | otherwise           = (fvs',               acc)
+
+        -- We are going to abstract over these variables, so we must
+        -- zap any IdInfo they have; see Trac #15005
+        -- cf. SetLevels.abstractVars
+        zap v | isId v = setIdInfo v vanillaIdInfo
+              | otherwise = v
+
+        -- We cannot abstract over join points
+        captures_join_points = any isJoinId abs_vars
+
 
 -- Picks a new unique, which is disjoint from
 --  * the free variables of the whole joinrec
@@ -227,30 +267,19 @@ mkExitJoinId in_scope ty join_arity = do
   where
     exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
                     `asJoinId` join_arity
-                    `setIdOccInfo` exit_occ_info
 
-    -- See Note [Do not inline exit join points]
-    exit_occ_info =
-        OneOcc { occ_in_lam = True
-               , occ_one_br = True
-               , occ_int_cxt = False
-               , occ_tail = AlwaysTailCalled join_arity }
-
-addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
-addExit in_scope ty join_arity rhs = do
+addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
+addExit in_scope join_arity rhs = do
     -- Pick a suitable name
+    let ty = exprType rhs
     v <- mkExitJoinId in_scope ty join_arity
     fs <- get
     put ((v,rhs):fs)
     return v
 
-
-type ExitifyM =  State [(JoinId, CoreExpr)]
-
 {-
 Note [Interesting expression]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 We do not want this to happen:
 
   joinrec go 0     x y = x
@@ -291,7 +320,6 @@ non-imported variable.
 
 Note [Jumps can be interesting]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 A jump to a join point can be interesting, if its arguments contain free
 non-exported variables (z in the following example):
 
@@ -304,16 +332,34 @@ non-exported variables (z in the following example):
           go (n-1) x y = jump go (n-1) (x+y)
 
 
-The join point itself can be interesting, even if none if
-its arguments are (assume `g` to be an imported function that, on its own, does
-not make this interesting):
+The join point itself can be interesting, even if none if its
+arguments have free variables free in the joinrec.  For example
+
+  join j p = case p of (x,y) -> x+y
+  joinrec go 0     x y = jump j (x,y)
+          go (n-1) x y = jump go (n-1) (x+y) y
+  in …
+
+Here, `j` would not be inlined because we do not inline something that looks
+like an exit join point (see Note [Do not inline exit join points]). But
+if we exitify the 'jump j (x,y)' we get
+
+  join j p = case p of (x,y) -> x+y
+  join exit x y = jump j (x,y)
+  joinrec go 0     x y = jump exit x y
+          go (n-1) x y = jump go (n-1) (x+y) y
+  in …
+
+and now 'j' can inline, and we get rid of the pair. Here's another
+example (assume `g` to be an imported function that, on its own,
+does not make this interesting):
 
   join j y = map f y
   joinrec go 0     x y = jump j (map g x)
           go (n-1) x y = jump go (n-1) (x+y)
   in …
 
-Here, `j` would not be inlined because we do not inline something that looks
+Again, `j` would not be inlined because we do not inline something that looks
 like an exit join point (see Note [Do not inline exit join points]).
 
 But after exitification we have
@@ -353,7 +399,6 @@ interesting expressions.
 
 Note [Calculating free variables]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 We have two options where to annotate the tree with free variables:
 
  A) The whole tree.
@@ -366,10 +411,11 @@ joinrecs are nested.
 Further downside of A: If the exitify function returns annotated expressions,
 it would have to ensure that the annotations are correct.
 
+We therefore choose B, and calculate the free variables in `exitify`.
+
 
 Note [Do not inline exit join points]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 When we have
 
   let t = foo bar
@@ -385,7 +431,8 @@ To prevent this, we need to recognize exit join points, and then disable
 inlining.
 
 Exit join points, recognizeable using `isExitJoinId` are join points with an
-occurence in a recursive group, and can be recognized using `isExitJoinId`.
+occurence in a recursive group, and can be recognized (after the occurence
+analyzer ran!) using `isExitJoinId`.
 This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
 because the lambdas of a non-recursive join point are not considered for
 `occ_in_lam`.  For example, in the following code, `j1` is /not/ marked
@@ -394,16 +441,13 @@ occ_in_lam, because `j2` is called only once.
   join j1 x = x+1
   join j2 y = join j1 (y+2)
 
-We create exit join point ids with such an `OccInfo`, see `exit_occ_info`.
-
-To prevent inlining, we check for that in `preInlineUnconditionally` directly.
-For `postInlineUnconditionally` and unfolding-based inlining, the function
-`simplLetUnfolding` simply gives exit join points no unfolding, which prevents
-this kind of inlining.
+To prevent inlining, we check for isExitJoinId
+* In `preInlineUnconditionally` directly.
+* In `simplLetUnfolding` we simply give exit join points no unfolding, which
+  prevents inlining in `postInlineUnconditionally` and call sites.
 
 Note [Placement of the exitification pass]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 I (Joachim) experimented with multiple positions for the Exitification pass in
 the Core2Core pipeline:
 
@@ -439,4 +483,17 @@ Positions C and D have their advantages: C decreases allocations in simpl, but D
 Assuming we have a budget of _one_ run of Exitification, then C wins (but we
 could get more from running it multiple times, as seen in fish).
 
+Note [Picking arguments to abstract over]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When we create an exit join point, so we need to abstract over those of its
+free variables that are be out-of-scope at the destination of the exit join
+point. So we go through the list `captured` and pick those that are actually
+free variables of the join point.
+
+We do not just `filter (`elemVarSet` fvs) captured`, as there might be
+shadowing, and `captured` may contain multiple variables with the same Unique. I
+these cases we want to abstract only over the last occurence, hence the `foldr`
+(with emphasis on the `r`). This is #15110.
+
 -}