Implement late lambda lift
authorSebastian Graf <sebastian.graf@kit.edu>
Fri, 23 Nov 2018 15:24:18 +0000 (16:24 +0100)
committerSebastian Graf <sebastian.graf@kit.edu>
Fri, 23 Nov 2018 15:26:02 +0000 (16:26 +0100)
Summary:
This implements a selective lambda-lifting pass late in the STG
pipeline.

Lambda lifting has the effect of avoiding closure allocation at the cost
of having to make former free vars available at call sites, possibly
enlarging closures surrounding call sites in turn.

We identify beneficial cases by means of an analysis that estimates
closure growth.

There's a Wiki page at
https://ghc.haskell.org/trac/ghc/wiki/LateLamLift.

Reviewers: simonpj, bgamari, simonmar

Reviewed By: simonpj

Subscribers: rwbarton, carter

GHC Trac Issues: #9476

Differential Revision: https://phabricator.haskell.org/D5224

24 files changed:
compiler/basicTypes/Demand.hs
compiler/basicTypes/Id.hs
compiler/codeGen/StgCmm.hs
compiler/codeGen/StgCmmBind.hs
compiler/codeGen/StgCmmExpr.hs
compiler/ghc.cabal.in
compiler/main/DynFlags.hs
compiler/simplStg/SimplStg.hs
compiler/simplStg/StgCse.hs
compiler/simplStg/StgLiftLams.hs [new file with mode: 0644]
compiler/simplStg/StgLiftLams/Analysis.hs [new file with mode: 0644]
compiler/simplStg/StgLiftLams/LiftM.hs [new file with mode: 0644]
compiler/simplStg/StgLiftLams/Transformation.hs [new file with mode: 0644]
compiler/simplStg/StgStats.hs
compiler/simplStg/UnariseStg.hs
compiler/stgSyn/CoreToStg.hs
compiler/stgSyn/StgFVs.hs
compiler/stgSyn/StgLint.hs
compiler/stgSyn/StgSubst.hs [new file with mode: 0644]
compiler/stgSyn/StgSyn.hs
docs/users_guide/using-optimisation.rst
inplace/test [deleted file]
inplace/test spaces [deleted symlink]
testsuite/tests/perf/join_points/all.T

