Some cleanup of the Exitification code
authorJoachim Breitner <mail@joachim-breitner.de>
Fri, 6 Apr 2018 21:26:45 +0000 (17:26 -0400)
committerJoachim Breitner <mail@joachim-breitner.de>
Mon, 9 Apr 2018 15:25:06 +0000 (11:25 -0400)
based on a thorough review by Simon in comments
https://ghc.haskell.org/trac/ghc/ticket/14152#comment:33
through 37.

The changes are:

 * `isExitJoinId` is moved to `SimplUtils`, because
   it is only valid when occurrence information is up-to-date.
 * Abstracted variables are properly sorted using `sortQuantVars`
 * Exitification does not set occ info.

 And then minor quibles to notes and avoiding some unhelpful shadowing
 of local names.

 Differential Revision: https://phabricator.haskell.org/D4576

compiler/basicTypes/Id.hs
compiler/simplCore/Exitify.hs
compiler/simplCore/SimplUtils.hs

index 709bea4..bf48bad 100644 (file)
@@ -74,7 +74,7 @@ module Id (
         DictId, isDictId, isEvVar,
 
         -- ** Join variables
-        JoinId, isJoinId, isJoinId_maybe, idJoinArity, isExitJoinId,
+        JoinId, isJoinId, isJoinId_maybe, idJoinArity,
         asJoinId, asJoinId_maybe, zapJoinId,
 
         -- ** Inline pragma stuff
@@ -498,10 +498,6 @@ isJoinId_maybe id
                 _            -> Nothing
  | otherwise = Nothing
 
--- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs
-isExitJoinId :: Var -> Bool
-isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
-
 idDataCon :: Id -> DataCon
 -- ^ Get from either the worker or the wrapper 'Id' to the 'DataCon'. Currently used only in the desugarer.
 --
index cf6a930..570186e 100644 (file)
@@ -48,16 +48,19 @@ import VarEnv
 import CoreFVs
 import FastString
 import Type
+import MkCore ( sortQuantVars )
 
 import Data.Bifunctor
 import Control.Monad
 
 -- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
+-- The really interesting function is exitify
 exitifyProgram :: CoreProgram -> CoreProgram
 exitifyProgram binds = map goTopLvl binds
   where
     goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
     goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
+      -- Top-level bindings are never join points
 
     in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
 
@@ -91,6 +94,10 @@ exitifyProgram binds = map goTopLvl binds
             is_join_rec = any (isJoinId . fst) pairs
             in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
 
+
+-- | State Monad used inside `exitify`
+type ExitifyM =  State [(JoinId, CoreExpr)]
+
 -- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
 --   join-points outside the joinrec.
 exitify :: InScopeSet -> [(Var,CoreExpr)] -> (CoreExpr -> CoreExpr)
@@ -120,11 +127,13 @@ exitify in_scope pairs =
     -- checks if there are no more recursive calls, if so, abstracts over
     -- variables bound on the way and lifts it out as a join point.
     --
-    -- It uses a state monad to keep track of floated binds
+    -- ExitifyM is a state monad to keep track of floated binds
     go :: [Var]           -- ^ variables to abstract over
        -> CoreExprWithFVs -- ^ current expression in tail position
-       -> State [(Id, CoreExpr)] CoreExpr
+       -> ExitifyM CoreExpr
 
+    -- We first look at the expression (no matter what it shape is)
+    -- and determine if we can turn it into a exit join point
     go captured ann_e
         -- Do not touch an expression that is already a join jump where all arguments
         -- are captured variables. See Note [Idempotency]
@@ -145,13 +154,13 @@ exitify in_scope pairs =
         -- We have something to float out!
         | is_exit = do
             -- Assemble the RHS of the exit join point
-            let rhs = mkLams args e
+            let rhs = mkLams abs_vars e
                 ty = exprType rhs
             let avoid = in_scope `extendInScopeSetList` captured
             -- Remember this binding under a suitable name
-            v <- addExit avoid ty (length args) rhs
+            v <- addExit avoid ty (length abs_vars) rhs
             -- And jump to it from here
-            return $ mkVarApps (Var v) args
+            return $ mkVarApps (Var v) abs_vars
       where
         -- An exit expression has no recursive calls
         is_exit = disjointVarSet fvs recursive_calls
@@ -166,14 +175,17 @@ exitify in_scope pairs =
         is_interesting = anyVarSet isLocalId (fvs `minusVarSet` mkVarSet captured)
 
         -- The possible arguments of this exit join point
-        args = filter (`elemVarSet` fvs) captured
+        abs_vars = sortQuantVars $ filter (`elemVarSet` fvs) captured
 
         -- We cannot abstract over join points
-        captures_join_points = any isJoinId args
+        captures_join_points = any isJoinId abs_vars
 
         e = deAnnotate ann_e
         fvs = dVarSetToVarSet (freeVarsOf ann_e)
 
+    -- We could not turn it into a exit joint point. So now recurse
+    -- into all expression where eligible exit join points might sit,
+    -- i.e. into all tail-call positions:
 
     -- Case right hand sides are in tail-call position
     go captured (_, AnnCase scrut bndr ty alts) = do
@@ -211,6 +223,8 @@ exitify in_scope pairs =
              return $ Let bind body'
       where bind = deAnnBind ann_bind
 
+    -- Cannot be turned into an exit join point, but also has no
+    -- tail-call subexpression. Nothing to do here.
     go _ ann_e = return (deAnnotate ann_e)
 
 
@@ -227,14 +241,6 @@ mkExitJoinId in_scope ty join_arity = do
   where
     exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
                     `asJoinId` join_arity
-                    `setIdOccInfo` exit_occ_info
-
-    -- See Note [Do not inline exit join points]
-    exit_occ_info =
-        OneOcc { occ_in_lam = True
-               , occ_one_br = True
-               , occ_int_cxt = False
-               , occ_tail = AlwaysTailCalled join_arity }
 
 addExit :: InScopeSet -> Type -> JoinArity -> CoreExpr -> ExitifyM JoinId
 addExit in_scope ty join_arity rhs = do
@@ -245,8 +251,6 @@ addExit in_scope ty join_arity rhs = do
     return v
 
 
-type ExitifyM =  State [(JoinId, CoreExpr)]
-
 {-
 Note [Interesting expression]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -381,6 +385,8 @@ joinrecs are nested.
 Further downside of A: If the exitify function returns annotated expressions,
 it would have to ensure that the annotations are correct.
 
+We therefore choose B, and calculate the free variables in `exitify`.
+
 
 Note [Do not inline exit join points]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -399,7 +405,8 @@ To prevent this, we need to recognize exit join points, and then disable
 inlining.
 
 Exit join points, recognizeable using `isExitJoinId` are join points with an
-occurence in a recursive group, and can be recognized using `isExitJoinId`.
+occurence in a recursive group, and can be recognized (after the occurence
+analyzer ran!) using `isExitJoinId`.
 This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
 because the lambdas of a non-recursive join point are not considered for
 `occ_in_lam`.  For example, in the following code, `j1` is /not/ marked
@@ -408,8 +415,6 @@ occ_in_lam, because `j2` is called only once.
   join j1 x = x+1
   join j2 y = join j1 (y+2)
 
-We create exit join point ids with such an `OccInfo`, see `exit_occ_info`.
-
 To prevent inlining, we check for isExitJoinId
 * In `preInlineUnconditionally` directly.
 * In `simplLetUnfolding` we simply give exit join points no unfolding, which
index db26af4..7c0689d 100644 (file)
@@ -30,7 +30,10 @@ module SimplUtils (
         addValArgTo, addCastTo, addTyArgTo,
         argInfoExpr, argInfoAppArgs, pushSimplifiedArgs,
 
-        abstractFloats
+        abstractFloats,
+
+        -- Utilities
+        isExitJoinId
     ) where
 
 #include "HsVersions.h"
@@ -2199,6 +2202,13 @@ in PrelRules)
 mkCase3 _dflags scrut bndr alts_ty alts
   = return (Case scrut bndr alts_ty alts)
 
+-- See Note [Exitification] and Note [Do not inline exit join points] in Exitify.hs
+-- This lives here (and not in Id) becuase occurrence info is only valid on
+-- InIds, so it's crucial that isExitJoinId is only called on freshly
+-- occ-analysed code. It's not a generic function you can call anywhere.
+isExitJoinId :: Var -> Bool
+isExitJoinId id = isJoinId id && isOneOcc (idOccInfo id) && occ_in_lam (idOccInfo id)
+
 {-
 Note [Dead binders]
 ~~~~~~~~~~~~~~~~~~~~