index 4707be7..8884542 100644 (file)
@@ -10,7 +10,7 @@
 module Demand (
         StrDmd, UseDmd(..), Count,
 
-        Demand, CleanDemand, getStrDmd, getUseDmd,
+        Demand, DmdShell, CleanDemand, getStrDmd, getUseDmd,
         mkProdDmd, mkOnceUsedDmd, mkManyUsedDmd, mkHeadStrict, oneifyDmd,
         toCleanDmd,
         absDmd, topDmd, botDmd, seqDmd,
@@ -48,9 +48,9 @@ module Demand (
         deferAfterIO,
         postProcessUnsat, postProcessDmdType,
 
-        splitProdDmd_maybe, peelCallDmd, mkCallDmd, mkWorkerDemand,
-        dmdTransformSig, dmdTransformDataConSig, dmdTransformDictSelSig,
-        argOneShots, argsOneShots, saturatedByOneShots,
+        splitProdDmd_maybe, peelCallDmd, peelManyCalls, mkCallDmd,
+        mkWorkerDemand, dmdTransformSig, dmdTransformDataConSig,
+        dmdTransformDictSelSig, argOneShots, argsOneShots, saturatedByOneShots,
         trimToType, TypeShape(..),
 
         useCount, isUsedOnce, reuseEnv,
@@ -787,7 +787,7 @@ botDmd = JD { sd = strBot, ud = useBot }
 seqDmd :: Demand
 seqDmd = JD { sd = Str VanStr HeadStr, ud = Use One UHead }
 
-oneifyDmd :: Demand -> Demand
+oneifyDmd :: JointDmd s (Use u) -> JointDmd s (Use u)
 oneifyDmd (JD { sd = s, ud = Use _ a }) = JD { sd = s, ud = Use One a }
 oneifyDmd jd                            = jd
 
@@ -796,7 +796,7 @@ isTopDmd :: Demand -> Bool
 isTopDmd (JD {sd = Lazy, ud = Use Many Used}) = True
 isTopDmd _                                    = False
 
-isAbsDmd :: Demand -> Bool
+isAbsDmd :: JointDmd (Str s) (Use u) -> Bool
 isAbsDmd (JD {ud = Abs}) = True   -- The strictness part can be HyperStr
 isAbsDmd _               = False  -- for a bottom demand
 
@@ -804,7 +804,7 @@ isSeqDmd :: Demand -> Bool
 isSeqDmd (JD {sd = Str VanStr HeadStr, ud = Use _ UHead}) = True
 isSeqDmd _                                                = False
 
-isUsedOnce :: Demand -> Bool
+isUsedOnce :: JointDmd (Str s) (Use u) -> Bool
 isUsedOnce (JD { ud = a }) = case useCount a of
                                One  -> True
                                Many -> False
@@ -817,7 +817,7 @@ seqDemandList :: [Demand] -> ()
 seqDemandList [] = ()
 seqDemandList (d:ds) = seqDemand d `seq` seqDemandList ds
 
-isStrictDmd :: Demand -> Bool
+isStrictDmd :: JointDmd (Str s) (Use u) -> Bool
 -- See Note [Strict demands]
 isStrictDmd (JD {ud = Abs})  = False
 isStrictDmd (JD {sd = Lazy}) = False
index 78518ee..5e91d26 100644 (file)
@@ -897,9 +897,10 @@ zapStableUnfolding id
 {-
 Note [transferPolyIdInfo]
 ~~~~~~~~~~~~~~~~~~~~~~~~~
-This transfer is used in two places:
+This transfer is used in three places:
         FloatOut (long-distance let-floating)
         SimplUtils.abstractFloats (short-distance let-floating)
+        StgLiftLams (selectively lambda-lift local functions to top-level)
 
 Consider the short-distance let-floating:
 
index 59ceba8..acd2aee 100644 (file)
@@ -45,7 +45,7 @@ import Module
 import Outputable
 import Stream
 import BasicTypes
-import VarSet ( isEmptyVarSet )
+import VarSet ( isEmptyDVarSet )
 
 import OrdList
 import MkGraph
@@ -156,7 +156,7 @@ cgTopRhs dflags _rec bndr (StgRhsCon _cc con args)
       -- see Note [Post-unarisation invariants] in UnariseStg
 
 cgTopRhs dflags rec bndr (StgRhsClosure fvs cc upd_flag args body)
-  = ASSERT(isEmptyVarSet fvs)    -- There should be no free variables
+  = ASSERT(isEmptyDVarSet fvs)    -- There should be no free variables
     cgTopRhsClosure dflags rec bndr cc upd_flag args body
 
 
index dba122f..9e14311 100644 (file)
@@ -44,7 +44,7 @@ import Name
 import Module
 import ListSetOps
 import Util
-import UniqSet ( nonDetEltsUniqSet )
+import VarSet
 import BasicTypes
 import Outputable
 import FastString
@@ -209,10 +209,7 @@ cgRhs id (StgRhsCon cc con args)
 {- See Note [GC recovery] in compiler/codeGen/StgCmmClosure.hs -}
 cgRhs id (StgRhsClosure fvs cc upd_flag args body)
   = do dflags <- getDynFlags
-       mkRhsClosure dflags id cc (nonVoidIds (nonDetEltsUniqSet fvs)) upd_flag args body
-       -- It's OK to use nonDetEltsUniqSet here because we're not aiming for
-       -- bit-for-bit determinism.
-       -- See Note [Unique Determinism and code generation]
+       mkRhsClosure dflags id cc (nonVoidIds (dVarSetElems fvs)) upd_flag args body
 
 ------------------------------------------------------------------------
 --              Non-constructor right hand sides
index 2430a0d..5844161 100644 (file)
@@ -81,8 +81,8 @@ cgExpr (StgTick t e)         = cgTick t >> cgExpr e
 cgExpr (StgLit lit)       = do cmm_lit <- cgLit lit
                                emitReturn [CmmLit cmm_lit]
 
-cgExpr (StgLet binds expr)             = do { cgBind binds;     cgExpr expr }
-cgExpr (StgLetNoEscape binds expr) =
+cgExpr (StgLet _ binds expr) = do { cgBind binds;     cgExpr expr }
+cgExpr (StgLetNoEscape binds expr) =
   do { u <- newUnique
      ; let join_id = mkBlockId u
      ; cgLneBinds join_id binds
index 893f959..a99c6e7 100644 (file)
@@ -433,6 +433,11 @@ Library
         SimplStg
         StgStats
         StgCse
+        StgLiftLams
+        StgLiftLams.Analysis
+        StgLiftLams.LiftM
+        StgLiftLams.Transformation
+        StgSubst
         UnariseStg
         RepType
         Rules
index a93da7b..b574ba9 100644 (file)
@@ -465,6 +465,7 @@ data GeneralFlag
    | Opt_StaticArgumentTransformation
    | Opt_CSE
    | Opt_StgCSE
+   | Opt_StgLiftLams
    | Opt_LiberateCase
    | Opt_SpecConstr
    | Opt_SpecConstrKeen
@@ -672,6 +673,7 @@ optimisationFlags = EnumSet.fromList
    , Opt_StaticArgumentTransformation
    , Opt_CSE
    , Opt_StgCSE
+   , Opt_StgLiftLams
    , Opt_LiberateCase
    , Opt_SpecConstr
    , Opt_SpecConstrKeen
@@ -903,6 +905,13 @@ data DynFlags = DynFlags {
   floatLamArgs          :: Maybe Int,   -- ^ Arg count for lambda floating
                                         --   See CoreMonad.FloatOutSwitches
 
+  liftLamsRecArgs       :: Maybe Int,   -- ^ Maximum number of arguments after lambda lifting a
+                                        --   recursive function.
+  liftLamsNonRecArgs    :: Maybe Int,   -- ^ Maximum number of arguments after lambda lifting a
+                                        --   non-recursive function.
+  liftLamsKnown         :: Bool,        -- ^ Lambda lift even when this turns a known call
+                                        --   into an unknown call.
+
   cmmProcAlignment      :: Maybe Int,   -- ^ Align Cmm functions at this boundary or use default.
 
   historySize           :: Int,         -- ^ Simplification history size
@@ -1865,6 +1874,9 @@ defaultDynFlags mySettings (myLlvmTargets, myLlvmPasses) =
         specConstrRecursive     = 3,
         liberateCaseThreshold   = Just 2000,
         floatLamArgs            = Just 0, -- Default: float only if no fvs
+        liftLamsRecArgs         = Just 5, -- Default: the number of available argument hardware registers on x86_64
+        liftLamsNonRecArgs      = Just 5, -- Default: the number of available argument hardware registers on x86_64
+        liftLamsKnown           = False,  -- Default: don't turn known calls into unknown ones
         cmmProcAlignment        = Nothing,
 
         historySize             = 20,
@@ -3522,6 +3534,18 @@ dynamic_flags_deps = [
       (intSuffix (\n d -> d { floatLamArgs = Just n }))
   , make_ord_flag defFlag "ffloat-all-lams"
       (noArg (\d -> d { floatLamArgs = Nothing }))
+  , make_ord_flag defFlag "fstg-lift-lams-rec-args"
+      (intSuffix (\n d -> d { liftLamsRecArgs = Just n }))
+  , make_ord_flag defFlag "fstg-lift-lams-rec-args-any"
+      (noArg (\d -> d { liftLamsRecArgs = Nothing }))
+  , make_ord_flag defFlag "fstg-lift-lams-non-rec-args"
+      (intSuffix (\n d -> d { liftLamsRecArgs = Just n }))
+  , make_ord_flag defFlag "fstg-lift-lams-non-rec-args-any"
+      (noArg (\d -> d { liftLamsRecArgs = Nothing }))
+  , make_ord_flag defFlag "fstg-lift-lams-known"
+      (noArg (\d -> d { liftLamsKnown = True }))
+  , make_ord_flag defFlag "fno-stg-lift-lams-known"
+      (noArg (\d -> d { liftLamsKnown = False }))
   , make_ord_flag defFlag "fproc-alignment"
       (intSuffix (\n d -> d { cmmProcAlignment = Just n }))
   , make_ord_flag defFlag "fblock-layout-weights"
@@ -4016,6 +4040,7 @@ fFlagsDeps = [
   flagSpec "cmm-sink"                         Opt_CmmSink,
   flagSpec "cse"                              Opt_CSE,
   flagSpec "stg-cse"                          Opt_StgCSE,
+  flagSpec "stg-lift-lams"                    Opt_StgLiftLams,
   flagSpec "cpr-anal"                         Opt_CprAnal,
   flagSpec "defer-type-errors"                Opt_DeferTypeErrors,
   flagSpec "defer-typed-holes"                Opt_DeferTypedHoles,
@@ -4546,6 +4571,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
     , ([1,2],   Opt_CmmSink)
     , ([1,2],   Opt_CSE)
     , ([1,2],   Opt_StgCSE)
+    , ([2],     Opt_StgLiftLams)
     , ([1,2],   Opt_EnableRewriteRules)  -- Off for -O0; see Note [Scoping for Builtin rules]
                                          --              in PrelRules
     , ([1,2],   Opt_FloatIn)
index 830dd19..327f614 100644 (file)
@@ -5,6 +5,7 @@
 -}
 
 {-# LANGUAGE CPP #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
 
 module SimplStg ( stg2stg ) where
 
@@ -18,12 +19,25 @@ import StgLint          ( lintStgTopBindings )
 import StgStats         ( showStgStats )
 import UnariseStg       ( unarise )
 import StgCse           ( stgCse )
+import StgLiftLams      ( stgLiftLams )
 
 import DynFlags
 import ErrUtils
-import UniqSupply       ( mkSplitUniqSupply )
+import UniqSupply
 import Outputable
 import Control.Monad
+import Control.Monad.IO.Class
+import Control.Monad.Trans.State.Strict
+
+newtype StgM a = StgM { _unStgM :: StateT UniqSupply IO a }
+  deriving (Functor, Applicative, Monad, MonadIO)
+
+instance MonadUnique StgM where
+  getUniqueSupplyM = StgM (state splitUniqSupply)
+  getUniqueM = StgM (state takeUniqFromSupply)
+
+runStgM :: UniqSupply -> StgM a -> IO a
+runStgM us (StgM m) = evalStateT m us
 
 stg2stg :: DynFlags                  -- includes spec of what stg-to-stg passes to do
         -> [StgTopBinding]           -- input...
@@ -33,46 +47,56 @@ stg2stg dflags binds
   = do  { showPass dflags "Stg2Stg"
         ; us <- mkSplitUniqSupply 'g'
 
-                -- Do the main business!
-        ; dumpIfSet_dyn dflags Opt_D_dump_stg "Pre unarise:"
-                        (pprStgTopBindings binds)
+        -- Do the main business!
+        ; binds' <- runStgM us $
+            foldM do_stg_pass binds (getStgToDo dflags)
 
-        ; stg_linter False "Pre-unarise" binds
-        ; let un_binds = unarise us binds
-        ; stg_linter True "Unarise" un_binds
-         -- Important that unarisation comes first
-         -- See Note [StgCse after unarisation] in StgCse
+        ; dump_when Opt_D_dump_stg "STG syntax:" binds'
 
-        ; dumpIfSet_dyn dflags Opt_D_dump_stg "STG syntax:"
-                        (pprStgTopBindings un_binds)
-
-        ; foldM do_stg_pass un_binds (getStgToDo dflags)
-        }
+        ; return binds'
+   }
 
   where
-    stg_linter unarised
-      | gopt Opt_DoStgLinting dflags = lintStgTopBindings dflags unarised
+    stg_linter what
+      | gopt Opt_DoStgLinting dflags = lintStgTopBindings dflags what
       | otherwise                    = \ _whodunnit _binds -> return ()
 
     -------------------------------------------
+    do_stg_pass :: [StgTopBinding] -> StgToDo -> StgM [StgTopBinding]
     do_stg_pass binds to_do
       = case to_do of
-          D_stg_stats ->
-             trace (showStgStats binds) (return binds)
+          StgDoNothing ->
+            return binds
+
+          StgStats ->
+            trace (showStgStats binds) (return binds)
 
-          StgCSE ->
-             {-# SCC "StgCse" #-}
-             let
-                 binds' = stgCse binds
-             in
-             end_pass "StgCse" binds'
+          StgCSE -> do
+            let binds' = {-# SCC "StgCse" #-} stgCse binds
+            end_pass "StgCse" binds'
+
+          StgLiftLams -> do
+            us <- getUniqueSupplyM
+            let binds' = {-# SCC "StgLiftLams" #-} stgLiftLams dflags us binds
+            end_pass "StgLiftLams" binds'
+
+          StgUnarise -> do
+            dump_when Opt_D_dump_stg "Pre unarise:" binds
+            us <- getUniqueSupplyM
+            liftIO (stg_linter False "Pre-unarise" binds)
+            let binds' = unarise us binds
+            liftIO (stg_linter True "Unarise" binds')
+            return binds'
+
+    dump_when flag header binds
+      = liftIO (dumpIfSet_dyn dflags flag header (pprStgTopBindings binds))
 
     end_pass what binds2
-      = do -- report verbosely, if required
-           dumpIfSet_dyn dflags Opt_D_verbose_stg2stg what
-              (pprStgTopBindings binds2)
-           stg_linter True what binds2
-           return binds2
+      = liftIO $ do -- report verbosely, if required
+          dumpIfSet_dyn dflags Opt_D_verbose_stg2stg what
+            (vcat (map ppr binds2))
+          stg_linter False what binds2
+          return binds2
 
 -- -----------------------------------------------------------------------------
 -- StgToDo:  abstraction of stg-to-stg passes to run.
@@ -80,12 +104,31 @@ stg2stg dflags binds
 -- | Optional Stg-to-Stg passes.
 data StgToDo
   = StgCSE
-  | D_stg_stats
-
--- | Which optional Stg-to-Stg passes to run. Depends on flags, ways etc.
+  -- ^ Common subexpression elimination
+  | StgLiftLams
+  -- ^ Lambda lifting closure variables, trading stack/register allocation for
+  -- heap allocation
+  | StgStats
+  | StgUnarise
+  -- ^ Mandatory unarise pass, desugaring unboxed tuple and sum binders
+  | StgDoNothing
+  -- ^ Useful for building up 'getStgToDo'
+  deriving Eq
+
+-- | Which Stg-to-Stg passes to run. Depends on flags, ways etc.
 getStgToDo :: DynFlags -> [StgToDo]
-getStgToDo dflags
-  = [ StgCSE                   | gopt Opt_StgCSE dflags] ++
-    [ D_stg_stats              | stg_stats ]
-  where
-        stg_stats = gopt Opt_StgStats dflags
+getStgToDo dflags =
+  filter (/= StgDoNothing)
+    [ mandatory StgUnarise
+    -- Important that unarisation comes first
+    -- See Note [StgCse after unarisation] in StgCse
+    , optional Opt_StgCSE StgCSE
+    , optional Opt_StgLiftLams StgLiftLams
+    , optional Opt_StgStats StgStats
+    ] where
+      optional opt = runWhen (gopt opt dflags)
+      mandatory = id
+
+runWhen :: Bool -> StgToDo -> StgToDo
+runWhen True todo = todo
+runWhen _    _    = StgDoNothing
index a22a7c1..fbccf80 100644 (file)
@@ -331,14 +331,14 @@ stgCseExpr env (StgConApp dataCon args tys)
 -- The binding might be removed due to CSE (we do not want trivial bindings on
 -- the STG level), so use the smart constructor `mkStgLet` to remove the binding
 -- if empty.
-stgCseExpr env (StgLet binds body)
+stgCseExpr env (StgLet ext binds body)
     = let (binds', env') = stgCseBind env binds
           body' = stgCseExpr env' body
-      in mkStgLet StgLet binds' body'
-stgCseExpr env (StgLetNoEscape binds body)
+      in mkStgLet (StgLet ext) binds' body'
+stgCseExpr env (StgLetNoEscape ext binds body)
     = let (binds', env') = stgCseBind env binds
           body' = stgCseExpr env' body
-      in mkStgLet StgLetNoEscape binds' body'
+      in mkStgLet (StgLetNoEscape ext) binds' body'
 
 -- Case alternatives
 -- Extend the CSE environment
diff --git a/compiler/simplStg/StgLiftLams.hs b/compiler/simplStg/StgLiftLams.hs
new file mode 100644 (file)
index 0000000..d46e641
--- /dev/null
@@ -0,0 +1,102 @@
+-- | Implements a selective lambda lifter, running late in the optimisation
+-- pipeline.
+--
+-- The transformation itself is implemented in "StgLiftLams.Transformation".
+-- If you are interested in the cost model that is employed to decide whether
+-- to lift a binding or not, look at "StgLiftLams.Analysis".
+-- "StgLiftLams.LiftM" contains the transformation monad that hides away some
+-- plumbing of the transformation.
+module StgLiftLams (
+    -- * Late lambda lifting in STG
+    -- $note
+    Transformation.stgLiftLams
+  ) where
+
+import qualified StgLiftLams.Transformation as Transformation
+
+-- Note [Late lambda lifting in STG]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+-- $note
+-- See also the <https://ghc.haskell.org/trac/ghc/wiki/LateLamLift wiki page>
+-- and Trac #9476.
+--
+-- The basic idea behind lambda lifting is to turn locally defined functions
+-- into top-level functions. Free variables are then passed as additional
+-- arguments at *call sites* instead of having a closure allocated for them at
+-- *definition site*. Example:
+--
+-- @
+--    let x = ...; y = ... in
+--    let f = {x y} \a -> a + x + y in
+--    let g = {f x} \b -> f b + x in
+--    g 5
+-- @
+--
+-- Lambda lifting @f@ would
+--
+--   1. Turn @f@'s free variables into formal parameters
+--   2. Update @f@'s call site within @g@ to @f x y b@
+--   3. Update @g@'s closure: Add @y@ as an additional free variable, while
+--      removing @f@, because @f@ no longer allocates and can be floated to
+--      top-level.
+--   4. Actually float the binding of @f@ to top-level, eliminating the @let@
+--      in the process.
+--
+-- This results in the following program (with free var annotations):
+--
+-- @
+--    f x y a = a + x + y;
+--    let x = ...; y = ... in
+--    let g = {x y} \b -> f x y b + x in
+--    g 5
+-- @
+--
+-- This optimisation is all about lifting only when it is beneficial to do so.
+-- The above seems like a worthwhile lift, judging from heap allocation:
+-- We eliminate @f@'s closure, saving to allocate a closure with 2 words, while
+-- not changing the size of @g@'s closure.
+--
+-- You can probably sense that there's some kind of cost model at play here.
+-- And you are right! But we also employ a couple of other heuristics for the
+-- lifting decision which are outlined in "StgLiftLams.Analysis#when".
+--
+-- The transformation is done in "StgLiftLams.Transformation", which calls out
+-- to 'StgLiftLams.Analysis.goodToLift' for its lifting decision.
+-- It relies on "StgLiftLams.LiftM", which abstracts some subtle STG invariants
+-- into a monadic substrate.
+--
+-- Suffice to say: We trade heap allocation for stack allocation.
+-- The additional arguments have to passed on the stack (or in registers,
+-- depending on architecture) every time we call the function to save a single
+-- heap allocation when entering the let binding. Nofib suggests a mean
+-- improvement of about 1% for this pass, so it seems like a worthwhile thing to
+-- do. Compile-times went up by 0.6%, so all in all a very modest change.
+--
+-- For a concrete example, look at @spectral/atom@. There's a call to 'zipWith'
+-- that is ultimately compiled to something like this
+-- (module desugaring/lowering to actual STG):
+--
+-- @
+--    propagate dt = ...;
+--    runExperiment ... =
+--      let xs = ... in
+--      let ys = ... in
+--      let go = {dt go} \xs ys -> case (xs, ys) of
+--            ([], []) -> []
+--            (x:xs', y:ys') -> propagate dt x y : go xs' ys'
+--      in go xs ys
+-- @
+--
+-- This will lambda lift @go@ to top-level, speeding up the resulting program
+-- by roughly one percent:
+--
+-- @
+--    propagate dt = ...;
+--    go dt xs ys = case (xs, ys) of
+--      ([], []) -> []
+--      (x:xs', y:ys') -> propagate dt x y : go dt xs' ys'
+--    runExperiment ... =
+--      let xs = ... in
+--      let ys = ... in
+--      in go dt xs ys
+-- @
diff --git a/compiler/simplStg/StgLiftLams/Analysis.hs b/compiler/simplStg/StgLiftLams/Analysis.hs
new file mode 100644 (file)
index 0000000..5b87f58
--- /dev/null
@@ -0,0 +1,566 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE DataKinds #-}
+
+-- | Provides the heuristics for when it's beneficial to lambda lift bindings.
+-- Most significantly, this employs a cost model to estimate impact on heap
+-- allocations, by looking at an STG expression's 'Skeleton'.
+module StgLiftLams.Analysis (
+    -- * #when# When to lift
+    -- $when
+
+    -- * #clogro# Estimating closure growth
+    -- $clogro
+
+    -- * AST annotation
+    Skeleton(..), BinderInfo(..), binderInfoBndr,
+    LlStgBinding, LlStgExpr, LlStgRhs, LlStgAlt, tagSkeletonTopBind,
+    -- * Lifting decision
+    goodToLift,
+    closureGrowth -- Exported just for the docs
+  ) where
+
+import GhcPrelude
+
+import BasicTypes
+import Demand
+import DynFlags
+import Id
+import SMRep ( WordOff )
+import StgSyn
+import qualified StgCmmArgRep
+import qualified StgCmmClosure
+import qualified StgCmmLayout
+import Outputable
+import Util
+import VarSet
+
+import Data.Maybe ( mapMaybe )
+
+-- Note [When to lift]
+-- ~~~~~~~~~~~~~~~~~~~
+-- $when
+-- The analysis proceeds in two steps:
+--
+--   1. It tags the syntax tree with analysis information in the form of
+--      'BinderInfo' at each binder and 'Skeleton's at each let-binding
+--      by 'tagSkeletonTopBind' and friends.
+--   2. The resulting syntax tree is treated by the "StgLiftLams.Transformation"
+--      module, calling out to 'goodToLift' to decide if a binding is worthwhile
+--      to lift.
+--      'goodToLift' consults argument occurrence information in 'BinderInfo'
+--      and estimates 'closureGrowth', for which it needs the 'Skeleton'.
+--
+-- So the annotations from 'tagSkeletonTopBind' ultimately fuel 'goodToLift',
+-- which employs a number of heuristics to identify and exclude lambda lifting
+-- opportunities deemed non-beneficial:
+--
+--  [Top-level bindings] can't be lifted.
+--  [Thunks] and data constructors shouldn't be lifted in order not to destroy
+--    sharing.
+--  [Argument occurrences] #arg_occs# of binders prohibit them to be lifted.
+--    Doing the lift would re-introduce the very allocation at call sites that
+--    we tried to get rid off in the first place. We capture analysis
+--    information in 'BinderInfo'. Note that we also consider a nullary
+--    application as argument occurrence, because it would turn into an n-ary
+--    partial application created by a generic apply function. This occurs in
+--    CPS-heavy code like the CS benchmark.
+--  [Join points] should not be lifted, simply because there's no reduction in
+--    allocation to be had.
+--  [Abstracting over join points] destroys join points, because they end up as
+--    arguments to the lifted function.
+--  [Abstracting over known local functions] turns a known call into an unknown
+--    call (e.g. some @stg_ap_*@), which is generally slower. Can be turned off
+--    with @-fstg-lift-lams-known@.
+--  [Calling convention] Don't lift when the resulting function would have a
+--    higher arity than available argument registers for the calling convention.
+--    Can be influenced with @-fstg-lift-(non)rec-args(-any)@.
+--  [Closure growth] introduced when former free variables have to be available
+--    at call sites may actually lead to an increase in overall allocations
+--  resulting from a lift. Estimating closure growth is described in
+--  "StgLiftLams.Analysis#clogro" and is what most of this module is ultimately
+--  concerned with.
+--
+-- There's a <https://ghc.haskell.org/trac/ghc/wiki/LateLamLift wiki page> with
+-- some more background and history.
+
+-- Note [Estimating closure growth]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+-- $clogro
+-- We estimate closure growth by abstracting the syntax tree into a 'Skeleton',
+-- capturing only syntactic details relevant to 'closureGrowth', such as
+--
+--   * 'ClosureSk', representing closure allocation.
+--   * 'RhsSk', representing a RHS of a binding and how many times it's called
+--     by an appropriate 'DmdShell'.
+--   * 'AltSk', 'BothSk' and 'NilSk' for choice, sequence and empty element.
+--
+-- This abstraction is mostly so that the main analysis function 'closureGrowth'
+-- can stay simple and focused. Also, skeletons tend to be much smaller than
+-- the syntax tree they abstract, so it makes sense to construct them once and
+-- and operate on them instead of the actual syntax tree.
+--
+-- A more detailed treatment of computing closure growth, including examples,
+-- can be found in the paper referenced from the
+-- <https://ghc.haskell.org/trac/ghc/wiki/LateLamLift wiki page>.
+
+llTrace :: String -> SDoc -> a -> a
+llTrace _ _ c = c
+-- llTrace a b c = pprTrace a b c
+
+type instance BinderP      'LiftLams = BinderInfo
+type instance XRhsClosure  'LiftLams = DIdSet
+type instance XLet         'LiftLams = Skeleton
+type instance XLetNoEscape 'LiftLams = Skeleton
+
+freeVarsOfRhs :: (XRhsClosure pass ~ DIdSet) => GenStgRhs pass -> DIdSet
+freeVarsOfRhs (StgRhsCon _ _ args) = mkDVarSet [ id | StgVarArg id <- args ]
+freeVarsOfRhs (StgRhsClosure fvs _ _ _ _) = fvs
+
+-- | Captures details of the syntax tree relevant to the cost model, such as
+-- closures, multi-shot lambdas and case expressions.
+data Skeleton
+  = ClosureSk !Id !DIdSet {- ^ free vars -} !Skeleton
+  | RhsSk !DmdShell {- ^ how often the RHS was entered -} !Skeleton
+  | AltSk !Skeleton !Skeleton
+  | BothSk !Skeleton !Skeleton
+  | NilSk
+
+bothSk :: Skeleton -> Skeleton -> Skeleton
+bothSk NilSk b = b
+bothSk a NilSk = a
+bothSk a b     = BothSk a b
+
+altSk :: Skeleton -> Skeleton -> Skeleton
+altSk NilSk b = b
+altSk a NilSk = a
+altSk a b     = AltSk a b
+
+rhsSk :: DmdShell -> Skeleton -> Skeleton
+rhsSk _        NilSk = NilSk
+rhsSk body_dmd skel  = RhsSk body_dmd skel
+
+-- | The type used in binder positions in 'GenStgExpr's.
+data BinderInfo
+  = BindsClosure !Id !Bool -- ^ Let(-no-escape)-bound thing with a flag
+                           --   indicating whether it occurs as an argument
+                           --   or in a nullary application
+                           --   (see "StgLiftLams.Analysis#arg_occs").
+  | BoringBinder !Id       -- ^ Every other kind of binder
+
+-- | Gets the bound 'Id' out a 'BinderInfo'.
+binderInfoBndr :: BinderInfo -> Id
+binderInfoBndr (BoringBinder bndr)   = bndr
+binderInfoBndr (BindsClosure bndr _) = bndr
+
+-- | Returns 'Nothing' for 'BoringBinder's and 'Just' the flag indicating
+-- occurrences as argument or in a nullary applications otherwise.
+binderInfoOccursAsArg :: BinderInfo -> Maybe Bool
+binderInfoOccursAsArg BoringBinder{}     = Nothing
+binderInfoOccursAsArg (BindsClosure _ b) = Just b
+
+instance Outputable Skeleton where
+  ppr NilSk = text ""
+  ppr (AltSk l r) = vcat
+    [ text "{ " <+> ppr l
+    , text "ALT"
+    , text "  " <+> ppr r
+    , text "}"
+    ]
+  ppr (BothSk l r) = ppr l $$ ppr r
+  ppr (ClosureSk f fvs body) = ppr f <+> ppr fvs $$ nest 2 (ppr body)
+  ppr (RhsSk body_dmd body) = hcat
+    [ text "λ["
+    , ppr str
+    , text ", "
+    , ppr use
+    , text "]. "
+    , ppr body
+    ]
+    where
+      str
+        | isStrictDmd body_dmd = '1'
+        | otherwise = '0'
+      use
+        | isAbsDmd body_dmd = '0'
+        | isUsedOnce body_dmd = '1'
+        | otherwise = 'ω'
+
+instance Outputable BinderInfo where
+  ppr = ppr . binderInfoBndr
+
+instance OutputableBndr BinderInfo where
+  pprBndr b = pprBndr b . binderInfoBndr
+  pprPrefixOcc = pprPrefixOcc . binderInfoBndr
+  pprInfixOcc = pprInfixOcc . binderInfoBndr
+  bndrIsJoin_maybe = bndrIsJoin_maybe . binderInfoBndr
+
+mkArgOccs :: [StgArg] -> IdSet
+mkArgOccs = mkVarSet . mapMaybe stg_arg_var
+  where
+    stg_arg_var (StgVarArg occ) = Just occ
+    stg_arg_var _               = Nothing
+
+-- | Tags every binder with its 'BinderInfo' and let bindings with their
+-- 'Skeleton's.
+tagSkeletonTopBind :: CgStgBinding -> LlStgBinding
+-- NilSk is OK when tagging top-level bindings. Also, top-level things are never
+-- lambda-lifted, so no need to track their argument occurrences. They can also
+-- never be let-no-escapes (thus we pass False).
+tagSkeletonTopBind bind = bind'
+  where
+    (_, _, _, bind') = tagSkeletonBinding False NilSk emptyVarSet bind
+
+-- | Tags binders of an 'StgExpr' with its 'BinderInfo' and let bindings with
+-- their 'Skeleton's. Additionally, returns its 'Skeleton' and the set of binder
+-- occurrences in argument and nullary application position
+-- (cf. "StgLiftLams.Analysis#arg_occs").
+tagSkeletonExpr :: CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
+tagSkeletonExpr (StgLit lit)
+  = (NilSk, emptyVarSet, StgLit lit)
+tagSkeletonExpr (StgConApp con args tys)
+  = (NilSk, mkArgOccs args, StgConApp con args tys)
+tagSkeletonExpr (StgOpApp op args ty)
+  = (NilSk, mkArgOccs args, StgOpApp op args ty)
+tagSkeletonExpr (StgApp f args)
+  = (NilSk, arg_occs, StgApp f args)
+  where
+    arg_occs
+      -- This checks for nullary applications, which we treat the same as
+      -- argument occurrences, see "StgLiftLams.Analysis#arg_occs".
+      | null args = unitVarSet f
+      | otherwise = mkArgOccs args
+tagSkeletonExpr (StgLam _ _) = pprPanic "stgLiftLams" (text "StgLam")
+tagSkeletonExpr (StgCase scrut bndr ty alts)
+  = (skel, arg_occs, StgCase scrut' bndr' ty alts')
+  where
+    (scrut_skel, scrut_arg_occs, scrut') = tagSkeletonExpr scrut
+    (alt_skels, alt_arg_occss, alts') = mapAndUnzip3 tagSkeletonAlt alts
+    skel = bothSk scrut_skel (foldr altSk NilSk alt_skels)
+    arg_occs = unionVarSets (scrut_arg_occs:alt_arg_occss) `delVarSet` bndr
+    bndr' = BoringBinder bndr
+tagSkeletonExpr (StgTick t e)
+  = (skel, arg_occs, StgTick t e')
+  where
+    (skel, arg_occs, e') = tagSkeletonExpr e
+tagSkeletonExpr (StgLet _ bind body) = tagSkeletonLet False body bind
+tagSkeletonExpr (StgLetNoEscape _ bind body) = tagSkeletonLet True body bind
+
+mkLet :: Bool -> Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
+mkLet True = StgLetNoEscape
+mkLet _    = StgLet
+
+tagSkeletonLet
+  :: Bool
+  -- ^ Is the binding a let-no-escape?
+  -> CgStgExpr
+  -- ^ Let body
+  -> CgStgBinding
+  -- ^ Binding group
+  -> (Skeleton, IdSet, LlStgExpr)
+  -- ^ RHS skeletons, argument occurrences and annotated binding
+tagSkeletonLet is_lne body bind
+  = (let_skel, arg_occs, mkLet is_lne scope bind' body')
+  where
+    (body_skel, body_arg_occs, body') = tagSkeletonExpr body
+    (let_skel, arg_occs, scope, bind')
+      = tagSkeletonBinding is_lne body_skel body_arg_occs bind
+
+tagSkeletonBinding
+  :: Bool
+  -- ^ Is the binding a let-no-escape?
+  -> Skeleton
+  -- ^ Let body skeleton
+  -> IdSet
+  -- ^ Argument occurrences in the body
+  -> CgStgBinding
+  -- ^ Binding group
+  -> (Skeleton, IdSet, Skeleton, LlStgBinding)
+  -- ^ Let skeleton, argument occurrences, scope skeleton of binding and
+  --   the annotated binding
+tagSkeletonBinding is_lne body_skel body_arg_occs (StgNonRec bndr rhs)
+  = (let_skel, arg_occs, scope, bind')
+  where
+    (rhs_skel, rhs_arg_occs, rhs') = tagSkeletonRhs bndr rhs
+    arg_occs = (body_arg_occs `unionVarSet` rhs_arg_occs) `delVarSet` bndr
+    bind_skel
+      | is_lne    = rhs_skel -- no closure is allocated for let-no-escapes
+      | otherwise = ClosureSk bndr (freeVarsOfRhs rhs) rhs_skel
+    let_skel = bothSk body_skel bind_skel
+    occurs_as_arg = bndr `elemVarSet` body_arg_occs
+    -- Compared to the recursive case, this exploits the fact that @bndr@ is
+    -- never free in @rhs@.
+    scope = body_skel
+    bind' = StgNonRec (BindsClosure bndr occurs_as_arg) rhs'
+tagSkeletonBinding is_lne body_skel body_arg_occs (StgRec pairs)
+  = (let_skel, arg_occs, scope, StgRec pairs')
+  where
+    (bndrs, _) = unzip pairs
+    -- Local recursive STG bindings also regard the defined binders as free
+    -- vars. We want to delete those for our cost model, as these are known
+    -- calls anyway when we add them to the same top-level recursive group as
+    -- the top-level binding currently being analysed.
+    skel_occs_rhss' = map (uncurry tagSkeletonRhs) pairs
+    rhss_arg_occs = map sndOf3 skel_occs_rhss'
+    scope_occs = unionVarSets (body_arg_occs:rhss_arg_occs)
+    arg_occs = scope_occs `delVarSetList` bndrs
+    -- @skel_rhss@ aren't yet wrapped in closures. We'll do that in a moment,
+    -- but we also need the un-wrapped skeletons for calculating the @scope@
+    -- of the group, as the outer closures don't contribute to closure growth
+    -- when we lift this specific binding.
+    scope = foldr (bothSk . fstOf3) body_skel skel_occs_rhss'
+    -- Now we can build the actual Skeleton for the expression just by
+    -- iterating over each bind pair.
+    (bind_skels, pairs') = unzip (zipWith single_bind bndrs skel_occs_rhss')
+    let_skel = foldr bothSk body_skel bind_skels
+    single_bind bndr (skel_rhs, _, rhs') = (bind_skel, (bndr', rhs'))
+      where
+        -- Here, we finally add the closure around each @skel_rhs@.
+        bind_skel
+          | is_lne    = skel_rhs -- no closure is allocated for let-no-escapes
+          | otherwise = ClosureSk bndr fvs skel_rhs
+        fvs = freeVarsOfRhs rhs' `dVarSetMinusVarSet` mkVarSet bndrs
+        bndr' = BindsClosure bndr (bndr `elemVarSet` scope_occs)
+
+tagSkeletonRhs :: Id -> CgStgRhs -> (Skeleton, IdSet, LlStgRhs)
+tagSkeletonRhs _ (StgRhsCon ccs dc args)
+  = (NilSk, mkArgOccs args, StgRhsCon ccs dc args)
+tagSkeletonRhs bndr (StgRhsClosure fvs ccs upd bndrs body)
+  = (rhs_skel, body_arg_occs, StgRhsClosure fvs ccs upd bndrs' body')
+  where
+    bndrs' = map BoringBinder bndrs
+    (body_skel, body_arg_occs, body') = tagSkeletonExpr body
+    rhs_skel = rhsSk (rhsDmdShell bndr) body_skel
+
+-- | How many times will the lambda body of the RHS bound to the given
+-- identifier be evaluated, relative to its defining context? This function
+-- computes the answer in form of a 'DmdShell'.
+rhsDmdShell :: Id -> DmdShell
+rhsDmdShell bndr
+  | is_thunk = oneifyDmd ds
+  | otherwise = peelManyCalls (idArity bndr) cd
+  where
+    is_thunk = idArity bndr == 0
+    -- Let's pray idDemandInfo is still OK after unarise...
+    (ds, cd) = toCleanDmd (idDemandInfo bndr) (idType bndr)
+
+tagSkeletonAlt :: CgStgAlt -> (Skeleton, IdSet, LlStgAlt)
+tagSkeletonAlt (con, bndrs, rhs)
+  = (alt_skel, arg_occs, (con, map BoringBinder bndrs, rhs'))
+  where
+    (alt_skel, alt_arg_occs, rhs') = tagSkeletonExpr rhs
+    arg_occs = alt_arg_occs `delVarSetList` bndrs
+
+-- | Combines several heuristics to decide whether to lambda-lift a given
+-- @let@-binding to top-level. See "StgLiftLams.Analysis#when" for details.
+goodToLift
+  :: DynFlags
+  -> TopLevelFlag
+  -> RecFlag
+  -> (DIdSet -> DIdSet) -- ^ An expander function, turning 'InId's into
+                        -- 'OutId's. See 'StgLiftLams.LiftM.liftedIdsExpander'.
+  -> [(BinderInfo, LlStgRhs)]
+  -> Skeleton
+  -> Maybe DIdSet       -- ^ @Just abs_ids@ <=> This binding is beneficial to
+                        -- lift and @abs_ids@ are the variables it would
+                        -- abstract over
+goodToLift dflags top_lvl rec_flag expander pairs scope = decide
+  [ ("top-level", isTopLevel top_lvl) -- keep in sync with Note [When to lift]
+  , ("memoized", any_memoized)
+  , ("argument occurrences", arg_occs)
+  , ("join point", is_join_point)
+  , ("abstracts join points", abstracts_join_ids)
+  , ("abstracts known local function", abstracts_known_local_fun)
+  , ("args spill on stack", args_spill_on_stack)
+  , ("increases allocation", inc_allocs)
+  ] where
+      decide deciders
+        | not (fancy_or deciders)
+        = llTrace "stgLiftLams:lifting"
+                  (ppr bndrs <+> ppr abs_ids $$
+                   ppr allocs $$
+                   ppr scope) $
+          Just abs_ids
+        | otherwise
+        = Nothing
+      ppr_deciders = vcat . map (text . fst) . filter snd
+      fancy_or deciders
+        = llTrace "stgLiftLams:goodToLift" (ppr bndrs $$ ppr_deciders deciders) $
+          any snd deciders
+
+      bndrs = map (binderInfoBndr . fst) pairs
+      bndrs_set = mkVarSet bndrs
+      rhss = map snd pairs
+
+      -- First objective: Calculate @abs_ids@, e.g. the former free variables
+      -- the lifted binding would abstract over. We have to merge the free
+      -- variables of all RHS to get the set of variables that will have to be
+      -- passed through parameters.
+      fvs = unionDVarSets (map freeVarsOfRhs rhss)
+      -- To lift the binding to top-level, we want to delete the lifted binders
+      -- themselves from the free var set. Local let bindings track recursive
+      -- occurrences in their free variable set. We neither want to apply our
+      -- cost model to them (see 'tagSkeletonRhs'), nor pass them as parameters
+      -- when lifted, as these are known calls. We call the resulting set the
+      -- identifiers we abstract over, thus @abs_ids@. These are all 'OutId's.
+      -- We will save the set in 'LiftM.e_expansions' for each of the variables
+      -- if we perform the lift.
+      abs_ids = expander (delDVarSetList fvs bndrs)
+
+      -- We don't lift updatable thunks or constructors
+      any_memoized = any is_memoized_rhs rhss
+      is_memoized_rhs StgRhsCon{} = True
+      is_memoized_rhs (StgRhsClosure _ _ upd _ _) = isUpdatable upd
+
+      -- Don't lift binders occuring as arguments. This would result in complex
+      -- argument expressions which would have to be given a name, reintroducing
+      -- the very allocation at each call site that we wanted to get rid off in
+      -- the first place.
+      arg_occs = or (mapMaybe (binderInfoOccursAsArg . fst) pairs)
+
+      -- These don't allocate anyway.
+      is_join_point = any isJoinId bndrs
+
+      -- Abstracting over join points/let-no-escapes spoils them.
+      abstracts_join_ids = any isJoinId (dVarSetElems abs_ids)
+
+      -- Abstracting over known local functions that aren't floated themselves
+      -- turns a known, fast call into an unknown, slow call:
+      --
+      --    let f x = ...
+      --        g y = ... f x ... -- this was a known call
+      --    in g 4
+      --
+      -- After lifting @g@, but not @f@:
+      --
+      --    l_g f y = ... f y ... -- this is now an unknown call
+      --    let f x = ...
+      --    in l_g f 4
+      --
+      -- We can abuse the results of arity analysis for this:
+      -- idArity f > 0 ==> known
+      known_fun id = idArity id > 0
+      abstracts_known_local_fun
+        = not (liftLamsKnown dflags) && any known_fun (dVarSetElems abs_ids)
+
+      -- Number of arguments of a RHS in the current binding group if we decide
+      -- to lift it
+      n_args
+        = length
+        . StgCmmClosure.nonVoidIds -- void parameters don't appear in Cmm
+        . (dVarSetElems abs_ids ++)
+        . rhsLambdaBndrs
+      max_n_args
+        | isRec rec_flag = liftLamsRecArgs dflags
+        | otherwise      = liftLamsNonRecArgs dflags
+      -- We have 5 hardware registers on x86_64 to pass arguments in. Any excess
+      -- args are passed on the stack, which means slow memory accesses
+      args_spill_on_stack
+        | Just n <- max_n_args = maximum (map n_args rhss) > n
+        | otherwise = False
+
+      -- We only perform the lift if allocations didn't increase.
+      -- Note that @clo_growth@ will be 'infinity' if there was positive growth
+      -- under a multi-shot lambda.
+      -- Also, abstracting over LNEs is unacceptable. LNEs might return
+      -- unlifted tuples, which idClosureFootprint can't cope with.
+      inc_allocs = abstracts_join_ids || allocs > 0
+      allocs = clo_growth + mkIntWithInf (negate closuresSize)
+      -- We calculate and then add up the size of each binding's closure.
+      -- GHC does not currently share closure environments, and we either lift
+      -- the entire recursive binding group or none of it.
+      closuresSize = sum $ flip map rhss $ \rhs ->
+        closureSize dflags
+        . dVarSetElems
+        . expander
+        . flip dVarSetMinusVarSet bndrs_set
+        $ freeVarsOfRhs rhs
+      clo_growth = closureGrowth expander (idClosureFootprint dflags) bndrs_set abs_ids scope
+
+rhsLambdaBndrs :: LlStgRhs -> [Id]
+rhsLambdaBndrs StgRhsCon{} = []
+rhsLambdaBndrs (StgRhsClosure _ _ _ bndrs _) = map binderInfoBndr bndrs
+
+-- | The size in words of a function closure closing over the given 'Id's,
+-- including the header.
+closureSize :: DynFlags -> [Id] -> WordOff
+closureSize dflags ids = words
+  where
+    (words, _, _)
+      -- Functions have a StdHeader (as opposed to ThunkHeader).
+      -- Note that mkVirtHeadOffsets will account for profiling headers, so
+      -- lifting decisions vary if we begin to profile stuff. Maybe we shouldn't
+      -- do this or deactivate profiling in @dflags@?
+      = StgCmmLayout.mkVirtHeapOffsets dflags StgCmmLayout.StdHeader
+      . StgCmmClosure.addIdReps
+      . StgCmmClosure.nonVoidIds
+      $ ids
+
+-- | The number of words a single 'Id' adds to a closure's size.
+-- Note that this can't handle unboxed tuples (which may still be present in
+-- let-no-escapes, even after Unarise), in which case
+-- @'StgCmmClosure.idPrimRep'@ will crash.
+idClosureFootprint:: DynFlags -> Id -> WordOff
+idClosureFootprint dflags
+  = StgCmmArgRep.argRepSizeW dflags
+  . StgCmmArgRep.idArgRep
+
+-- | @closureGrowth expander sizer f fvs@ computes the closure growth in words
+-- as a result of lifting @f@ to top-level. If there was any growing closure
+-- under a multi-shot lambda, the result will be 'infinity'.
+-- Also see "StgLiftLams.Analysis#clogro".
+closureGrowth
+  :: (DIdSet -> DIdSet)
+  -- ^ Expands outer free ids that were lifted to their free vars
+  -> (Id -> Int)
+  -- ^ Computes the closure footprint of an identifier
+  -> IdSet
+  -- ^ Binding group for which lifting is to be decided
+  -> DIdSet
+  -- ^ Free vars of the whole binding group prior to lifting it. These must be
+  --   available at call sites if we decide to lift the binding group.
+  -> Skeleton
+  -- ^ Abstraction of the scope of the function
+  -> IntWithInf
+  -- ^ Closure growth. 'infinity' indicates there was growth under a
+  --   (multi-shot) lambda.
+closureGrowth expander sizer group abs_ids = go
+  where
+    go NilSk = 0
+    go (BothSk a b) = go a + go b
+    go (AltSk a b) = max (go a) (go b)
+    go (ClosureSk _ clo_fvs rhs)
+      -- If no binder of the @group@ occurs free in the closure, the lifting
+      -- won't have any effect on it and we can omit the recursive call.
+      | n_occs == 0 = 0
+      -- Otherwise, we account the cost of allocating the closure and add it to
+      -- the closure growth of its RHS.
+      | otherwise   = mkIntWithInf cost + go rhs
+      where
+        n_occs = sizeDVarSet (clo_fvs' `dVarSetIntersectVarSet` group)
+        -- What we close over considering prior lifting decisions
+        clo_fvs' = expander clo_fvs
+        -- Variables that would additionally occur free in the closure body if
+        -- we lift @f@
+        newbies = abs_ids `minusDVarSet` clo_fvs'
+        -- Lifting @f@ removes @f@ from the closure but adds all @newbies@
+        cost = foldDVarSet (\id size -> sizer id + size) 0 newbies - n_occs
+    go (RhsSk body_dmd body)
+      -- The conservative assumption would be that
+      --   1. Every RHS with positive growth would be called multiple times,
+      --      modulo thunks.
+      --   2. Every RHS with negative growth wouldn't be called at all.
+      --
+      -- In the first case, we'd have to return 'infinity', while in the
+      -- second case, we'd have to return 0. But we can do far better
+      -- considering information from the demand analyser, which provides us
+      -- with conservative estimates on minimum and maximum evaluation
+      -- cardinality. The @body_dmd@ part of 'RhsSk' is the result of
+      -- 'rhsDmdShell' and accurately captures the cardinality of the RHSs body
+      -- relative to its defining context.
+      | isAbsDmd body_dmd   = 0
+      | cg <= 0             = if isStrictDmd body_dmd then cg else 0
+      | isUsedOnce body_dmd = cg
+      | otherwise           = infinity
+      where
+        cg = go body
diff --git a/compiler/simplStg/StgLiftLams/LiftM.hs b/compiler/simplStg/StgLiftLams/LiftM.hs
new file mode 100644 (file)
index 0000000..c9e520a
--- /dev/null
@@ -0,0 +1,349 @@
+{-# LANGUAGE CPP #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+
+-- | Hides away distracting bookkeeping while lambda lifting into a 'LiftM'
+-- monad.
+module StgLiftLams.LiftM (
+    decomposeStgBinding, mkStgBinding,
+    Env (..),
+    -- * #floats# Handling floats
+    -- $floats
+    FloatLang (..), collectFloats, -- Exported just for the docs
+    -- * Transformation monad
+    LiftM, runLiftM, withCaffyness,
+    -- ** Adding bindings
+    startBindingGroup, endBindingGroup, addTopStringLit, addLiftedBinding,
+    -- ** Substitution and binders
+    withSubstBndr, withSubstBndrs, withLiftedBndr, withLiftedBndrs,
+    -- ** Occurrences
+    substOcc, isLifted, formerFreeVars, liftedIdsExpander
+  ) where
+
+#include "HsVersions.h"
+
+import GhcPrelude
+
+import BasicTypes
+import CostCentre ( isCurrentCCS, dontCareCCS )
+import DynFlags
+import FastString
+import Id
+import IdInfo
+import Name
+import Outputable
+import OrdList
+import StgSubst
+import StgSyn
+import Type
+import UniqSupply
+import Util
+import VarEnv
+import VarSet
+
+import Control.Arrow ( second )
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.RWS.Strict ( RWST, runRWST )
+import qualified Control.Monad.Trans.RWS.Strict as RWS
+import Control.Monad.Trans.Cont ( ContT (..) )
+import Data.ByteString ( ByteString )
+import Data.List ( foldl' )
+
+-- | @uncurry 'mkStgBinding' . 'decomposeStgBinding' = id@
+decomposeStgBinding :: GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
+decomposeStgBinding (StgRec pairs) = (Recursive, pairs)
+decomposeStgBinding (StgNonRec bndr rhs) = (NonRecursive, [(bndr, rhs)])
+
+mkStgBinding :: RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
+mkStgBinding Recursive = StgRec
+mkStgBinding NonRecursive = uncurry StgNonRec . head
+
+-- | Environment threaded around in a scoped, @Reader@-like fashion.
+data Env
+  = Env
+  { e_dflags     :: !DynFlags
+  -- ^ Read-only.
+  , e_subst      :: !Subst
+  -- ^ We need to track the renamings of local 'InId's to their lifted 'OutId',
+  -- because shadowing might make a closure's free variables unavailable at its
+  -- call sites. Consider:
+  -- @
+  --    let f y = x + y in let x = 4 in f x
+  -- @
+  -- Here, @f@ can't be lifted to top-level, because its free variable @x@ isn't
+  -- available at its call site.
+  , e_expansions :: !(IdEnv DIdSet)
+  -- ^ Lifted 'Id's don't occur as free variables in any closure anymore, because
+  -- they are bound at the top-level. Every occurrence must supply the formerly
+  -- free variables of the lifted 'Id', so they in turn become free variables of
+  -- the call sites. This environment tracks this expansion from lifted 'Id's to
+  -- their free variables.
+  --
+  -- 'InId's to 'OutId's.
+  --
+  -- Invariant: 'Id's not present in this map won't be substituted.
+  , e_in_caffy_context :: !Bool
+  -- ^ Are we currently analysing within a caffy context (e.g. the containing
+  -- top-level binder's 'idCafInfo' is 'MayHaveCafRefs')? If not, we can safely
+  -- assume that functions we lift out aren't caffy either.
+  }
+
+emptyEnv :: DynFlags -> Env
+emptyEnv dflags = Env dflags emptySubst emptyVarEnv False
+
+
+-- Note [Handling floats]
+-- ~~~~~~~~~~~~~~~~~~~~~~
+-- $floats
+-- Consider the following expression:
+--
+-- @
+--     f x =
+--       let g y = ... f y ...
+--       in g x
+-- @
+--
+-- What happens when we want to lift @g@? Normally, we'd put the lifted @l_g@
+-- binding above the binding for @f@:
+--
+-- @
+--     g f y = ... f y ...
+--     f x = g f x
+-- @
+--
+-- But this very unnecessarily turns a known call to @f@ into an unknown one, in
+-- addition to complicating matters for the analysis.
+-- Instead, we'd really like to put both functions in the same recursive group,
+-- thereby preserving the known call:
+--
+-- @
+--     Rec {
+--       g y = ... f y ...
+--       f x = g x
+--     }
+-- @
+--
+-- But we don't want this to happen for just /any/ binding. That would create
+-- possibly huge recursive groups in the process, calling for an occurrence
+-- analyser on STG.
+-- So, we need to track when we lift a binding out of a recursive RHS and add
+-- the binding to the same recursive group as the enclosing recursive binding
+-- (which must have either already been at the top-level or decided to be
+-- lifted itself in order to preserve the known call).
+--
+-- This is done by expressing this kind of nesting structure as a 'Writer' over
+-- @['FloatLang']@ and flattening this expression in 'runLiftM' by a call to
+-- 'collectFloats'.
+-- API-wise, the analysis will not need to know about the whole 'FloatLang'
+-- business and will just manipulate it indirectly through actions in 'LiftM'.
+
+-- | We need to detect when we are lifting something out of the RHS of a
+-- recursive binding (c.f. "StgLiftLams.LiftM#floats"), in which case that
+-- binding needs to be added to the same top-level recursive group. This
+-- requires we detect a certain nesting structure, which is encoded by
+-- 'StartBindingGroup' and 'EndBindingGroup'.
+--
+-- Although 'collectFloats' will only ever care if the current binding to be
+-- lifted (through 'LiftedBinding') will occur inside such a binding group or
+-- not, e.g. doesn't care about the nesting level as long as its greater than 0.
+data FloatLang
+  = StartBindingGroup
+  | EndBindingGroup
+  | PlainTopBinding OutStgTopBinding
+  | LiftedBinding OutStgBinding
+
+instance Outputable FloatLang where
+  ppr StartBindingGroup = char '('
+  ppr EndBindingGroup = char ')'
+  ppr (PlainTopBinding StgTopStringLit{}) = text "<str>"
+  ppr (PlainTopBinding (StgTopLifted b)) = ppr (LiftedBinding b)
+  ppr (LiftedBinding bind) = (if isRec rec then char 'r' else char 'n') <+> ppr (map fst pairs)
+    where
+      (rec, pairs) = decomposeStgBinding bind
+
+-- | Flattens an expression in @['FloatLang']@ into an STG program, see #floats.
+-- Important pre-conditions: The nesting of opening 'StartBindinGroup's and
+-- closing 'EndBindinGroup's is balanced. Also, it is crucial that every binding
+-- group has at least one recursive binding inside. Otherwise there's no point
+-- in announcing the binding group in the first place and an @ASSERT@ will
+-- trigger.
+collectFloats :: [FloatLang] -> [OutStgTopBinding]
+collectFloats = go (0 :: Int) []
+  where
+    go 0 [] [] = []
+    go _ _ [] = pprPanic "collectFloats" (text "unterminated group")
+    go n binds (f:rest) = case f of
+      StartBindingGroup -> go (n+1) binds rest
+      EndBindingGroup
+        | n == 0 -> pprPanic "collectFloats" (text "no group to end")
+        | n == 1 -> StgTopLifted (merge_binds binds) : go 0 [] rest
+        | otherwise -> go (n-1) binds rest
+      PlainTopBinding top_bind
+        | n == 0 -> top_bind : go n binds rest
+        | otherwise -> pprPanic "collectFloats" (text "plain top binding inside group")
+      LiftedBinding bind
+        | n == 0 -> StgTopLifted (rm_cccs bind) : go n binds rest
+        | otherwise -> go n (bind:binds) rest
+
+    map_rhss f = uncurry mkStgBinding . second (map (second f)) . decomposeStgBinding
+    rm_cccs = map_rhss removeRhsCCCS
+    merge_binds binds = ASSERT( any is_rec binds )
+                        StgRec (concatMap (snd . decomposeStgBinding . rm_cccs) binds)
+    is_rec StgRec{} = True
+    is_rec _ = False
+
+-- | Omitting this makes for strange closure allocation schemes that crash the
+-- GC.
+removeRhsCCCS :: GenStgRhs pass -> GenStgRhs pass
+removeRhsCCCS (StgRhsClosure ext ccs upd bndrs body)
+  | isCurrentCCS ccs
+  = StgRhsClosure ext dontCareCCS upd bndrs body
+removeRhsCCCS (StgRhsCon ccs con args)
+  | isCurrentCCS ccs
+  = StgRhsCon dontCareCCS con args
+removeRhsCCCS rhs = rhs
+
+-- | The analysis monad consists of the following 'RWST' components:
+--
+--     * 'Env': Reader-like context. Contains a substitution, info about how
+--       how lifted identifiers are to be expanded into applications and details
+--       such as 'DynFlags' and a flag helping with determining if a lifted
+--       binding is caffy.
+--
+--     * @'OrdList' 'FloatLang'@: Writer output for the resulting STG program.
+--
+--     * No pure state component
+--
+--     * But wrapping around 'UniqSM' for generating fresh lifted binders.
+--       (The @uniqAway@ approach could give the same name to two different
+--       lifted binders, so this is necessary.)
+newtype LiftM a
+  = LiftM { unwrapLiftM :: RWST Env (OrdList FloatLang) () UniqSM a }
+  deriving (Functor, Applicative, Monad)
+
+instance HasDynFlags LiftM where
+  getDynFlags = LiftM (RWS.asks e_dflags)
+
+instance MonadUnique LiftM where
+  getUniqueSupplyM = LiftM (lift getUniqueSupplyM)
+  getUniqueM = LiftM (lift getUniqueM)
+  getUniquesM = LiftM (lift getUniquesM)
+
+runLiftM :: DynFlags -> UniqSupply -> LiftM () -> [OutStgTopBinding]
+runLiftM dflags us (LiftM m) = collectFloats (fromOL floats)
+  where
+    (_, _, floats) = initUs_ us (runRWST m (emptyEnv dflags) ())
+
+-- | Assumes a given caffyness for the execution of the passed action, which
+-- influences the 'cafInfo' of lifted bindings.
+withCaffyness :: Bool -> LiftM a -> LiftM a
+withCaffyness caffy action
+  = LiftM (RWS.local (\e -> e { e_in_caffy_context = caffy }) (unwrapLiftM action))
+
+-- | Writes a plain 'StgTopStringLit' to the output.
+addTopStringLit :: OutId -> ByteString -> LiftM ()
+addTopStringLit id = LiftM . RWS.tell . unitOL . PlainTopBinding . StgTopStringLit id
+
+-- | Starts a recursive binding group. See #floats# and 'collectFloats'.
+startBindingGroup :: LiftM ()
+startBindingGroup = LiftM $ RWS.tell $ unitOL $ StartBindingGroup
+
+-- | Ends a recursive binding group. See #floats# and 'collectFloats'.
+endBindingGroup :: LiftM ()
+endBindingGroup = LiftM $ RWS.tell $ unitOL $ EndBindingGroup
+
+-- | Lifts a binding to top-level. Depending on whether it's declared inside
+-- a recursive RHS (see #floats# and 'collectFloats'), this might be added to
+-- an existing recursive top-level binding group.
+addLiftedBinding :: OutStgBinding -> LiftM ()
+addLiftedBinding = LiftM . RWS.tell . unitOL . LiftedBinding
+
+-- | Takes a binder and a continuation which is called with the substituted
+-- binder. The continuation will be evaluated in a 'LiftM' context in which that
+-- binder is deemed in scope. Think of it as a 'RWS.local' computation: After
+-- the continuation finishes, the new binding won't be in scope anymore.
+withSubstBndr :: Id -> (Id -> LiftM a) -> LiftM a
+withSubstBndr bndr inner = LiftM $ do
+  subst <- RWS.asks e_subst
+  let (bndr', subst') = substBndr bndr subst
+  RWS.local (\e -> e { e_subst = subst' }) (unwrapLiftM (inner bndr'))
+
+-- | See 'withSubstBndr'.
+withSubstBndrs :: Traversable f => f Id -> (f Id -> LiftM a) -> LiftM a
+withSubstBndrs = runContT . traverse (ContT . withSubstBndr)
+
+-- | Similarly to 'withSubstBndr', this function takes a set of variables to
+-- abstract over, the binder to lift (and generate a fresh, substituted name
+-- for) and a continuation in which that fresh, lifted binder is in scope.
+--
+-- It takes care of all the details involved with copying and adjusting the
+-- binder, fresh name generation and caffyness.
+withLiftedBndr :: DIdSet -> Id -> (Id -> LiftM a) -> LiftM a
+withLiftedBndr abs_ids bndr inner = do
+  uniq <- getUniqueM
+  let str = "$l" ++ occNameString (getOccName bndr)
+  let ty = mkLamTypes (dVarSetElems abs_ids) (idType bndr)
+  -- When the enclosing top-level binding is not caffy, then the lifted
+  -- binding will not be caffy either. If we don't recognize this, non-caffy
+  -- things call caffy things and then codegen screws up.
+  in_caffy_ctxt <- LiftM (RWS.asks e_in_caffy_context)
+  let caf_info = if in_caffy_ctxt then MayHaveCafRefs else NoCafRefs
+  let bndr'
+        -- See Note [transferPolyIdInfo] in Id.hs. We need to do this at least
+        -- for arity information.
+        = transferPolyIdInfo bndr (dVarSetElems abs_ids)
+        -- Otherwise we confuse code gen if bndr was not caffy: the new bndr is
+        -- assumed to be caffy and will need an SRT. Transitive call sites might
+        -- not be caffy themselves and subsequently will miss a static link
+        -- field in their closure. Chaos ensues.
+        . flip setIdCafInfo caf_info
+        . mkSysLocalOrCoVar (mkFastString str) uniq
+        $ ty
+  LiftM $ RWS.local
+    (\e -> e
+      { e_subst = extendSubst bndr bndr' $ extendInScope bndr' $ e_subst e
+      , e_expansions = extendVarEnv (e_expansions e) bndr abs_ids
+      })
+    (unwrapLiftM (inner bndr'))
+
+-- | See 'withLiftedBndr'.
+withLiftedBndrs :: Traversable f => DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
+withLiftedBndrs abs_ids = runContT . traverse (ContT . withLiftedBndr abs_ids)
+
+-- | Substitutes a binder /occurrence/, which was brought in scope earlier by
+-- 'withSubstBndr'\/'withLiftedBndr'.
+substOcc :: Id -> LiftM Id
+substOcc id = LiftM (RWS.asks (lookupIdSubst id . e_subst))
+
+-- | Whether the given binding was decided to be lambda lifted.
+isLifted :: InId -> LiftM Bool
+isLifted bndr = LiftM (RWS.asks (elemVarEnv bndr . e_expansions))
+
+-- | Returns an empty list for a binding that was not lifted and the list of all
+-- local variables the binding abstracts over (so, exactly the additional
+-- arguments at adjusted call sites) otherwise.
+formerFreeVars :: InId -> LiftM [OutId]
+formerFreeVars f = LiftM $ do
+  expansions <- RWS.asks e_expansions
+  pure $ case lookupVarEnv expansions f of
+    Nothing -> []
+    Just fvs -> dVarSetElems fvs
+
+-- | Creates an /expander function/ for the current set of lifted binders.
+-- This expander function will replace any 'InId' by their corresponding 'OutId'
+-- and, in addition, will expand any lifted binders by the former free variables
+-- it abstracts over.
+liftedIdsExpander :: LiftM (DIdSet -> DIdSet)
+liftedIdsExpander = LiftM $ do
+  expansions <- RWS.asks e_expansions
+  subst <- RWS.asks e_subst
+  -- We use @noWarnLookupIdSubst@ here in order to suppress "not in scope"
+  -- warnings generated by 'lookupIdSubst' due to local bindings within RHS.
+  -- These are not in the InScopeSet of @subst@ and extending the InScopeSet in
+  -- @goodToLift@/@closureGrowth@ before passing it on to @expander@ is too much
+  -- trouble.
+  let go set fv = case lookupVarEnv expansions fv of
+        Nothing -> extendDVarSet set (noWarnLookupIdSubst fv subst) -- Not lifted
+        Just fvs' -> unionDVarSet set fvs'
+  let expander fvs = foldl' go emptyDVarSet (dVarSetElems fvs)
+  pure expander
diff --git a/compiler/simplStg/StgLiftLams/Transformation.hs b/compiler/simplStg/StgLiftLams/Transformation.hs
new file mode 100644 (file)
index 0000000..8c4d616
--- /dev/null
@@ -0,0 +1,155 @@
+{-# LANGUAGE CPP #-}
+
+-- | (Mostly) textbook instance of the lambda lifting transformation,
+-- selecting which bindings to lambda lift by consulting 'goodToLift'.
+module StgLiftLams.Transformation (stgLiftLams) where
+
+#include "HsVersions.h"
+
+import GhcPrelude
+
+import BasicTypes
+import DynFlags
+import Id
+import IdInfo
+import StgFVs ( annBindingFreeVars )
+import StgLiftLams.Analysis
+import StgLiftLams.LiftM
+import StgSyn
+import Outputable
+import UniqSupply
+import Util
+import VarSet
+import Control.Monad ( when )
+import Data.Maybe ( isNothing )
+
+-- | Lambda lifts bindings to top-level deemed worth lifting (see 'goodToLift').
+stgLiftLams :: DynFlags -> UniqSupply -> [InStgTopBinding] -> [OutStgTopBinding]
+stgLiftLams dflags us = runLiftM dflags us . foldr liftTopLvl (pure ())
+
+liftTopLvl :: InStgTopBinding -> LiftM () -> LiftM ()
+liftTopLvl (StgTopStringLit bndr lit) rest = withSubstBndr bndr $ \bndr' -> do
+  addTopStringLit bndr' lit
+  rest
+liftTopLvl (StgTopLifted bind) rest = do
+  let is_rec = isRec $ fst $ decomposeStgBinding bind
+  when is_rec startBindingGroup
+  let bind_w_fvs = annBindingFreeVars bind
+  withLiftedBind TopLevel (tagSkeletonTopBind bind_w_fvs) NilSk $ \mb_bind' -> do
+    -- We signal lifting of a binding through returning Nothing.
+    -- Should never happen for a top-level binding, though, since we are already
+    -- at top-level.
+    case mb_bind' of
+      Nothing -> pprPanic "StgLiftLams" (text "Lifted top-level binding")
+      Just bind' -> addLiftedBinding bind'
+    when is_rec endBindingGroup
+    rest
+
+withLiftedBind
+  :: TopLevelFlag
+  -> LlStgBinding
+  -> Skeleton
+  -> (Maybe OutStgBinding -> LiftM a)
+  -> LiftM a
+withLiftedBind top_lvl bind scope k
+  | isTopLevel top_lvl
+  = withCaffyness (is_caffy pairs) go
+  | otherwise
+  = go
+  where
+    (rec, pairs) = decomposeStgBinding bind
+    is_caffy = any (mayHaveCafRefs . idCafInfo . binderInfoBndr . fst)
+    go = withLiftedBindPairs top_lvl rec pairs scope (k . fmap (mkStgBinding rec))
+
+withLiftedBindPairs
+  :: TopLevelFlag
+  -> RecFlag
+  -> [(BinderInfo, LlStgRhs)]
+  -> Skeleton
+  -> (Maybe [(Id, OutStgRhs)] -> LiftM a)
+  -> LiftM a
+withLiftedBindPairs top rec pairs scope k = do
+  let (infos, rhss) = unzip pairs
+  let bndrs = map binderInfoBndr infos
+  expander <- liftedIdsExpander
+  dflags <- getDynFlags
+  case goodToLift dflags top rec expander pairs scope of
+    -- @abs_ids@ is the set of all variables that need to become parameters.
+    Just abs_ids -> withLiftedBndrs abs_ids bndrs $ \bndrs' -> do
+      -- Within this block, all binders in @bndrs@ will be noted as lifted, so
+      -- that the return value of @liftedIdsExpander@ in this context will also
+      -- expand the bindings in @bndrs@ to their free variables.
+      -- Now we can recurse into the RHSs and see if we can lift any further
+      -- bindings. We pass the set of expanded free variables (thus OutIds) on
+      -- to @liftRhs@ so that it can add them as parameter binders.
+      when (isRec rec) startBindingGroup
+      rhss' <- traverse (liftRhs (Just abs_ids)) rhss
+      let pairs' = zip bndrs' rhss'
+      addLiftedBinding (mkStgBinding rec pairs')
+      when (isRec rec) endBindingGroup
+      k Nothing
+    Nothing -> withSubstBndrs bndrs $ \bndrs' -> do
+      -- Don't lift the current binding, but possibly some bindings in their
+      -- RHSs.
+      rhss' <- traverse (liftRhs Nothing) rhss
+      let pairs' = zip bndrs' rhss'
+      k (Just pairs')
+
+liftRhs
+  :: Maybe (DIdSet)
+  -- ^ @Just former_fvs@ <=> this RHS was lifted and we have to add @former_fvs@
+  -- as lambda binders, discarding all free vars.
+  -> LlStgRhs
+  -> LiftM OutStgRhs
+liftRhs mb_former_fvs rhs@(StgRhsCon ccs con args)
+  = ASSERT2 ( isNothing mb_former_fvs, text "Should never lift a constructor" $$ ppr rhs)
+    StgRhsCon ccs con <$> traverse liftArgs args
+liftRhs Nothing (StgRhsClosure _ ccs upd infos body) = do
+  -- This RHS wasn't lifted.
+  withSubstBndrs (map binderInfoBndr infos) $ \bndrs' ->
+    StgRhsClosure noExtSilent ccs upd bndrs' <$> liftExpr body
+liftRhs (Just former_fvs) (StgRhsClosure _ ccs upd infos body) = do
+  -- This RHS was lifted. Insert extra binders for @former_fvs@.
+  withSubstBndrs (map binderInfoBndr infos) $ \bndrs' -> do
+    let bndrs'' = dVarSetElems former_fvs ++ bndrs'
+    StgRhsClosure noExtSilent ccs upd bndrs'' <$> liftExpr body
+
+liftArgs :: InStgArg -> LiftM OutStgArg
+liftArgs a@(StgLitArg _) = pure a
+liftArgs (StgVarArg occ) = do
+  ASSERTM2( not <$> isLifted occ, text "StgArgs should never be lifted" $$ ppr occ )
+  StgVarArg <$> substOcc occ
+
+liftExpr :: LlStgExpr -> LiftM OutStgExpr
+liftExpr (StgLit lit) = pure (StgLit lit)
+liftExpr (StgTick t e) = StgTick t <$> liftExpr e
+liftExpr (StgApp f args) = do
+  f' <- substOcc f
+  args' <- traverse liftArgs args
+  fvs' <- formerFreeVars f
+  let top_lvl_args = map StgVarArg fvs' ++ args'
+  pure (StgApp f' top_lvl_args)
+liftExpr (StgConApp con args tys) = StgConApp con <$> traverse liftArgs args <*> pure tys
+liftExpr (StgOpApp op args ty) = StgOpApp op <$> traverse liftArgs args <*> pure ty
+liftExpr (StgLam _ _) = pprPanic "stgLiftLams" (text "StgLam")
+liftExpr (StgCase scrut info ty alts) = do
+  scrut' <- liftExpr scrut
+  withSubstBndr (binderInfoBndr info) $ \bndr' -> do
+    alts' <- traverse liftAlt alts
+    pure (StgCase scrut' bndr' ty alts')
+liftExpr (StgLet scope bind body)
+  = withLiftedBind NotTopLevel bind scope $ \mb_bind' -> do
+      body' <- liftExpr body
+      case mb_bind' of
+        Nothing -> pure body' -- withLiftedBindPairs decided to lift it and already added floats
+        Just bind' -> pure (StgLet noExtSilent bind' body')
+liftExpr (StgLetNoEscape scope bind body)
+  = withLiftedBind NotTopLevel bind scope $ \mb_bind' -> do
+      body' <- liftExpr body
+      case mb_bind' of
+        Nothing -> pprPanic "stgLiftLams" (text "Should never decide to lift LNEs")
+        Just bind' -> pure (StgLetNoEscape noExtSilent bind' body')
+
+liftAlt :: LlStgAlt -> LiftM OutStgAlt
+liftAlt (con, infos, rhs) = withSubstBndrs (map binderInfoBndr infos) $ \bndrs' ->
+  (,,) con bndrs' <$> liftExpr rhs
index a2a9a85..05a0cf9 100644 (file)
@@ -153,12 +153,12 @@ statExpr (StgConApp _ _ _)= countOne ConstructorApps
 statExpr (StgOpApp _ _ _) = countOne PrimitiveApps
 statExpr (StgTick _ e)    = statExpr e
 
-statExpr (StgLetNoEscape binds body)
+statExpr (StgLetNoEscape binds body)
   = statBinding False{-not top-level-} binds    `combineSE`
     statExpr body                               `combineSE`
     countOne LetNoEscapes
 
-statExpr (StgLet binds body)
+statExpr (StgLet binds body)
   = statBinding False{-not top-level-} binds    `combineSE`
     statExpr body
 
index e87fd85..c908580 100644 (file)
@@ -353,11 +353,11 @@ unariseExpr rho (StgCase scrut bndr alt_ty alts)
                        -- bndr may have a unboxed sum/tuple type but it will be
                        -- dead after unarise (checked in StgLint)
 
-unariseExpr rho (StgLet bind e)
-  = StgLet <$> unariseBinding rho bind <*> unariseExpr rho e
+unariseExpr rho (StgLet ext bind e)
+  = StgLet ext <$> unariseBinding rho bind <*> unariseExpr rho e
 
-unariseExpr rho (StgLetNoEscape bind e)
-  = StgLetNoEscape <$> unariseBinding rho bind <*> unariseExpr rho e
+unariseExpr rho (StgLetNoEscape ext bind e)
+  = StgLetNoEscape ext <$> unariseBinding rho bind <*> unariseExpr rho e
 
 unariseExpr rho (StgTick tick e)
   = StgTick tick <$> unariseExpr rho e
index 74bb7b6..573db78 100644 (file)
@@ -631,8 +631,8 @@ coreToStgLet bind body = do
 
         -- Compute the new let-expression
     let
-        new_let | isJoinBind bind = StgLetNoEscape bind2 body2
-                | otherwise       = StgLet bind2 body2
+        new_let | isJoinBind bind = StgLetNoEscape noExtSilent bind2 body2
+                | otherwise       = StgLet noExtSilent bind2 body2
 
     return new_let
   where
index 80ce33f..edfc94e 100644 (file)
@@ -1,6 +1,7 @@
 -- | Free variable analysis on STG terms.
 module StgFVs (
-    annTopBindingsFreeVars
+    annTopBindingsFreeVars,
+    annBindingFreeVars
   ) where
 
 import GhcPrelude
@@ -26,13 +27,17 @@ addLocals :: [Id] -> Env -> Env
 addLocals bndrs env
   = env { locals = extendVarSetList (locals env) bndrs }
 
--- | Annotates a top-level STG binding with its free variables.
+-- | Annotates a top-level STG binding group with its free variables.
 annTopBindingsFreeVars :: [StgTopBinding] -> [CgStgTopBinding]
 annTopBindingsFreeVars = map go
   where
     go (StgTopStringLit id bs) = StgTopStringLit id bs
     go (StgTopLifted bind)
-      = StgTopLifted (fst (binding emptyEnv emptyVarSet bind))
+      = StgTopLifted (annBindingFreeVars bind)
+
+-- | Annotates an STG binding with its free variables.
+annBindingFreeVars :: StgBinding -> CgStgBinding
+annBindingFreeVars = fst . binding emptyEnv emptyDVarSet
 
 boundIds :: StgBinding -> [Id]
 boundIds (StgNonRec b _) = [b]
@@ -53,35 +58,35 @@ boundIds (StgRec pairs)  = map fst pairs
 --      knot-tying.
 
 -- | This makes sure that only local, non-global free vars make it into the set.
-mkFreeVarSet :: Env -> [Id] -> IdSet
-mkFreeVarSet env = mkVarSet . filter (`elemVarSet` locals env)
+mkFreeVarSet :: Env -> [Id] -> DIdSet
+mkFreeVarSet env = mkDVarSet . filter (`elemVarSet` locals env)
 
-args :: Env -> [StgArg] -> IdSet
+args :: Env -> [StgArg] -> DIdSet
 args env = mkFreeVarSet env . mapMaybe f
   where
     f (StgVarArg occ) = Just occ
     f _               = Nothing
 
-binding :: Env -> IdSet -> StgBinding -> (CgStgBinding, IdSet)
+binding :: Env -> DIdSet -> StgBinding -> (CgStgBinding, DIdSet)
 binding env body_fv (StgNonRec bndr r) = (StgNonRec bndr r', fvs)
   where
     -- See Note [Tacking local binders]
     (r', rhs_fvs) = rhs env r
-    fvs = delVarSet body_fv bndr `unionVarSet` rhs_fvs
+    fvs = delDVarSet body_fv bndr `unionDVarSet` rhs_fvs
 binding env body_fv (StgRec pairs) = (StgRec pairs', fvs)
   where
     -- See Note [Tacking local binders]
     bndrs = map fst pairs
     (rhss, rhs_fvss) = mapAndUnzip (rhs env . snd) pairs
     pairs' = zip bndrs rhss
-    fvs = delVarSetList (unionVarSets (body_fv:rhs_fvss)) bndrs
+    fvs = delDVarSetList (unionDVarSets (body_fv:rhs_fvss)) bndrs
 
-expr :: Env -> StgExpr -> (CgStgExpr, IdSet)
+expr :: Env -> StgExpr -> (CgStgExpr, DIdSet)
 expr env = go
   where
     go (StgApp occ as)
-      = (StgApp occ as, unionVarSet (args env as) (mkFreeVarSet env [occ]))
-    go (StgLit lit) = (StgLit lit, emptyVarSet)
+      = (StgApp occ as, unionDVarSet (args env as) (mkFreeVarSet env [occ]))
+    go (StgLit lit) = (StgLit lit, emptyDVarSet)
     go (StgConApp dc as tys) = (StgConApp dc as tys, args env as)
     go (StgOpApp op as ty) = (StgOpApp op as ty, args env as)
     go StgLam{} = pprPanic "StgFVs: StgLam" empty
@@ -90,16 +95,16 @@ expr env = go
         (scrut', scrut_fvs) = go scrut
         -- See Note [Tacking local binders]
         (alts', alt_fvss) = mapAndUnzip (alt (addLocals [bndr] env)) alts
-        alt_fvs = unionVarSets alt_fvss
-        fvs = delVarSet (unionVarSet scrut_fvs alt_fvs) bndr
-    go (StgLet bind body) = go_bind StgLet bind body
-    go (StgLetNoEscape bind body) = go_bind StgLetNoEscape bind body
+        alt_fvs = unionDVarSets alt_fvss
+        fvs = delDVarSet (unionDVarSet scrut_fvs alt_fvs) bndr
+    go (StgLet ext bind body) = go_bind (StgLet ext) bind body
+    go (StgLetNoEscape ext bind body) = go_bind (StgLetNoEscape ext) bind body
     go (StgTick tick e) = (StgTick tick e', fvs')
       where
         (e', fvs) = go e
-        fvs' = unionVarSet (tickish tick) fvs
-        tickish (Breakpoint _ ids) = mkVarSet ids
-        tickish _                  = emptyVarSet
+        fvs' = unionDVarSet (tickish tick) fvs
+        tickish (Breakpoint _ ids) = mkDVarSet ids
+        tickish _                  = emptyDVarSet
 
     go_bind dc bind body = (dc bind' body', fvs)
       where
@@ -108,18 +113,18 @@ expr env = go
         (body', body_fvs) = expr env' body
         (bind', fvs) = binding env' body_fvs bind
 
-rhs :: Env -> StgRhs -> (CgStgRhs, IdSet)
+rhs :: Env -> StgRhs -> (CgStgRhs, DIdSet)
 rhs env (StgRhsClosure _ ccs uf bndrs body)
   = (StgRhsClosure fvs ccs uf bndrs body', fvs)
   where
     -- See Note [Tacking local binders]
     (body', body_fvs) = expr (addLocals bndrs env) body
-    fvs = delVarSetList body_fvs bndrs
+    fvs = delDVarSetList body_fvs bndrs
 rhs env (StgRhsCon ccs dc as) = (StgRhsCon ccs dc as, args env as)
 
-alt :: Env -> StgAlt -> (CgStgAlt, IdSet)
+alt :: Env -> StgAlt -> (CgStgAlt, DIdSet)
 alt env (con, bndrs, e) = ((con, bndrs, e'), fvs)
   where
     -- See Note [Tacking local binders]
     (e', rhs_fvs) = expr (addLocals bndrs env) e
-    fvs = delVarSetList rhs_fvs bndrs
+    fvs = delDVarSetList rhs_fvs bndrs
index 35a498f..383b016 100644 (file)
@@ -40,6 +40,8 @@ import StgSyn
 
 import DynFlags
 import Bag              ( Bag, emptyBag, isEmptyBag, snocBag, bagToList )
+import BasicTypes       ( TopLevelFlag(..), isTopLevel )
+import CostCentre       ( isCurrentCCS )
 import Id               ( Id, idType, isLocalId, isJoinId )
 import VarSet
 import DataCon
@@ -84,7 +86,7 @@ lintStgTopBindings dflags unarised whodunnit binds
         addInScopeVars binders $
             lint_binds binds
 
-    lint_bind (StgTopLifted bind) = lintStgBinds bind
+    lint_bind (StgTopLifted bind) = lintStgBinds TopLevel bind
     lint_bind (StgTopStringLit v _) = return [v]
 
 lintStgArg :: StgArg -> LintM ()
@@ -94,26 +96,39 @@ lintStgArg (StgVarArg v) = lintStgVar v
 lintStgVar :: Id -> LintM ()
 lintStgVar id = checkInScope id
 
-lintStgBinds :: StgBinding -> LintM [Id] -- Returns the binders
-lintStgBinds (StgNonRec binder rhs) = do
-    lint_binds_help (binder,rhs)
+lintStgBinds :: TopLevelFlag -> StgBinding -> LintM [Id] -- Returns the binders
+lintStgBinds top_lvl (StgNonRec binder rhs) = do
+    lint_binds_help top_lvl (binder,rhs)
     return [binder]
 
-lintStgBinds (StgRec pairs)
+lintStgBinds top_lvl (StgRec pairs)
   = addInScopeVars binders $ do
-        mapM_ lint_binds_help pairs
+        mapM_ (lint_binds_help top_lvl) pairs
         return binders
   where
     binders = [b | (b,_) <- pairs]
 
-lint_binds_help :: (Id, StgRhs) -> LintM ()
-lint_binds_help (binder, rhs)
+lint_binds_help :: TopLevelFlag -> (Id, StgRhs) -> LintM ()
+lint_binds_help top_lvl (binder, rhs)
   = addLoc (RhsOf binder) $ do
+        when (isTopLevel top_lvl) (checkNoCurrentCCS rhs)
         lintStgRhs rhs
         -- Check binder doesn't have unlifted type or it's a join point
         checkL (isJoinId binder || not (isUnliftedType (idType binder)))
                (mkUnliftedTyMsg binder rhs)
 
+-- | Top-level bindings can't inherit the cost centre stack from their
+-- (static) allocation site.
+checkNoCurrentCCS :: StgRhs -> LintM ()
+checkNoCurrentCCS (StgRhsClosure _ ccs _ _ _)
+  | isCurrentCCS ccs
+  = addErrL (text "Top-level StgRhsClosure with CurrentCCS")
+checkNoCurrentCCS (StgRhsCon ccs _ _)
+  | isCurrentCCS ccs
+  = addErrL (text "Top-level StgRhsCon with CurrentCCS")
+checkNoCurrentCCS _
+  = return ()
+
 lintStgRhs :: StgRhs -> LintM ()
 
 lintStgRhs (StgRhsClosure _ _ _ [] expr)
@@ -154,14 +169,14 @@ lintStgExpr (StgOpApp _ args _) =
 lintStgExpr lam@(StgLam _ _) =
     addErrL (text "Unexpected StgLam" <+> ppr lam)
 
-lintStgExpr (StgLet binds body) = do
-    binders <- lintStgBinds binds
+lintStgExpr (StgLet binds body) = do
+    binders <- lintStgBinds NotTopLevel binds
     addLoc (BodyOfLetRec binders) $
       addInScopeVars binders $
         lintStgExpr body
 
-lintStgExpr (StgLetNoEscape binds body) = do
-    binders <- lintStgBinds binds
+lintStgExpr (StgLetNoEscape binds body) = do
+    binders <- lintStgBinds NotTopLevel binds
     addLoc (BodyOfLetRec binders) $
       addInScopeVars binders $
         lintStgExpr body
diff --git a/compiler/stgSyn/StgSubst.hs b/compiler/stgSyn/StgSubst.hs
new file mode 100644 (file)
index 0000000..72fbe41
--- /dev/null
@@ -0,0 +1,80 @@
+{-# LANGUAGE CPP #-}
+
+module StgSubst where
+
+#include "HsVersions.h"
+
+import GhcPrelude
+
+import Id
+import VarEnv
+import Control.Monad.Trans.State.Strict
+import Outputable
+import Util
+
+-- | A renaming substitution from 'Id's to 'Id's. Like 'RnEnv2', but not
+-- maintaining pairs of substitutions. Like @"CoreSubst".'CoreSubst.Subst'@, but
+-- with the domain being 'Id's instead of entire 'CoreExpr'.
+data Subst = Subst InScopeSet IdSubstEnv
+
+type IdSubstEnv = IdEnv Id
+
+-- | @emptySubst = 'mkEmptySubst' 'emptyInScopeSet'@
+emptySubst :: Subst
+emptySubst = mkEmptySubst emptyInScopeSet
+
+-- | Constructs a new 'Subst' assuming the variables in the given 'InScopeSet'
+-- are in scope.
+mkEmptySubst :: InScopeSet -> Subst
+mkEmptySubst in_scope = Subst in_scope emptyVarEnv
+
+-- | Substitutes an 'Id' for another one according to the 'Subst' given in a way
+-- that avoids shadowing the 'InScopeSet', returning the result and an updated
+-- 'Subst' that should be used by subsequent substitutions.
+substBndr :: Id -> Subst -> (Id, Subst)
+substBndr id (Subst in_scope env)
+  = (new_id, Subst new_in_scope new_env)
+  where
+    new_id = uniqAway in_scope id
+    no_change = new_id == id -- in case nothing shadowed
+    new_in_scope = in_scope `extendInScopeSet` new_id
+    new_env
+      | no_change = delVarEnv env id
+      | otherwise = extendVarEnv env id new_id
+
+-- | @substBndrs = runState . traverse (state . substBndr)@
+substBndrs :: Traversable f => f Id -> Subst -> (f Id, Subst)
+substBndrs = runState . traverse (state . substBndr)
+
+-- | Substitutes an occurrence of an identifier for its counterpart recorded
+-- in the 'Subst'.
+lookupIdSubst :: HasCallStack => Id -> Subst -> Id
+lookupIdSubst id (Subst in_scope env)
+  | not (isLocalId id) = id
+  | Just id' <- lookupVarEnv env id = id'
+  | Just id' <- lookupInScope in_scope id = id'
+  | otherwise = WARN( True, text "StgSubst.lookupIdSubst" <+> ppr id $$ ppr in_scope)
+                id
+
+-- | Substitutes an occurrence of an identifier for its counterpart recorded
+-- in the 'Subst'. Does not generate a debug warning if the identifier to
+-- to substitute wasn't in scope.
+noWarnLookupIdSubst :: HasCallStack => Id -> Subst -> Id
+noWarnLookupIdSubst id (Subst in_scope env)
+  | not (isLocalId id) = id
+  | Just id' <- lookupVarEnv env id = id'
+  | Just id' <- lookupInScope in_scope id = id'
+  | otherwise = id
+
+-- | Add the 'Id' to the in-scope set and remove any existing substitutions for
+-- it.
+extendInScope :: Id -> Subst -> Subst
+extendInScope id (Subst in_scope env) = Subst (in_scope `extendInScopeSet` id) env
+
+-- | Add a substitution for an 'Id' to the 'Subst': you must ensure that the
+-- in-scope set is such that TyCORep Note [The substitution invariant]
+-- holds after extending the substitution like this.
+extendSubst :: Id -> Id -> Subst -> Subst
+extendSubst id new_id (Subst in_scope env)
+  = ASSERT2( new_id `elemInScopeSet` in_scope, ppr id <+> ppr new_id $$ ppr in_scope )
+    Subst in_scope (extendVarEnv env id new_id)
index 145c001..5ba63e4 100644 (file)
@@ -16,6 +16,7 @@ generation.
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE TypeFamilies #-}
 {-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE ConstraintKinds #-}
 
 module StgSyn (
         StgArg(..),
@@ -23,7 +24,8 @@ module StgSyn (
         GenStgTopBinding(..), GenStgBinding(..), GenStgExpr(..), GenStgRhs(..),
         GenStgAlt, AltType(..),
 
-        StgPass(..), XRhsClosure, NoExtSilent, noExtSilent,
+        StgPass(..), BinderP, XRhsClosure, XLet, XLetNoEscape,
+        NoExtSilent, noExtSilent,
 
         UpdateFlag(..), isUpdatable,
 
@@ -33,6 +35,9 @@ module StgSyn (
         -- a set of synonyms for the code gen parameterisation
         CgStgTopBinding, CgStgBinding, CgStgExpr, CgStgRhs, CgStgAlt,
 
+        -- a set of synonyms for the lambda lifting parameterisation
+        LlStgTopBinding, LlStgBinding, LlStgExpr, LlStgRhs, LlStgAlt,
+
         -- a set of synonyms to distinguish in- and out variants
         InStgArg,  InStgTopBinding,  InStgBinding,  InStgExpr,  InStgRhs,  InStgAlt,
         OutStgArg, OutStgTopBinding, OutStgBinding, OutStgExpr, OutStgRhs, OutStgAlt,
@@ -101,8 +106,8 @@ data GenStgTopBinding pass
   | StgTopStringLit Id ByteString
 
 data GenStgBinding pass
-  = StgNonRec Id (GenStgRhs pass)
-  | StgRec    [(Id, GenStgRhs pass)]
+  = StgNonRec (BinderP pass) (GenStgRhs pass)
+  | StgRec    [(BinderP pass, GenStgRhs pass)]
 
 {-
 ************************************************************************
@@ -245,7 +250,7 @@ TODO: Encode this via an extension to GenStgExpr à la TTG.
 -}
 
   | StgLam
-        (NonEmpty Id)
+        (NonEmpty (BinderP pass))
         StgExpr    -- Body of lambda
 
 {-
@@ -259,13 +264,9 @@ This has the same boxed/unboxed business as Core case expressions.
 -}
 
   | StgCase
-        (GenStgExpr pass)
-                    -- the thing to examine
-
-        Id          -- binds the result of evaluating the scrutinee
-
+        (GenStgExpr pass) -- the thing to examine
+        (BinderP pass) -- binds the result of evaluating the scrutinee
         AltType
-
         [GenStgAlt pass]
                     -- The DEFAULT case is always *first*
                     -- if it is there at all
@@ -365,10 +366,12 @@ And so the code for let(rec)-things:
 -}
 
   | StgLet
+        (XLet pass)
         (GenStgBinding pass)    -- right hand sides (see below)
         (GenStgExpr pass)       -- body
 
   | StgLetNoEscape
+        (XLetNoEscape pass)
         (GenStgBinding pass)    -- right hand sides (see below)
         (GenStgExpr pass)       -- body
 
@@ -405,7 +408,7 @@ data GenStgRhs pass
                            --   list just before 'CodeGen'.
         CostCentreStack    -- ^ CCS to be attached (default is CurrentCCS)
         !UpdateFlag        -- ^ 'ReEntrant' | 'Updatable' | 'SingleEntry'
-        [Id]               -- ^ arguments; if empty, then not a function;
+        [BinderP pass]     -- ^ arguments; if empty, then not a function;
                            --   as above, order is important.
         (GenStgExpr pass)  -- ^ body
 
@@ -437,8 +440,9 @@ The second flavour of right-hand-side is for constructors (simple but important)
 
 -- | Used as a data type index for the stgSyn AST
 data StgPass
-  = CodeGen
-  | Vanilla
+  = Vanilla
+  | LiftLams
+  | CodeGen
 
 -- | Like 'HsExpression.NoExt', but with an 'Outputable' instance that returns
 -- 'empty'.
@@ -455,9 +459,24 @@ noExtSilent = NoExtSilent
 -- TODO: Maybe move this to HsExtensions? I'm not sure about the implications
 -- on build time...
 
-type family XRhsClosure (pass :: StgPass) where
-  XRhsClosure 'CodeGen = IdSet -- code gen needs to track non-global free vars
-  XRhsClosure 'Vanilla = NoExtSilent
+-- TODO: Do we really want to the extension point type families to have a closed
+-- domain?
+type family BinderP (pass :: StgPass)
+type instance BinderP 'Vanilla = Id
+type instance BinderP 'CodeGen = Id
+
+type family XRhsClosure (pass :: StgPass)
+type instance XRhsClosure 'Vanilla = NoExtSilent
+-- | Code gen needs to track non-global free vars
+type instance XRhsClosure 'CodeGen = DIdSet
+
+type family XLet (pass :: StgPass)
+type instance XLet 'Vanilla = NoExtSilent
+type instance XLet 'CodeGen = NoExtSilent
+
+type family XLetNoEscape (pass :: StgPass)
+type instance XLetNoEscape 'Vanilla = NoExtSilent
+type instance XLetNoEscape 'CodeGen = NoExtSilent
 
 stgRhsArity :: StgRhs -> Int
 stgRhsArity (StgRhsClosure _ _ _ bndrs _)
@@ -506,9 +525,9 @@ exprHasCafRefs (StgLam _ body)
   = exprHasCafRefs body
 exprHasCafRefs (StgCase scrt _ _ alts)
   = exprHasCafRefs scrt || any altHasCafRefs alts
-exprHasCafRefs (StgLet bind body)
+exprHasCafRefs (StgLet bind body)
   = bindHasCafRefs bind || exprHasCafRefs body
-exprHasCafRefs (StgLetNoEscape bind body)
+exprHasCafRefs (StgLetNoEscape bind body)
   = bindHasCafRefs bind || exprHasCafRefs body
 exprHasCafRefs (StgTick _ expr)
   = exprHasCafRefs expr
@@ -562,7 +581,7 @@ rather than from the scrutinee type.
 
 type GenStgAlt pass
   = (AltCon,          -- alts: data constructor,
-     [Id],            -- constructor's parameters,
+     [BinderP pass],  -- constructor's parameters,
      GenStgExpr pass) -- ...right-hand side.
 
 data AltType
@@ -589,6 +608,12 @@ type StgExpr       = GenStgExpr       'Vanilla
 type StgRhs        = GenStgRhs        'Vanilla
 type StgAlt        = GenStgAlt        'Vanilla
 
+type LlStgTopBinding = GenStgTopBinding 'LiftLams
+type LlStgBinding    = GenStgBinding    'LiftLams
+type LlStgExpr       = GenStgExpr       'LiftLams
+type LlStgRhs        = GenStgRhs        'LiftLams
+type LlStgAlt        = GenStgAlt        'LiftLams
+
 type CgStgTopBinding = GenStgTopBinding 'CodeGen
 type CgStgBinding    = GenStgBinding    'CodeGen
 type CgStgExpr       = GenStgExpr       'CodeGen
@@ -676,8 +701,15 @@ Robin Popplestone asked for semi-colon separators on STG binds; here's
 hoping he likes terminators instead...  Ditto for case alternatives.
 -}
 
+type OutputablePass pass =
+  ( Outputable (XLet pass)
+  , Outputable (XLetNoEscape pass)
+  , Outputable (XRhsClosure pass)
+  , OutputableBndr (BinderP pass)
+  )
+
 pprGenStgTopBinding
-  :: Outputable (XRhsClosure pass) => GenStgTopBinding pass -> SDoc
+  :: OutputablePass pass => GenStgTopBinding pass -> SDoc
 pprGenStgTopBinding (StgTopStringLit bndr str)
   = hang (hsep [pprBndr LetBind bndr, equals])
         4 (pprHsBytes str <> semi)
@@ -685,7 +717,7 @@ pprGenStgTopBinding (StgTopLifted bind)
   = pprGenStgBinding bind
 
 pprGenStgBinding
-  :: (Outputable (XRhsClosure pass)) => GenStgBinding pass -> SDoc
+  :: OutputablePass pass => GenStgBinding pass -> SDoc
 
 pprGenStgBinding (StgNonRec bndr rhs)
   = hang (hsep [pprBndr LetBind bndr, equals])
@@ -709,27 +741,23 @@ pprStgTopBindings binds
 instance Outputable StgArg where
     ppr = pprStgArg
 
-instance (Outputable (XRhsClosure pass))
-                => Outputable (GenStgTopBinding pass) where
+instance OutputablePass pass => Outputable (GenStgTopBinding pass) where
     ppr = pprGenStgTopBinding
 
-instance (Outputable (XRhsClosure pass))
-                => Outputable (GenStgBinding pass) where
+instance OutputablePass pass => Outputable (GenStgBinding pass) where
     ppr = pprGenStgBinding
 
-instance (Outputable (XRhsClosure pass))
-                => Outputable (GenStgExpr pass) where
+instance OutputablePass pass => Outputable (GenStgExpr pass) where
     ppr = pprStgExpr
 
-instance (Outputable (XRhsClosure pass))
-                => Outputable (GenStgRhs pass) where
+instance OutputablePass pass => Outputable (GenStgRhs pass) where
     ppr rhs = pprStgRhs rhs
 
 pprStgArg :: StgArg -> SDoc
 pprStgArg (StgVarArg var) = ppr var
 pprStgArg (StgLitArg con) = ppr con
 
-pprStgExpr :: (Outputable (XRhsClosure pass)) => GenStgExpr pass -> SDoc
+pprStgExpr :: OutputablePass pass => GenStgExpr pass -> SDoc
 -- special case
 pprStgExpr (StgLit lit)     = ppr lit
 
@@ -773,19 +801,19 @@ pprStgExpr (StgLet srt (StgNonRec bndr (StgRhsClosure cc bi free_vars upd_flag a
 
 -- special case: let ... in let ...
 
-pprStgExpr (StgLet bind expr@(StgLet _ _))
+pprStgExpr (StgLet ext bind expr@StgLet{})
   = ($$)
-      (sep [hang (text "let {")
+      (sep [hang (text "let" <+> ppr ext <+> text "{")
                 2 (hsep [pprGenStgBinding bind, text "} in"])])
       (ppr expr)
 
 -- general case
-pprStgExpr (StgLet bind expr)
-  = sep [hang (text "let {") 2 (pprGenStgBinding bind),
+pprStgExpr (StgLet ext bind expr)
+  = sep [hang (text "let" <+> ppr ext <+> text "{") 2 (pprGenStgBinding bind),
            hang (text "} in ") 2 (ppr expr)]
 
-pprStgExpr (StgLetNoEscape bind expr)
-  = sep [hang (text "let-no-escape {")
+pprStgExpr (StgLetNoEscape ext bind expr)
+  = sep [hang (text "let-no-escape" <+> ppr ext <+> text "{")
                 2 (pprGenStgBinding bind),
            hang (text "} in ")
                 2 (ppr expr)]
@@ -805,7 +833,7 @@ pprStgExpr (StgCase expr bndr alt_type alts)
            nest 2 (vcat (map pprStgAlt alts)),
            char '}']
 
-pprStgAlt :: (Outputable (XRhsClosure pass)) => GenStgAlt pass -> SDoc
+pprStgAlt :: OutputablePass pass => GenStgAlt pass -> SDoc
 pprStgAlt (con, params, expr)
   = hang (hsep [ppr con, sep (map (pprBndr CasePatBind) params), text "->"])
          4 (ppr expr <> semi)
@@ -821,7 +849,7 @@ instance Outputable AltType where
   ppr (AlgAlt tc)     = text "Alg"    <+> ppr tc
   ppr (PrimAlt tc)    = text "Prim"   <+> ppr tc
 
-pprStgRhs :: (Outputable (XRhsClosure pass)) => GenStgRhs pass -> SDoc
+pprStgRhs :: OutputablePass pass => GenStgRhs pass -> SDoc
 
 -- special case
 pprStgRhs (StgRhsClosure ext cc upd_flag [{-no args-}] (StgApp func []))
index 0048478..bdae8b6 100644 (file)
@@ -1004,6 +1004,58 @@ by saying ``-fno-wombat``.
     Chapter 7 of `Andre Santos's PhD
     thesis <http://research.microsoft.com/en-us/um/people/simonpj/papers/santos-thesis.ps.gz>`__
 
+.. ghc-flag:: -fstg-lift-lams
+    :shortdesc: Enable late lambda lifting on the STG intermediate
+        language. Implied by :ghc-flag:`-O2`.
+    :type: dynamic
+    :reverse: -fno-stg-lift-lams
+    :category:
+
+    :default: on
+
+    Enables the late lambda lifting optimisation on the STG
+    intermediate language. This selectively lifts local functions to
+    top-level by converting free variables into function parameters.
+
+.. ghc-flag:: -fstg-lift-lams-known
+    :shortdesc: Allow turning known into unknown calls while performing
+        late lambda lifting.
+    :type: dynamic
+    :reverse: -fno-stg-lift-lams-known
+    :category:
+
+    :default: off
+
+    Allow turning known into unknown calls while performing
+    late lambda lifting. This is deemed non-beneficial, so it's
+    off by default.
+
+.. ghc-flag:: -fstg-lift-lams-non-rec-args
+    :shortdesc: Create top-level non-recursive functions with at most <n>
+        parameters while performing late lambda lifting.
+    :type: dynamic
+    :reverse: -fno-stg-lift-lams-non-rec-args-any
+    :category:
+
+    :default: 5
+
+    Create top-level non-recursive functions with at most <n> parameters
+    while performing late lambda lifting. The default is 5, the number of
+    available parameter registers on x86_64.
+
+.. ghc-flag:: -fstg-lift-lams-rec-args
+    :shortdesc: Create top-level recursive functions with at most <n>
+        parameters while performing late lambda lifting.
+    :type: dynamic
+    :reverse: -fno-stg-lift-lams-rec-args-any
+    :category:
+
+    :default: 5
+
+    Create top-level recursive functions with at most <n> parameters
+    while performing late lambda lifting. The default is 5, the number of
+    available parameter registers on x86_64.
+
 .. ghc-flag:: -fstrictness
     :shortdesc: Turn on strictness analysis.
         Implied by :ghc-flag:`-O`. Implies :ghc-flag:`-fworker-wrapper`
diff --git a/inplace/test b/inplace/test
deleted file mode 100755 (executable)
index cccdc75..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-# See Note [Spaces in TEST_HC].
-echo
-echo 'Possible fix: put quotes around $(TEST_HC) in your Makefile.'
diff --git a/inplace/test spaces b/inplace/test spaces
deleted file mode 120000 (symlink)
index c5e82d7..0000000
+++ /dev/null
@@ -1 +0,0 @@
-bin
\ No newline at end of file
index eedf0c0..99b1726 100644 (file)
@@ -17,7 +17,7 @@ test('join003',
 test('join004',
   [collect_stats('bytes allocated',5),],
   compile_and_run,
-  [''])
+  ['-fno-stg-lift-lams'])
 
 test('join005', normal, compile, [''])
 test('join006', normal, compile, [''])