Join points
authorLuke Maurer <maurerl@cs.uoregon.edu>
Wed, 1 Feb 2017 16:56:01 +0000 (11:56 -0500)
committerDavid Feuer <David.Feuer@gmail.com>
Wed, 1 Feb 2017 18:44:52 +0000 (13:44 -0500)
This major patch implements Join Points, as described in
https://ghc.haskell.org/trac/ghc/wiki/SequentCore.  You have
to read that page, and especially the paper it links to, to
understand what's going on; but it is very cool.

It's Luke Maurer's work, but done in close collaboration with Simon PJ.

This Phab is a squash-merge of wip/join-points branch of
http://github.com/lukemaurer/ghc. There are many, many interdependent
changes.

Reviewers: goldfire, mpickering, bgamari, simonmar, dfeuer, austin

Subscribers: simonpj, dfeuer, mpickering, Mikolaj, thomie

Differential Revision: https://phabricator.haskell.org/D2853

77 files changed:
compiler/backpack/RnModIface.hs
compiler/basicTypes/BasicTypes.hs
compiler/basicTypes/Demand.hs
compiler/basicTypes/Id.hs
compiler/basicTypes/IdInfo.hs
compiler/basicTypes/IdInfo.hs-boot
compiler/basicTypes/Var.hs
compiler/basicTypes/VarEnv.hs
compiler/coreSyn/CoreArity.hs
compiler/coreSyn/CoreArity.hs-boot [new file with mode: 0644]
compiler/coreSyn/CoreLint.hs
compiler/coreSyn/CorePrep.hs
compiler/coreSyn/CoreStats.hs
compiler/coreSyn/CoreSubst.hs
compiler/coreSyn/CoreSyn.hs
compiler/coreSyn/CoreUnfold.hs
compiler/coreSyn/CoreUtils.hs
compiler/coreSyn/MkCore.hs
compiler/coreSyn/PprCore.hs
compiler/deSugar/DsUtils.hs
compiler/iface/IfaceSyn.hs
compiler/iface/TcIface.hs
compiler/iface/ToIface.hs
compiler/simplCore/CSE.hs
compiler/simplCore/CoreMonad.hs
compiler/simplCore/FloatIn.hs
compiler/simplCore/FloatOut.hs
compiler/simplCore/LiberateCase.hs
compiler/simplCore/OccurAnal.hs
compiler/simplCore/SetLevels.hs
compiler/simplCore/SimplCore.hs
compiler/simplCore/SimplEnv.hs
compiler/simplCore/SimplUtils.hs
compiler/simplCore/Simplify.hs
compiler/specialise/Rules.hs
compiler/specialise/SpecConstr.hs
compiler/specialise/Specialise.hs
compiler/stgSyn/CoreToStg.hs
compiler/stranal/DmdAnal.hs
compiler/stranal/WorkWrap.hs
compiler/stranal/WwLib.hs
compiler/types/Type.hs
compiler/utils/Outputable.hs
compiler/utils/UniqFM.hs
testsuite/tests/deSugar/should_compile/T2431.stderr
testsuite/tests/deriving/perf/all.T
testsuite/tests/numeric/should_compile/T7116.stdout
testsuite/tests/perf/compiler/all.T
testsuite/tests/perf/haddock/all.T
testsuite/tests/perf/join_points/Makefile [new file with mode: 0644]
testsuite/tests/perf/join_points/all.T [new file with mode: 0644]
testsuite/tests/perf/join_points/join001.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join002.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join002.stdout [new file with mode: 0644]
testsuite/tests/perf/join_points/join003.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join003.stdout [new file with mode: 0644]
testsuite/tests/perf/join_points/join004.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join004.stdout [new file with mode: 0644]
testsuite/tests/perf/join_points/join005.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join006.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join007.hs [new file with mode: 0644]
testsuite/tests/perf/join_points/join007.stdout [new file with mode: 0644]
testsuite/tests/perf/should_run/all.T
testsuite/tests/roles/should_compile/Roles13.stderr
testsuite/tests/simplCore/should_compile/Makefile
testsuite/tests/simplCore/should_compile/T13156.hs
testsuite/tests/simplCore/should_compile/T13156.stdout
testsuite/tests/simplCore/should_compile/T3717.stderr
testsuite/tests/simplCore/should_compile/T3772.stdout
testsuite/tests/simplCore/should_compile/T4908.stderr
testsuite/tests/simplCore/should_compile/T4930.stderr
testsuite/tests/simplCore/should_compile/T5658b.stdout
testsuite/tests/simplCore/should_compile/T7360.stderr
testsuite/tests/simplCore/should_compile/T9400.stderr
testsuite/tests/simplCore/should_compile/all.T
testsuite/tests/simplCore/should_compile/par01.stderr
testsuite/tests/simplCore/should_compile/spec-inline.stderr

index a6d6edd..e32bb74 100644 (file)
@@ -606,8 +606,8 @@ rnIfaceConAlt (IfaceDataAlt data_occ) = IfaceDataAlt <$> rnIfaceGlobal data_occ
 rnIfaceConAlt alt = pure alt
 
 rnIfaceLetBndr :: Rename IfaceLetBndr
-rnIfaceLetBndr (IfLetBndr fs ty info)
-    = IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info
+rnIfaceLetBndr (IfLetBndr fs ty info jpi)
+    = IfLetBndr fs <$> rnIfaceType ty <*> rnIfaceIdInfo info <*> pure jpi
 
 rnIfaceLamBndr :: Rename IfaceLamBndr
 rnIfaceLamBndr (bndr, oneshot) = (,) <$> rnIfaceBndr bndr <*> pure oneshot
index cf4c970..ff4d2c7 100644 (file)
@@ -24,7 +24,7 @@ module BasicTypes(
 
         ConTag, ConTagZ, fIRST_TAG,
 
-        Arity, RepArity,
+        Arity, RepArity, JoinArity,
 
         Alignment,
 
@@ -64,13 +64,15 @@ module BasicTypes(
         noOneShotInfo, hasNoOneShotInfo, isOneShotInfo,
         bestOneShot, worstOneShot,
 
-        OccInfo(..), seqOccInfo, zapFragileOcc, isOneOcc,
-        isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isNoOcc,
+        OccInfo(..), noOccInfo, seqOccInfo, zapFragileOcc, isOneOcc,
+        isDeadOcc, isStrongLoopBreaker, isWeakLoopBreaker, isManyOccs,
         strongLoopBreaker, weakLoopBreaker,
 
         InsideLam, insideLam, notInsideLam,
         OneBranch, oneBranch, notOneBranch,
         InterestingCxt,
+        TailCallInfo(..), tailCallInfo, zapOccTailCallInfo,
+        isAlwaysTailCalled,
 
         EP(..),
 
@@ -154,6 +156,12 @@ type Arity = Int
 --  \(# x, y #) -> fib (x + y) has representation arity 2
 type RepArity = Int
 
+-- | The number of arguments that a join point takes. Unlike the arity of a
+-- function, this is a purely syntactic property and is fixed when the join
+-- point is created (or converted from a value). Both type and value arguments
+-- are counted.
+type JoinArity = Int
+
 {-
 ************************************************************************
 *                                                                      *
@@ -808,20 +816,23 @@ defn of OccInfo here, safely at the bottom
 
 -- | identifier Occurrence Information
 data OccInfo
-  = NoOccInfo           -- ^ There are many occurrences, or unknown occurrences
+  = ManyOccs        { occ_tail    :: !TailCallInfo }
+                        -- ^ There are many occurrences, or unknown occurrences
 
   | IAmDead             -- ^ Marks unused variables.  Sometimes useful for
                         -- lambda and case-bound variables.
 
-  | OneOcc
-        !InsideLam
-        !OneBranch
-        !InterestingCxt -- ^ Occurs exactly once, not inside a rule
+  | OneOcc          { occ_in_lam  :: !InsideLam
+                    , occ_one_br  :: !OneBranch
+                    , occ_int_cxt :: !InterestingCxt
+                    , occ_tail    :: !TailCallInfo }
+                        -- ^ Occurs exactly once (per branch), not inside a rule
 
   -- | This identifier breaks a loop of mutually recursive functions. The field
   -- marks whether it is only a loop breaker due to a reference in a rule
-  | IAmALoopBreaker     -- Note [LoopBreaker OccInfo]
-        !RulesOnly
+  | IAmALoopBreaker { occ_rules_only :: !RulesOnly
+                    , occ_tail       :: !TailCallInfo }
+                        -- Note [LoopBreaker OccInfo]
 
   deriving (Eq)
 
@@ -839,9 +850,12 @@ Note [LoopBreaker OccInfo]
 See OccurAnal Note [Weak loop breakers]
 -}
 
-isNoOcc :: OccInfo -> Bool
-isNoOcc NoOccInfo = True
-isNoOcc _         = False
+noOccInfo :: OccInfo
+noOccInfo = ManyOccs { occ_tail = NoTailCallInfo }
+
+isManyOccs :: OccInfo -> Bool
+isManyOccs ManyOccs{} = True
+isManyOccs _          = False
 
 seqOccInfo :: OccInfo -> ()
 seqOccInfo occ = occ `seq` ()
@@ -868,17 +882,41 @@ oneBranch, notOneBranch :: OneBranch
 oneBranch    = True
 notOneBranch = False
 
+-----------------
+data TailCallInfo = AlwaysTailCalled JoinArity -- See Note [TailCallInfo]
+                  | NoTailCallInfo
+  deriving (Eq)
+
+tailCallInfo :: OccInfo -> TailCallInfo
+tailCallInfo IAmDead   = NoTailCallInfo
+tailCallInfo other     = occ_tail other
+
+zapOccTailCallInfo :: OccInfo -> OccInfo
+zapOccTailCallInfo IAmDead   = IAmDead
+zapOccTailCallInfo occ       = occ { occ_tail = NoTailCallInfo }
+
+isAlwaysTailCalled :: OccInfo -> Bool
+isAlwaysTailCalled occ
+  = case tailCallInfo occ of AlwaysTailCalled{} -> True
+                             NoTailCallInfo     -> False
+
+instance Outputable TailCallInfo where
+  ppr (AlwaysTailCalled ar) = sep [ text "Tail", int ar ]
+  ppr _                     = empty
+
+-----------------
 strongLoopBreaker, weakLoopBreaker :: OccInfo
-strongLoopBreaker = IAmALoopBreaker False
-weakLoopBreaker   = IAmALoopBreaker True
+strongLoopBreaker = IAmALoopBreaker False NoTailCallInfo
+weakLoopBreaker   = IAmALoopBreaker True  NoTailCallInfo
 
 isWeakLoopBreaker :: OccInfo -> Bool
-isWeakLoopBreaker (IAmALoopBreaker _) = True
+isWeakLoopBreaker (IAmALoopBreaker{}) = True
 isWeakLoopBreaker _                   = False
 
 isStrongLoopBreaker :: OccInfo -> Bool
-isStrongLoopBreaker (IAmALoopBreaker False) = True   -- Loop-breaker that breaks a non-rule cycle
-isStrongLoopBreaker _                       = False
+isStrongLoopBreaker (IAmALoopBreaker { occ_rules_only = False }) = True
+  -- Loop-breaker that breaks a non-rule cycle
+isStrongLoopBreaker _                                            = False
 
 isDeadOcc :: OccInfo -> Bool
 isDeadOcc IAmDead = True
@@ -889,16 +927,21 @@ isOneOcc (OneOcc {}) = True
 isOneOcc _           = False
 
 zapFragileOcc :: OccInfo -> OccInfo
-zapFragileOcc (OneOcc {}) = NoOccInfo
-zapFragileOcc occ         = occ
+-- Keep only the most robust data: deadness, loop-breaker-hood
+zapFragileOcc (OneOcc {}) = noOccInfo
+zapFragileOcc occ         = zapOccTailCallInfo occ
 
 instance Outputable OccInfo where
   -- only used for debugging; never parsed.  KSW 1999-07
-  ppr NoOccInfo            = empty
-  ppr (IAmALoopBreaker ro) = text "LoopBreaker" <> if ro then char '!' else empty
+  ppr (ManyOccs tails)     = pprShortTailCallInfo tails
   ppr IAmDead              = text "Dead"
-  ppr (OneOcc inside_lam one_branch int_cxt)
-        = text "Once" <> pp_lam <> pp_br <> pp_args
+  ppr (IAmALoopBreaker rule_only tails)
+        = text "LoopBreaker" <> pp_ro <> pprShortTailCallInfo tails
+        where
+          pp_ro | rule_only = char '!'
+                | otherwise = empty
+  ppr (OneOcc inside_lam one_branch int_cxt tail_info)
+        = text "Once" <> pp_lam <> pp_br <> pp_args <> pp_tail
         where
           pp_lam | inside_lam = char 'L'
                  | otherwise  = empty
@@ -906,8 +949,43 @@ instance Outputable OccInfo where
                  | otherwise  = char '*'
           pp_args | int_cxt   = char '!'
                   | otherwise = empty
+          pp_tail             = pprShortTailCallInfo tail_info
+
+pprShortTailCallInfo :: TailCallInfo -> SDoc
+pprShortTailCallInfo (AlwaysTailCalled ar) = char 'T' <> brackets (int ar)
+pprShortTailCallInfo NoTailCallInfo        = empty
 
 {-
+Note [TailCallInfo]
+~~~~~~~~~~~~~~~~~~~
+The occurrence analyser determines what can be made into a join point, but it
+doesn't change the binder into a JoinId because then it would be inconsistent
+with the occurrences. Thus it's left to the simplifier (or to simpleOptExpr) to
+change the IdDetails.
+
+The AlwaysTailCalled marker actually means slightly more than simply that the
+function is always tail-called. See Note [Invariants on join points].
+
+This info is quite fragile and should not be relied upon unless the occurrence
+analyser has *just* run. Use 'Id.isJoinId_maybe' for the permanent state of
+the join-point-hood of a binder; a join id itself will not be marked
+AlwaysTailCalled.
+
+Note that there is a 'TailCallInfo' on a 'ManyOccs' value. One might expect that
+being tail-called would mean that the variable could only appear once per branch
+(thus getting a `OneOcc { occ_one_br = True }` occurrence info), but a join
+point can also be invoked from other join points, not just from case branches:
+
+  let j1 x = ...
+      j2 y = ... j1 z {- tail call -} ...
+  in case w of
+       A -> j1 v
+       B -> j2 u
+       C -> j2 q
+
+Here both 'j1' and 'j2' will get marked AlwaysTailCalled, but j1 will get
+ManyOccs and j2 will get `OneOcc { occ_one_br = True }`.
+
 ************************************************************************
 *                                                                      *
                 Default method specification
index c72bf39..8cacf22 100644 (file)
@@ -304,7 +304,9 @@ splitArgStrProdDmd n (Str _ s) = splitStrProdDmd n s
 splitStrProdDmd :: Int -> StrDmd -> Maybe [ArgStr]
 splitStrProdDmd n HyperStr   = Just (replicate n strBot)
 splitStrProdDmd n HeadStr    = Just (replicate n strTop)
-splitStrProdDmd n (SProd ds) = ASSERT( ds `lengthIs` n) Just ds
+splitStrProdDmd n (SProd ds) = WARN( not (ds `lengthIs` n),
+                                     text "splitStrProdDmd" $$ ppr n $$ ppr ds )
+                               Just ds
 splitStrProdDmd _ (SCall {}) = Nothing
       -- This can happen when the programmer uses unsafeCoerce,
       -- and we don't then want to crash the compiler (Trac #9208)
@@ -586,7 +588,9 @@ seqArgUse _          = ()
 splitUseProdDmd :: Int -> UseDmd -> Maybe [ArgUse]
 splitUseProdDmd n Used        = Just (replicate n useTop)
 splitUseProdDmd n UHead       = Just (replicate n Abs)
-splitUseProdDmd n (UProd ds)  = ASSERT2( ds `lengthIs` n, text "splitUseProdDmd" $$ ppr n $$ ppr ds )
+splitUseProdDmd n (UProd ds)  = WARN( not (ds `lengthIs` n),
+                                      text "splitUseProdDmd" $$ ppr n
+                                                             $$ ppr ds )
                                 Just ds
 splitUseProdDmd _ (UCall _ _) = Nothing
       -- This can happen when the programmer uses unsafeCoerce,
index 2b1bdfd..acb22e8 100644 (file)
@@ -52,7 +52,7 @@ module Id (
         globaliseId, localiseId,
         setIdInfo, lazySetIdInfo, modifyIdInfo, maybeModifyIdInfo,
         zapLamIdInfo, zapIdDemandInfo, zapIdUsageInfo, zapIdUsageEnvInfo,
-        zapIdUsedOnceInfo,
+        zapIdUsedOnceInfo, zapIdTailCallInfo,
         zapFragileIdInfo, zapIdStrictness,
         transferPolyIdInfo,
 
@@ -73,6 +73,10 @@ module Id (
         -- ** Evidence variables
         DictId, isDictId, isEvVar,
 
+        -- ** Join variables
+        JoinId, isJoinId, isJoinId_maybe, idJoinArity,
+        asJoinId, asJoinId_maybe, zapJoinId,
+
         -- ** Inline pragma stuff
         idInlinePragma, setInlinePragma, modifyInlinePragma,
         idInlineActivation, setInlineActivation, idRuleMatchInfo,
@@ -118,11 +122,12 @@ import IdInfo
 import BasicTypes
 
 -- Imported and re-exported
-import Var( Id, CoVar, DictId,
+import Var( Id, CoVar, DictId, JoinId,
             InId,  InVar,
             OutId, OutVar,
-            idInfo, idDetails, globaliseId, varType,
-            isId, isLocalId, isGlobalId, isExportedId )
+            idInfo, idDetails, setIdDetails, globaliseId, varType,
+            isId, isLocalId, isGlobalId, isExportedId,
+            isJoinId, isJoinId_maybe )
 import qualified Var
 
 import Type
@@ -157,7 +162,10 @@ infixl  1 `setIdUnfolding`,
           `idCafInfo`,
 
           `setIdDemandInfo`,
-          `setIdStrictness`
+          `setIdStrictness`,
+
+          `asJoinId`,
+          `asJoinId_maybe`
 
 {-
 ************************************************************************
@@ -546,6 +554,40 @@ isDictId id = isDictTy (idType id)
 {-
 ************************************************************************
 *                                                                      *
+              Join variables
+*                                                                      *
+************************************************************************
+-}
+
+idJoinArity :: JoinId -> JoinArity
+idJoinArity id = isJoinId_maybe id `orElse` pprPanic "idJoinArity" (ppr id)
+
+asJoinId :: Id -> JoinArity -> JoinId
+asJoinId id arity = WARN(not (isLocalId id),
+                         text "global id being marked as join var:" <+> ppr id)
+                    WARN(not (is_vanilla_or_join id),
+                         ppr id <+> pprIdDetails (idDetails id))
+                    id `setIdDetails` JoinId arity
+  where
+    is_vanilla_or_join id = case Var.idDetails id of
+                              VanillaId -> True
+                              JoinId {} -> True
+                              _         -> False
+
+zapJoinId :: Id -> Id
+-- May be a regular id already
+zapJoinId jid | isJoinId jid = zapIdTailCallInfo (jid `setIdDetails` VanillaId)
+                                 -- Core Lint may complain if still marked
+                                 -- as AlwaysTailCalled
+              | otherwise    = jid
+
+asJoinId_maybe :: Id -> Maybe JoinArity -> Id
+asJoinId_maybe id (Just arity) = asJoinId id arity
+asJoinId_maybe id Nothing      = zapJoinId id
+
+{-
+************************************************************************
+*                                                                      *
 \subsection{IdInfo stuff}
 *                                                                      *
 ************************************************************************
@@ -590,9 +632,11 @@ zapIdStrictness id = modifyIdInfo (`setStrictnessInfo` nopSig) id
 isStrictId :: Id -> Bool
 isStrictId id
   = ASSERT2( isId id, text "isStrictId: not an id: " <+> ppr id )
+         not (isJoinId id) && (
            (isStrictType (idType id)) ||
            -- Take the best of both strictnesses - old and new
            (isStrictDmd (idDemandInfo id))
+         )
 
         ---------------------------------
         -- UNFOLDING
@@ -660,7 +704,7 @@ setIdOccInfo :: Id -> OccInfo -> Id
 setIdOccInfo id occ_info = modifyIdInfo (`setOccInfo` occ_info) id
 
 zapIdOccInfo :: Id -> Id
-zapIdOccInfo b = b `setIdOccInfo` NoOccInfo
+zapIdOccInfo b = b `setIdOccInfo` noOccInfo
 
 {-
         ---------------------------------
@@ -804,6 +848,9 @@ zapIdUsageEnvInfo = zapInfo zapUsageEnvInfo
 zapIdUsedOnceInfo :: Id -> Id
 zapIdUsedOnceInfo = zapInfo zapUsedOnceInfo
 
+zapIdTailCallInfo :: Id -> Id
+zapIdTailCallInfo = zapInfo zapTailCallInfo
+
 {-
 Note [transferPolyIdInfo]
 ~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -869,13 +916,14 @@ transferPolyIdInfo old_id abstract_wrt new_id
     old_inline_prag = inlinePragInfo old_info
     old_occ_info    = occInfo old_info
     new_arity       = old_arity + arity_increase
+    new_occ_info    = zapOccTailCallInfo old_occ_info
 
     old_strictness  = strictnessInfo old_info
     new_strictness  = increaseStrictSigArity arity_increase old_strictness
 
     transfer new_info = new_info `setArityInfo` new_arity
                                  `setInlinePragInfo` old_inline_prag
-                                 `setOccInfo` old_occ_info
+                                 `setOccInfo` new_occ_info
                                  `setStrictnessInfo` new_strictness
 
 isNeverLevPolyId :: Id -> Bool
index 4481539..f29fba7 100644 (file)
@@ -14,6 +14,7 @@ Haskell. [WDP 94/11])
 module IdInfo (
         -- * The IdDetails type
         IdDetails(..), pprIdDetails, coVarDetails, isCoVarDetails,
+        JoinArity, isJoinIdDetails_maybe,
         RecSelParent(..),
 
         -- * The IdInfo type
@@ -28,6 +29,7 @@ module IdInfo (
         -- ** Zapping various forms of Info
         zapLamInfo, zapFragileInfo,
         zapDemandInfo, zapUsageInfo, zapUsageEnvInfo, zapUsedOnceInfo,
+        zapTailCallInfo,
 
         -- ** The ArityInfo type
         ArityInfo,
@@ -55,6 +57,9 @@ module IdInfo (
         InsideLam, OneBranch,
         insideLam, notInsideLam, oneBranch, notOneBranch,
 
+        TailCallInfo(..),
+        tailCallInfo, isAlwaysTailCalled,
+
         -- ** The RuleInfo type
         RuleInfo(..),
         emptyRuleInfo,
@@ -153,6 +158,8 @@ data IdDetails
   | CoVarId    -- ^ A coercion variable
                -- This only covers /un-lifted/ coercions, of type
                -- (t1 ~# t2) or (t1 ~R# t2), not their lifted variants
+  | JoinId JoinArity           -- ^ An 'Id' for a join point taking n arguments
+       -- Note [Join points] in CoreSyn
 
 -- | Recursive Selector Parent
 data RecSelParent = RecSelData TyCon | RecSelPatSyn PatSyn deriving Eq
@@ -176,6 +183,10 @@ isCoVarDetails :: IdDetails -> Bool
 isCoVarDetails CoVarId = True
 isCoVarDetails _       = False
 
+isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
+isJoinIdDetails_maybe (JoinId join_arity) = Just join_arity
+isJoinIdDetails_maybe _                   = Nothing
+
 instance Outputable IdDetails where
     ppr = pprIdDetails
 
@@ -195,6 +206,7 @@ pprIdDetails other     = brackets (pp other)
                               = brackets $ text "RecSel" <>
                                            ppWhen is_naughty (text "(naughty)")
    pp CoVarId                 = text "CoVarId"
+   pp (JoinId arity)          = text "JoinId" <> parens (int arity)
 
 {-
 ************************************************************************
@@ -285,7 +297,7 @@ vanillaIdInfo
             unfoldingInfo       = noUnfolding,
             oneShotInfo         = NoOneShotInfo,
             inlinePragInfo      = defaultInlinePragma,
-            occInfo             = NoOccInfo,
+            occInfo             = noOccInfo,
             demandInfo          = topDmd,
             strictnessInfo      = nopSig,
             callArityInfo       = unknownArity,
@@ -482,12 +494,16 @@ zapLamInfo info@(IdInfo {occInfo = occ, demandInfo = demand})
   where
         -- The "unsafe" occ info is the ones that say I'm not in a lambda
         -- because that might not be true for an unsaturated lambda
-    is_safe_occ (OneOcc in_lam _ _) = in_lam
-    is_safe_occ _other              = True
+    is_safe_occ occ | isAlwaysTailCalled occ     = False
+    is_safe_occ (OneOcc { occ_in_lam = in_lam }) = in_lam
+    is_safe_occ _other                           = True
 
     safe_occ = case occ of
-                 OneOcc _ once int_cxt -> OneOcc insideLam once int_cxt
-                 _other                -> occ
+                 OneOcc{} -> occ { occ_in_lam = True
+                                 , occ_tail   = NoTailCallInfo }
+                 IAmALoopBreaker{}
+                          -> occ { occ_tail   = NoTailCallInfo }
+                 _other   -> occ
 
     is_safe_dmd dmd = not (isStrictDmd dmd)
 
@@ -529,6 +545,14 @@ zapFragileUnfolding unf
  | isFragileUnfolding unf = noUnfolding
  | otherwise              = unf
 
+zapTailCallInfo :: IdInfo -> Maybe IdInfo
+zapTailCallInfo info
+  = case occInfo info of
+      occ | isAlwaysTailCalled occ -> Just (info `setOccInfo` safe_occ)
+          | otherwise              -> Nothing
+        where
+          safe_occ = occ { occ_tail = NoTailCallInfo }
+
 {-
 ************************************************************************
 *                                                                      *
index 0fabad3..27c1217 100644 (file)
@@ -1,4 +1,5 @@
 module IdInfo where
+import BasicTypes
 import Outputable
 data IdInfo
 data IdDetails
@@ -6,5 +7,6 @@ data IdDetails
 vanillaIdInfo :: IdInfo
 coVarDetails :: IdDetails
 isCoVarDetails :: IdDetails -> Bool
+isJoinIdDetails_maybe :: IdDetails -> Maybe JoinArity
 pprIdDetails :: IdDetails -> SDoc
 
index 3f78c28..2b728af 100644 (file)
@@ -34,7 +34,7 @@
 
 module Var (
         -- * The main data type and synonyms
-        Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId,
+        Var, CoVar, Id, NcId, DictId, DFunId, EvVar, EqVar, EvId, IpId, JoinId,
         TyVar, TypeVar, KindVar, TKVar, TyCoVar,
 
         -- * In and Out variants
@@ -57,6 +57,7 @@ module Var (
         -- ** Predicates
         isId, isTyVar, isTcTyVar,
         isLocalVar, isLocalId, isCoVar, isNonCoVarId, isTyCoVar,
+        isJoinId, isJoinId_maybe,
         isGlobalId, isExportedId,
         mustHaveLocalBinding,
 
@@ -83,8 +84,11 @@ module Var (
 
 import {-# SOURCE #-}   TyCoRep( Type, Kind, pprKind )
 import {-# SOURCE #-}   TcType( TcTyVarDetails, pprTcTyVarDetails, vanillaSkolemTv )
-import {-# SOURCE #-}   IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails, vanillaIdInfo, pprIdDetails )
+import {-# SOURCE #-}   IdInfo( IdDetails, IdInfo, coVarDetails, isCoVarDetails,
+                                isJoinIdDetails_maybe,
+                                vanillaIdInfo, pprIdDetails )
 
+import BasicTypes ( JoinArity )
 import Name hiding (varName)
 import Unique ( Uniquable, Unique, getKey, getUnique
               , mkUniqueGrimily, nonDetCmpUnique )
@@ -92,6 +96,7 @@ import Util
 import Binary
 import DynFlags
 import Outputable
+import Maybes
 
 import Data.Data
 
@@ -149,6 +154,7 @@ type IpId   = EvId      -- A term-level implicit parameter
 
 -- | Equality Variable
 type EqVar  = EvId      -- Boxed equality evidence
+type JoinId = Id        -- A join variable
 
 -- | Type or Coercion Variable
 type TyCoVar = Id       -- Type, *or* coercion variable
@@ -612,6 +618,14 @@ isNonCoVarId :: Var -> Bool
 isNonCoVarId (Id { id_details = details }) = not (isCoVarDetails details)
 isNonCoVarId _                             = False
 
+isJoinId :: Var -> Bool
+isJoinId (Id { id_details = details }) = isJust (isJoinIdDetails_maybe details)
+isJoinId _                             = False
+
+isJoinId_maybe :: Var -> Maybe JoinArity
+isJoinId_maybe (Id { id_details = details }) = isJoinIdDetails_maybe details
+isJoinId_maybe _                             = Nothing
+
 isLocalId :: Var -> Bool
 isLocalId (Id { idScope = LocalId _ }) = True
 isLocalId _                            = False
index dcb64a9..64357d7 100644 (file)
@@ -12,8 +12,8 @@ module VarEnv (
         elemVarEnv,
         extendVarEnv, extendVarEnv_C, extendVarEnv_Acc, extendVarEnv_Directly,
         extendVarEnvList,
-        plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusVarEnvList,
-        alterVarEnv,
+        plusVarEnv, plusVarEnv_C, plusVarEnv_CD, plusMaybeVarEnv_C,
+        plusVarEnvList, alterVarEnv,
         delVarEnvList, delVarEnv, delVarEnv_Directly,
         minusVarEnv, intersectsVarEnv,
         lookupVarEnv, lookupVarEnv_NF, lookupWithDefaultVarEnv,
@@ -41,6 +41,7 @@ module VarEnv (
         unitDVarEnv,
         delDVarEnv,
         delDVarEnvList,
+        minusDVarEnv,
         partitionDVarEnv,
         anyDVarEnv,
 
@@ -450,6 +451,7 @@ minusVarEnv       :: VarEnv a -> VarEnv b -> VarEnv a
 intersectsVarEnv  :: VarEnv a -> VarEnv a -> Bool
 plusVarEnv_C      :: (a -> a -> a) -> VarEnv a -> VarEnv a -> VarEnv a
 plusVarEnv_CD     :: (a -> a -> a) -> VarEnv a -> a -> VarEnv a -> a -> VarEnv a
+plusMaybeVarEnv_C :: (a -> a -> Maybe a) -> VarEnv a -> VarEnv a -> VarEnv a
 mapVarEnv         :: (a -> b) -> VarEnv a -> VarEnv b
 modifyVarEnv      :: (a -> a) -> VarEnv a -> Var -> VarEnv a
 
@@ -471,6 +473,7 @@ extendVarEnv_Directly = addToUFM_Directly
 extendVarEnvList = addListToUFM
 plusVarEnv_C     = plusUFM_C
 plusVarEnv_CD    = plusUFM_CD
+plusMaybeVarEnv_C = plusMaybeUFM_C
 delVarEnvList    = delListFromUFM
 delVarEnv        = delFromUFM
 minusVarEnv      = minusUFM
@@ -541,6 +544,9 @@ mkDVarEnv = listToUDFM
 extendDVarEnv :: DVarEnv a -> Var -> a -> DVarEnv a
 extendDVarEnv = addToUDFM
 
+minusDVarEnv :: DVarEnv a -> DVarEnv a' -> DVarEnv a
+minusDVarEnv = minusUDFM
+
 lookupDVarEnv :: DVarEnv a -> Var -> Maybe a
 lookupDVarEnv = lookupUDFM
 
index 0d6f4b6..49f58c6 100644 (file)
@@ -11,7 +11,8 @@
 -- | Arity and eta expansion
 module CoreArity (
         manifestArity, exprArity, typeArity, exprBotStrictness_maybe,
-        exprEtaExpandArity, findRhsArity, CheapFun, etaExpand
+        exprEtaExpandArity, findRhsArity, CheapFun, etaExpand,
+        etaExpandToJoinPoint, etaExpandToJoinPointRule
     ) where
 
 #include "HsVersions.h"
@@ -952,11 +953,17 @@ etaInfoApp subst (Case e b ty alts) eis
 etaInfoApp subst (Let b e) eis
   = Let b' (etaInfoApp subst' e eis)
   where
-    (subst', b') = subst_bind subst b
+    (subst', b') = etaInfoAppBind subst b eis
 
 etaInfoApp subst (Tick t e) eis
   = Tick (substTickish subst t) (etaInfoApp subst e eis)
 
+etaInfoApp subst expr _
+  | (Var fun, _) <- collectArgs expr
+  , Var fun' <- lookupIdSubst (text "etaInfoApp" <+> ppr fun) subst fun
+  , isJoinId fun'
+  = subst_expr subst expr
+
 etaInfoApp subst e eis
   = go (subst_expr subst e) eis
   where
@@ -965,6 +972,94 @@ etaInfoApp subst e eis
     go e (EtaCo co    : eis) = go (Cast e co) eis
 
 --------------
+-- | Apply the eta info to a local binding. Mostly delegates to
+-- `etaInfoAppLocalBndr` and `etaInfoAppRhs`.
+etaInfoAppBind :: Subst -> CoreBind -> [EtaInfo] -> (Subst, CoreBind)
+etaInfoAppBind subst (NonRec bndr rhs) eis
+  = (subst', NonRec bndr' rhs')
+  where
+    bndr_w_new_type = etaInfoAppLocalBndr bndr eis
+    (subst', bndr1) = substBndr subst bndr_w_new_type
+    rhs'            = etaInfoAppRhs subst bndr1 rhs eis
+    bndr'           | isJoinId bndr = bndr1 `setIdArity` manifestArity rhs'
+                                        -- Arity may have changed
+                                        -- (see etaInfoAppRhs example)
+                    | otherwise     = bndr1
+etaInfoAppBind subst (Rec pairs) eis
+  = (subst', Rec (bndrs' `zip` rhss'))
+  where
+    (bndrs, rhss)     = unzip pairs
+    bndrs_w_new_types = map (\bndr -> etaInfoAppLocalBndr bndr eis) bndrs
+    (subst', bndrs1)  = substRecBndrs subst bndrs_w_new_types
+    rhss'             = zipWith process bndrs1 rhss
+    process bndr' rhs = etaInfoAppRhs subst' bndr' rhs eis
+    bndrs'            | isJoinId (head bndrs)
+                      = [ bndr1 `setIdArity` manifestArity rhs'
+                        | (bndr1, rhs') <- bndrs1 `zip` rhss' ]
+                          -- Arities may have changed
+                          -- (see etaInfoAppRhs example)
+                      | otherwise
+                      = bndrs1
+
+--------------
+-- | Apply the eta info to a binder's RHS. Only interesting for a join point,
+-- where we might have this:
+--   join j :: a -> [a] -> [a]
+--        j x = \xs -> x : xs in jump j z
+-- Eta-expanding produces this:
+--   \ys -> (join j :: a -> [a] -> [a]
+--                j x = \xs -> x : xs in jump j z) ys
+-- Now when we push the application to ys inward (see Note [No crap in
+-- eta-expanded code]), it goes to the body of the RHS of the join point (after
+-- the lambda x!):
+--   \ys -> join j :: a -> [a]
+--               j x = x : ys in jump j z
+-- Note that the type and arity of j have both changed.
+etaInfoAppRhs :: Subst -> CoreBndr -> CoreExpr -> [EtaInfo] -> CoreExpr
+etaInfoAppRhs subst bndr expr eis
+  | Just arity <- isJoinId_maybe bndr
+  = do_join_point arity
+  | otherwise
+  = subst_expr subst expr
+  where
+    do_join_point arity = mkLams join_bndrs' join_body'
+      where
+        (join_bndrs, join_body) = collectNBinders arity expr
+        (subst', join_bndrs') = substBndrs subst join_bndrs
+        join_body' = etaInfoApp subst' join_body eis
+
+
+--------------
+-- | Apply the eta info to a local binder. A join point will have the EtaInfos
+-- applied to its RHS, so its type may change. See comment on etaInfoAppRhs for
+-- an example. See Note [No crap in eta-expanded code] for why all this is
+-- necessary.
+etaInfoAppLocalBndr :: CoreBndr -> [EtaInfo] -> CoreBndr
+etaInfoAppLocalBndr bndr orig_eis
+  = case isJoinId_maybe bndr of
+      Just arity -> bndr `setIdType` modifyJoinResTy arity (app orig_eis) ty
+      Nothing    -> bndr
+  where
+    ty = idType bndr
+
+    -- | Apply the given EtaInfos to the result type of the join point.
+    app :: [EtaInfo] -- To apply
+        -> Type      -- Result type of join point
+        -> Type      -- New result type
+    app [] ty
+      = ty
+    app (EtaVar v : eis) ty
+      | isId v    = app eis (funResultTy ty)
+      | otherwise = app eis (piResultTy ty (mkTyVarTy v))
+    app (EtaCo co : eis) ty
+      = ASSERT2(from_ty `eqType` ty, fsep ([text "can't apply", ppr orig_eis,
+                                            text "to", ppr bndr <+> dcolon <+>
+                                                       ppr (idType bndr)]))
+        app eis to_ty
+      where
+        Pair from_ty to_ty = coercionKind co
+
+--------------
 mkEtaWW :: Arity -> CoreExpr -> InScopeSet -> Type
         -> (InScopeSet, [EtaInfo])
         -- EtaInfo contains fresh variables,
@@ -1018,14 +1113,65 @@ mkEtaWW orig_n orig_expr in_scope orig_ty
 
 
 --------------
--- Avoiding unnecessary substitution; use short-cutting versions
+-- Don't use short-cutting substitution - we may be changing the types of join
+-- points, so applying the in-scope set is necessary
+-- TODO Check if we actually *are* changing any join points' types
 
 subst_expr :: Subst -> CoreExpr -> CoreExpr
-subst_expr = substExprSC (text "CoreArity:substExpr")
+subst_expr = substExpr (text "CoreArity:substExpr")
+
+
+--------------
 
-subst_bind :: Subst -> CoreBind -> (Subst, CoreBind)
-subst_bind = substBindSC
+-- | Split an expression into the given number of binders and a body,
+-- eta-expanding if necessary. Counts value *and* type binders.
+etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
+etaExpandToJoinPoint join_arity expr
+  = go join_arity [] expr
+  where
+    go 0 rev_bs e         = (reverse rev_bs, e)
+    go n rev_bs (Lam b e) = go (n-1) (b : rev_bs) e
+    go n rev_bs e         = case etaBodyForJoinPoint n e of
+                              (bs, e') -> (reverse rev_bs ++ bs, e')
+
+etaExpandToJoinPointRule :: JoinArity -> CoreRule -> CoreRule
+etaExpandToJoinPointRule _ rule@(BuiltinRule {})
+  = WARN(True, (sep [text "Can't eta-expand built-in rule:", ppr rule]))
+      -- How did a local binding get a built-in rule anyway? Probably a plugin.
+    rule
+etaExpandToJoinPointRule join_arity rule@(Rule { ru_bndrs = bndrs, ru_rhs = rhs
+                                               , ru_args  = args })
+  | need_args == 0
+  = rule
+  | need_args < 0
+  = pprPanic "etaExpandToJoinPointRule" (ppr join_arity $$ ppr rule)
+  | otherwise
+  = rule { ru_bndrs = bndrs ++ new_bndrs, ru_args = args ++ new_args
+         , ru_rhs = new_rhs }
+  where
+    need_args = join_arity - length args
+    (new_bndrs, new_rhs) = etaBodyForJoinPoint need_args rhs
+    new_args = varsToCoreExprs new_bndrs
+
+-- Adds as many binders as asked for; assumes expr is not a lambda
+etaBodyForJoinPoint :: Int -> CoreExpr -> ([CoreBndr], CoreExpr)
+etaBodyForJoinPoint need_args body
+  = go need_args (exprType body) (init_subst body) [] body
+  where
+    go 0 _  _     rev_bs e
+      = (reverse rev_bs, e)
+    go n ty subst rev_bs e
+      | Just (tv, res_ty) <- splitForAllTy_maybe ty
+      , let (subst', tv') = Type.substTyVarBndr subst tv
+      = go (n-1) res_ty subst' (tv' : rev_bs) (e `App` Type (mkTyVarTy tv'))
+      | Just (arg_ty, res_ty) <- splitFunTy_maybe ty
+      , let (subst', b) = freshEtaId n subst arg_ty
+      = go (n-1) res_ty subst' (b : rev_bs) (e `App` Var b)
+      | otherwise
+      = pprPanic "etaBodyForJoinPoint" $ int need_args $$
+                                         ppr body $$ ppr (exprType body)
 
+    init_subst e = mkEmptyTCvSubst (mkInScopeSet (exprFreeVars e))
 
 --------------
 freshEtaId :: Int -> TCvSubst -> Type -> (TCvSubst, Id)
diff --git a/compiler/coreSyn/CoreArity.hs-boot b/compiler/coreSyn/CoreArity.hs-boot
new file mode 100644 (file)
index 0000000..4c155da
--- /dev/null
@@ -0,0 +1,6 @@
+module CoreArity where
+
+import BasicTypes
+import CoreSyn
+
+etaExpandToJoinPoint :: JoinArity -> CoreExpr -> ([CoreBndr], CoreExpr)
index c09b4a0..a776038 100644 (file)
@@ -37,6 +37,7 @@ import VarEnv
 import VarSet
 import Name
 import Id
+import IdInfo
 import PprCore
 import ErrUtils
 import Coercion
@@ -168,6 +169,28 @@ different types, called bad coercions. Following coercions are forbidden:
       coerced to (# B_1,..,B_m #) if n=m and for each pair A_i, B_i rules
       (a-e) holds.
 
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+
+We check the rules listed in Note [Invariants on join points] in CoreSyn. The
+only one that causes any difficulty is the first: All occurrences must be tail
+calls. To this end, along with the in-scope set, we remember in le_bad_joins the
+subset of join ids that are no longer allowed because they were declared "too
+far away." For example:
+
+  join j x = ... in
+  case e of
+    A -> jump j y -- good
+    B -> case (jump j z) of -- BAD
+           C -> join h = jump j w in ... -- good
+           D -> let x = jump j v in ... -- BAD
+
+A join point remains valid in case branches, so when checking the A branch, j
+is still valid. When we check the scrutinee of the inner case, however, we add j
+to le_bad_joins and catch the error. Similarly, join points can occur free in
+RHSes of other join points but not the RHSes of value bindings (thunks and
+functions).
+
 ************************************************************************
 *                                                                      *
                  Beginning and ending passes
@@ -251,6 +274,7 @@ coreDumpFlag CoreDesugar              = Just Opt_D_dump_ds
 coreDumpFlag CoreDesugarOpt           = Just Opt_D_dump_ds
 coreDumpFlag CoreTidy                 = Just Opt_D_dump_simpl
 coreDumpFlag CorePrep                 = Just Opt_D_dump_prep
+coreDumpFlag CoreOccurAnal            = Just Opt_D_dump_occur_anal
 
 coreDumpFlag CoreDoPrintCore          = Nothing
 coreDumpFlag (CoreDoRuleCheck {})     = Nothing
@@ -473,7 +497,7 @@ lintSingleBinding :: TopLevelFlag -> RecFlag -> (Id, CoreExpr) -> LintM ()
 lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
   = addLoc (RhsOf binder) $
          -- Check the rhs
-    do { ty <- lintRhs rhs
+    do { ty <- lintRhs binder rhs
        ; lint_bndr binder -- Check match to RHS type
        ; binder_ty <- applySubstTy (idType binder)
        ; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty)
@@ -481,6 +505,7 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
         -- Check the let/app invariant
         -- See Note [CoreSyn let/app invariant] in CoreSyn
        ; checkL (not (isUnliftedType binder_ty)
+            || isJoinId binder
             || (isNonRec rec_flag && exprOkForSpeculation rhs)
             || exprIsLiteralString rhs)
            (mkRhsPrimMsg binder rhs)
@@ -501,6 +526,11 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
            (mkTopNonLitStrMsg binder)
 
        ; flags <- getLintFlags
+
+        -- Check that if the binder is top-level, it's not a join point
+       ; checkL (not (isJoinId binder && isTopLevel top_lvl_flag))
+           (mkTopJoinMsg binder)
+
        ; when (lf_check_inline_loop_breakers flags
                && isStrongLoopBreaker (idOccInfo binder)
                && isInlinePragma (idInlinePragma binder))
@@ -535,7 +565,7 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
                ppr binder)
            _ -> return ()
 
-       ; mapM_ (lintCoreRule binder_ty) (idCoreRules binder)
+       ; mapM_ (lintCoreRule binder binder_ty) (idCoreRules binder)
        ; lintIdUnfolding binder binder_ty (idUnfolding binder) }
 
         -- We should check the unfolding, if any, but this is tricky because
@@ -546,20 +576,45 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
     lint_bndr var | isId var  = lintIdBndr top_lvl_flag var $ \_ -> return ()
                   | otherwise = return ()
 
--- | Checks the RHS of top-level bindings. It only differs from 'lintCoreExpr'
+-- | Checks the RHS of bindings. It only differs from 'lintCoreExpr'
 -- in that it doesn't reject occurrences of the function 'makeStatic' when they
--- appear at the top level and @lf_check_static_ptrs == AllowAtTopLevel@.
+-- appear at the top level and @lf_check_static_ptrs == AllowAtTopLevel@, and
+-- for join points, it skips the outer lambdas that take arguments to the
+-- join point.
 --
 -- See Note [Checking StaticPtrs].
-lintRhs :: CoreExpr -> LintM OutType
-lintRhs rhs = fmap lf_check_static_ptrs getLintFlags >>= go
+lintRhs :: Id -> CoreExpr -> LintM OutType
+lintRhs bndr rhs
+    | Just arity <- isJoinId_maybe bndr
+    = lint_join_lams arity arity True rhs
+    | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
+    = lint_join_lams arity arity False rhs
+  where
+    lint_join_lams 0 _ _ rhs
+      = lintCoreExpr rhs
+    lint_join_lams n tot enforce (Lam var expr)
+      = addLoc (LambdaBodyOf var) $
+        lintBinder var $ \ var' ->
+        do { body_ty <- lint_join_lams (n-1) tot enforce expr
+           ; return $ mkLamType var' body_ty }
+    lint_join_lams n tot True _other
+      = failWithL $ mkBadJoinArityMsg bndr tot (tot-n)
+    lint_join_lams _ _ False rhs
+      = markAllJoinsBad $ lintCoreExpr rhs
+          -- Future join point, not yet eta-expanded
+          -- Body is not a tail position
+
+-- Allow applications of the data constructor @StaticPtr@ at the top
+-- but produce errors otherwise.
+lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
   where
     -- Allow occurrences of 'makeStatic' at the top-level but produce errors
     -- otherwise.
     go AllowAtTopLevel
       | (binders0, rhs') <- collectTyBinders rhs
       , Just (fun, t, info, e) <- collectMakeStaticArgs rhs'
-      = foldr
+      = markAllJoinsBad $
+        foldr
         -- imitate @lintCoreExpr (Lam ...)@
         (\var loopBinders ->
           addLoc (LambdaBodyOf var) $
@@ -572,12 +627,12 @@ lintRhs rhs = fmap lf_check_static_ptrs getLintFlags >>= go
             addLoc (AnExpr rhs') $ lintCoreArgs fun_ty [Type t, info, e]
         )
         binders0
-    go _ = lintCoreExpr rhs
+    go _ = markAllJoinsBad $ lintCoreExpr rhs
 
 lintIdUnfolding :: Id -> Type -> Unfolding -> LintM ()
 lintIdUnfolding bndr bndr_ty (CoreUnfolding { uf_tmpl = rhs, uf_src = src })
   | isStableSource src
-  = do { ty <- lintCoreExpr rhs
+  = do { ty <- lintRhs bndr rhs
        ; ensureEqTys bndr_ty ty (mkRhsMsg bndr (text "unfolding") ty) }
 
 lintIdUnfolding bndr bndr_ty (DFunUnfolding { df_con = con, df_bndrs = bndrs
@@ -624,18 +679,13 @@ lintCoreExpr :: CoreExpr -> LintM OutType
 -- If you edit this function, you may need to update the GHC formalism
 -- See Note [GHC Formalism]
 lintCoreExpr (Var var)
-  = do  { checkL (isNonCoVarId var)
-                 (text "Non term variable" <+> ppr var)
-
-        ; checkDeadIdOcc var
-        ; var' <- lookupIdInScope var
-        ; return (idType var') }
+  = lintCoreVar var 0
 
 lintCoreExpr (Lit lit)
   = return (literalType lit)
 
 lintCoreExpr (Cast expr co)
-  = do { expr_ty <- lintCoreExpr expr
+  = do { expr_ty <- markAllJoinsBad $ lintCoreExpr expr
        ; co' <- applySubstCo co
        ; (_, k2, from_ty, to_ty, r) <- lintCoercion co'
        ; lintL (classifiesTypeWithValues k2)
@@ -644,14 +694,20 @@ lintCoreExpr (Cast expr co)
        ; ensureEqTys from_ty expr_ty (mkCastErr expr co' from_ty expr_ty)
        ; return to_ty }
 
-lintCoreExpr (Tick (Breakpoint _ ids) expr)
-  = do forM_ ids $ \id -> do
-         checkDeadIdOcc id
-         lookupIdInScope id
-       lintCoreExpr expr
-
-lintCoreExpr (Tick _other_tickish expr)
-  = lintCoreExpr expr
+lintCoreExpr (Tick tickish expr)
+  = do case tickish of
+         Breakpoint _ ids -> forM_ ids $ \id -> do
+                               checkDeadIdOcc id
+                               lookupIdInScope id
+         _                -> return ()
+       markAllJoinsBadIf block_joins $ lintCoreExpr expr
+  where
+    block_joins = not (tickish `tickishScopesLike` SoftScope)
+      -- TODO Consider whether this is the correct rule. It is consistent with
+      -- the simplifier's behaviour - cost-centre-scoped ticks become part of
+      -- the continuation, and thus they behave like part of an evaluation
+      -- context, but soft-scoped and non-scoped ticks simply wrap the result
+      -- (see Simplify.simplTick).
 
 lintCoreExpr (Let (NonRec tv (Type ty)) body)
   | isTyVar tv
@@ -661,7 +717,7 @@ lintCoreExpr (Let (NonRec tv (Type ty)) body)
     do  { addLoc (RhsOf tv) $ lintTyKind tv' ty'
                 -- Now extend the substitution so we
                 -- take advantage of it in the body
-        ; extendSubstL tv' ty'       $
+        ; extendSubstL tv ty'        $
           addLoc (BodyOfLetRec [tv]) $
           lintCoreExpr body } }
 
@@ -677,6 +733,8 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
 lintCoreExpr (Let (Rec pairs) body)
   = lintIdBndrs bndrs       $ \_ ->
     do  { checkL (null dups) (dupVars dups)
+        ; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
+            mkInconsistentRecMsg bndrs
         ; mapM_ (lintSingleBinding NotTopLevel Recursive) pairs
         ; addLoc (BodyOfLetRec bndrs) (lintCoreExpr body) }
   where
@@ -684,24 +742,15 @@ lintCoreExpr (Let (Rec pairs) body)
     (_, dups) = removeDups compare bndrs
 
 lintCoreExpr e@(App _ _)
-    = do lf <- getLintFlags
-         -- Check for a nested occurrence of the StaticPtr constructor.
-         -- See Note [Checking StaticPtrs].
-         case fun of
-           Var b | lf_check_static_ptrs lf /= AllowAnywhere
-                 , idName b == makeStaticName
-                 -> do
-              failWithL $ text "Found makeStatic nested in an expression: " <+>
-                          ppr e
-           _     -> go
+  = addLoc (AnExpr e) $
+    do { fun_ty <- lintCoreFun fun (length args)
+       ; lintCoreArgs fun_ty args }
   where
-    go = do { fun_ty <- lintCoreExpr fun
-            ; addLoc (AnExpr e) $ lintCoreArgs fun_ty args }
-
     (fun, args) = collectArgs e
 
 lintCoreExpr (Lam var expr)
   = addLoc (LambdaBodyOf var) $
+    markAllJoinsBad $
     lintBinder var $ \ var' ->
     do { body_ty <- lintCoreExpr expr
        ; return $ mkLamType var' body_ty }
@@ -709,7 +758,7 @@ lintCoreExpr (Lam var expr)
 lintCoreExpr e@(Case scrut var alt_ty alts) =
        -- Check the scrutinee
   do { let scrut_diverges = exprIsBottom scrut
-     ; scrut_ty <- lintCoreExpr scrut
+     ; scrut_ty <- markAllJoinsBad $ lintCoreExpr scrut
      ; (alt_ty, _) <- lintInTy alt_ty
      ; (var_ty, _) <- lintInTy (idType var)
 
@@ -762,6 +811,63 @@ lintCoreExpr (Coercion co)
   = do { (k1, k2, ty1, ty2, role) <- lintInCo co
        ; return (mkHeteroCoercionType role k1 k2 ty1 ty2) }
 
+lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed
+            -> LintM Type -- returns type of the *variable*
+lintCoreVar var nargs
+  = do  { checkL (isNonCoVarId var)
+                 (text "Non term variable" <+> ppr var)
+
+        ; lf <- getLintFlags
+          -- Check for a nested occurrence of the StaticPtr constructor.
+          -- See Note [Checking StaticPtrs].
+        ; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $
+            checkL (idName var /= makeStaticName) $
+              text "Found makeStatic nested in an expression"
+
+        ; checkDeadIdOcc var
+        ; ty   <- applySubstTy (idType var)
+        ; var' <- lookupIdInScope var
+        ; let ty' = idType var'
+        ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
+        ; mb_join_arity
+            <- case isJoinId_maybe var' of
+                 Just join_arity ->
+                   do  { checkL (isJoinId_maybe var == Just join_arity) $
+                           mkJoinBndrOccMismatchMsg var' var
+                       ; return $ Just join_arity }
+                 Nothing ->
+                   case tailCallInfo (idOccInfo var') of
+                     AlwaysTailCalled join_arity -> return $ Just join_arity
+                       -- This function will be turned into a join point by the
+                       -- simplifier; typecheck it as if it already were one
+                     NoTailCallInfo              -> return $ Nothing
+        ; case mb_join_arity of
+            Just join_arity ->
+              do  { bad <- isBadJoin var'
+                  ; checkL (not bad) $ mkJoinOutOfScopeMsg var'
+                  ; checkL (nargs == join_arity) $
+                      mkBadJumpMsg var' join_arity nargs }
+            Nothing ->
+              do  { checkL (not (isJoinId var)) $
+                      mkJoinBndrOccMismatchMsg var' var }
+        ; return (idType var') }
+
+lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed
+            -> LintM Type -- returns type of the *function*
+lintCoreFun (Var var) nargs
+  = lintCoreVar var nargs
+lintCoreFun (Lam var body) nargs
+  -- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see
+  -- Note [Beta redexes]
+  | nargs /= 0
+  = addLoc (LambdaBodyOf var) $
+    lintBinder var $ \ var' ->
+    do { body_ty <- lintCoreFun body (nargs - 1)
+       ; return $ mkLamType var' body_ty }
+lintCoreFun expr nargs
+  = markAllJoinsBadIf (nargs /= 0) $
+    lintCoreExpr expr
+
 {-
 Note [No alternatives lint check]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -783,6 +889,33 @@ correct, but that exprIsBottom is unable to see it. In particular, the
 empty-type check in exprIsBottom is an approximation. Therefore, this
 check is not fully reliable, and we keep both around.
 
+Note [Beta redexes]
+~~~~~~~~~~~~~~~~~~~
+Consider:
+
+  join j @x y z = ... in
+  (\@x y z -> jump j @x y z) @t e1 e2
+
+This is clearly ill-typed, since the jump is inside both an application and a
+lambda, either of which is enough to disqualify it as a tail call (see Note
+[Invariants on join points] in CoreSyn). However, strictly from a
+lambda-calculus perspective, the term doesn't go wrong---after the two beta
+reductions, the jump *is* a tail call and everything is fine.
+
+Why would we want to allow this when we have let? One reason is that a compound
+beta redex (that is, one with more than one argument) has different scoping
+rules: naively reducing the above example using lets will capture any free
+occurrence of y in e2. More fundamentally, type lets are tricky; many passes,
+such as Float Out, tacitly assume that the incoming program's type lets have
+all been dealt with by the simplifier. Thus we don't want to let-bind any types
+in, say, CoreSubst.simpleOptPgm, which in some circumstances can run immediately
+before Float Out.
+
+All that said, currently CoreSubst.simpleOptPgm is the only thing using this
+loophole, doing so to avoid re-traversing large functions (beta-reducing a type
+lambda without introducing a type let requires a substitution). TODO: Improve
+simpleOptPgm so that we can forget all this ever happened.
+
 ************************************************************************
 *                                                                      *
 \subsection[lintCoreArgs]{lintCoreArgs}
@@ -806,7 +939,7 @@ lintCoreArg fun_ty (Type arg_ty)
        ; lintTyApp fun_ty arg_ty' }
 
 lintCoreArg fun_ty arg
-  = do { arg_ty <- lintCoreExpr arg
+  = do { arg_ty <- markAllJoinsBad $ lintCoreExpr arg
            -- See Note [Levity polymorphism invariants] in CoreSyn
        ; lintL (not (isTypeLevPoly arg_ty))
            (text "Levity-polymorphic argument:" <+>
@@ -1225,15 +1358,21 @@ lint_app doc kfn kas
 *                                                                      *
 ********************************************************************* -}
 
-lintCoreRule :: OutType -> CoreRule -> LintM ()
-lintCoreRule _ (BuiltinRule {})
+lintCoreRule :: OutVar -> OutType -> CoreRule -> LintM ()
+lintCoreRule _ (BuiltinRule {})
   = return ()  -- Don't bother
 
-lintCoreRule fun_ty (Rule { ru_name = name, ru_bndrs = bndrs
-                          , ru_args = args, ru_rhs = rhs })
+lintCoreRule fun fun_ty rule@(Rule { ru_name = name, ru_bndrs = bndrs
+                                   , ru_args = args, ru_rhs = rhs })
   = lintBinders bndrs $ \ _ ->
     do { lhs_ty <- foldM lintCoreArg fun_ty args
-       ; rhs_ty <- lintCoreExpr rhs
+       ; rhs_ty <- case isJoinId_maybe fun of
+                     Just join_arity
+                       -> do { checkL (args `lengthIs` join_arity) $
+                                 mkBadJoinPointRuleMsg fun join_arity rule
+                               -- See Note [Rules for join points]
+                             ; lintCoreExpr rhs }
+                     _ -> markAllJoinsBad $ lintCoreExpr rhs
        ; ensureEqTys lhs_ty rhs_ty $
          (rule_doc <+> vcat [ text "lhs type:" <+> ppr lhs_ty
                             , text "rhs type:" <+> ppr rhs_ty ])
@@ -1273,6 +1412,26 @@ we'll end up with
    RULE forall x y. f ($gw y) = $gw (x+1)
 This seems sufficiently obscure that there isn't enough payoff to
 try to trim the forall'd binder list.
+
+Note [Rules for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+A join point cannot be partially applied. However, the left-hand side of a rule
+for a join point is effectively a *pattern*, not a piece of code, so there's an
+argument to be made for allowing a situation like this:
+
+  join $sj :: Int -> Int -> String
+       $sj n m = ...
+       j :: forall a. Eq a => a -> a -> String
+       {-# RULES "SPEC j" jump j @ Int $dEq = jump $sj #-}
+       j @a $dEq x y = ...
+
+Applying this rule can't turn a well-typed program into an ill-typed one, so
+conceivably we could allow it. But we can always eta-expand such an
+"undersaturated" rule (see 'CoreArity.etaExpandToJoinPointRule'), and in fact
+the simplifier would have to in order to deal with the RHS. So we take a
+conservative view and don't allow undersaturated rules for join points. See
+Note [Rules and join points] in OccurAnal for further discussion.
 -}
 
 {-
@@ -1624,6 +1783,8 @@ data LintEnv
        , le_subst :: TCvSubst        -- Current type substitution; we also use this
                                      -- to keep track of all the variables in scope,
                                      -- both Ids and TyVars
+       , le_bad_joins :: IdSet       -- Join points that are no longer valid
+                                     -- See Note [Join points]
        , le_dynflags :: DynFlags     -- DynamicFlags
        }
 
@@ -1734,7 +1895,8 @@ initL dflags flags m
   = case unLintM m env (emptyBag, emptyBag) of
       (_, errs) -> errs
   where
-    env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = [], le_dynflags = dflags }
+    env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = []
+             , le_dynflags = dflags, le_bad_joins = emptyVarSet }
 
 getLintFlags :: LintM LintFlags
 getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs)
@@ -1791,8 +1953,11 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs)
 addInScopeVars :: [Var] -> LintM a -> LintM a
 addInScopeVars vars m
   = LintM $ \ env errs ->
-    unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars })
+    unLintM m (env { le_subst     = extendTCvInScopeList (le_subst env) vars
+                   , le_bad_joins = bad_joins' env })
               errs
+  where
+    bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars)
 
 addInScopeVarSet :: VarSet -> LintM a -> LintM a
 addInScopeVarSet vars m
@@ -1803,7 +1968,11 @@ addInScopeVarSet vars m
 addInScopeVar :: Var -> LintM a -> LintM a
 addInScopeVar var m
   = LintM $ \ env errs ->
-    unLintM m (env { le_subst = extendTCvInScope (le_subst env) var }) errs
+    unLintM m (env { le_subst     = extendTCvInScope (le_subst env) var
+                   , le_bad_joins = bad_joins' env }) errs
+  where
+    bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var
+                   | otherwise    = le_bad_joins env
 
 extendSubstL :: TyVar -> Type -> LintM a -> LintM a
 extendSubstL tv ty m
@@ -1814,6 +1983,18 @@ updateTCvSubst :: TCvSubst -> LintM a -> LintM a
 updateTCvSubst subst' m
   = LintM $ \ env errs -> unLintM m (env { le_subst = subst' }) errs
 
+markAllJoinsBad :: LintM a -> LintM a
+markAllJoinsBad m
+  = LintM $ \ env errs -> unLintM m (marked env) errs
+  where
+    marked env = env { le_bad_joins = filterVarSet isJoinId in_set }
+      where
+        in_set = getInScopeVars (getTCvInScope (le_subst env))
+
+markAllJoinsBadIf :: Bool -> LintM a -> LintM a
+markAllJoinsBadIf True  m = markAllJoinsBad m
+markAllJoinsBadIf False m = m
+
 getTCvSubst :: LintM TCvSubst
 getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs))
 
@@ -1839,6 +2020,10 @@ lookupIdInScope id
   where
     out_of_scope = pprBndr LetBind id <+> text "is out of scope"
 
+isBadJoin :: Id -> LintM Bool
+isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env),
+                                     errs)
+
 lintTyCoVarInScope :: Var -> LintM ()
 lintTyCoVarInScope v = lintInScope (text "is out of scope") v
 
@@ -2096,6 +2281,62 @@ mkBadTyVarMsg tv
   = text "Non-tyvar used in TyVarTy:"
       <+> ppr tv <+> dcolon <+> ppr (varType tv)
 
+mkTopJoinMsg :: Var -> SDoc
+mkTopJoinMsg var
+  = text "Join point at top level:" <+> ppr var
+
+mkBadJoinArityMsg :: Var -> Int -> Int -> SDoc
+mkBadJoinArityMsg var ar nlams
+  = vcat [ text "Join point has too few lambdas",
+           text "Join var:" <+> ppr var,
+           text "Join arity:" <+> ppr ar,
+           text "Number of lambdas:" <+> ppr nlams ]
+
+mkJoinOutOfScopeMsg :: Var -> SDoc
+mkJoinOutOfScopeMsg var
+  = text "Join variable no longer in scope:" <+> ppr var
+
+mkBadJumpMsg :: Var -> Int -> Int -> SDoc
+mkBadJumpMsg var ar nargs
+  = vcat [ text "Join point invoked with wrong number of arguments",
+           text "Join var:" <+> ppr var,
+           text "Join arity:" <+> ppr ar,
+           text "Number of arguments:" <+> int nargs ]
+
+mkInconsistentRecMsg :: [Var] -> SDoc
+mkInconsistentRecMsg bndrs
+  = vcat [ text "Recursive let binders mix values and join points",
+           text "Binders:" <+> hsep (map ppr_with_details bndrs) ]
+  where
+    ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr)
+
+mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc
+mkJoinBndrOccMismatchMsg bndr var
+  = vcat [ text "Mismatch in join point status between binder and occurrence",
+           text "Var:" <+> ppr bndr,
+           text "Binder:" <+> ppr_join_status bndr,
+           text "Occ:" <+> ppr_join_status var ]
+  where
+    ppr_join_status v = case details of JoinId _ -> ppr details
+                                        _        -> text "not a join id"
+      where
+        details = idDetails v
+
+mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc
+mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty
+  = vcat [ text "Mismatch in type between binder and occurrence"
+         , text "Var:" <+> ppr bndr
+         , text "Binder type:" <+> ppr bndr_ty
+         , text "Occurrence type:" <+> ppr var_ty
+         , text "  Before subst:" <+> ppr (idType var) ]
+
+mkBadJoinPointRuleMsg :: JoinId -> JoinArity -> CoreRule -> SDoc
+mkBadJoinPointRuleMsg bndr join_arity rule
+  = vcat [ text "Join point has rule with wrong number of arguments"
+         , text "Var:" <+> ppr bndr
+         , text "Join arity:" <+> ppr join_arity
+         , text "Rule:" <+> ppr rule ]
+
 pprLeftOrRight :: LeftOrRight -> MsgDoc
 pprLeftOrRight CLeft  = text "left"
 pprLeftOrRight CRight = text "right"
index 4e4cbb9..74de5af 100644 (file)
@@ -204,9 +204,13 @@ corePrepTopBinds initialCorePrepEnv binds
   = go initialCorePrepEnv binds
   where
     go _   []             = return emptyFloats
-    go env (bind : binds) = do (env', bind') <- cpeBind TopLevel env bind
-                               binds' <- go env' binds
-                               return (bind' `appendFloats` binds')
+    go env (bind : binds) = do (env', floats, maybe_new_bind)
+                                 <- cpeBind TopLevel env bind
+                               MASSERT(isNothing maybe_new_bind)
+                                 -- Only join points get returned this way by
+                                 -- cpeBind, and no join point may float to top
+                               floatss <- go env' binds
+                               return (floats `appendFloats` floatss)
 
 mkDataConWorkers :: DynFlags -> ModLocation -> [TyCon] -> [CoreBind]
 -- See Note [Data constructor workers]
@@ -280,6 +284,29 @@ This is all very gruesome and horrible. It would be better to figure
 out CafInfo later, after CorePrep.  We'll do that in due course.
 Meanwhile this horrible hack works.
 
+Note [Join points and floating]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Join points can float out of other join points but not out of value bindings:
+
+  let z =
+    let  w = ... in -- can float
+    join k = ... in -- can't float
+    ... jump k ...
+  join j x1 ... xn =
+    let  y = ... in -- can float (but don't want to)
+    join h = ... in -- can float (but not much point)
+    ... jump h ...
+  in ...
+
+Here, the jump to h remains valid if h is floated outward, but the jump to k
+does not.
+
+We don't float *out* of join points. It would only be safe to float out of
+nullary join points (or ones where the arguments are all either type arguments
+or dead binders). Nullary join points aren't ever recursive, so they're always
+effectively one-shot functions, which we don't float out of. We *could* float
+join points from nullary join points, but there's no clear benefit at this
+stage.
 
 Note [Data constructor workers]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -369,8 +396,12 @@ Into this one:
 -}
 
 cpeBind :: TopLevelFlag -> CorePrepEnv -> CoreBind
-        -> UniqSM (CorePrepEnv, Floats)
+        -> UniqSM (CorePrepEnv,
+                   Floats,         -- Floating value bindings
+                   Maybe CoreBind) -- Just bind' <=> returned new bind; no float
+                                   -- Nothing <=> added bind' to floats instead
 cpeBind top_lvl env (NonRec bndr rhs)
+  | not (isJoinId bndr)
   = do { (_, bndr1) <- cpCloneBndr env bndr
        ; let dmd         = idDemandInfo bndr
              is_unlifted = isUnliftedType (idType bndr)
@@ -380,7 +411,7 @@ cpeBind top_lvl env (NonRec bndr rhs)
                                           env bndr1 rhs
        -- See Note [Inlining in CorePrep]
        ; if exprIsTrivial rhs2 && isNotTopLevel top_lvl
-            then return (extendCorePrepEnvExpr env bndr rhs2, floats)
+            then return (extendCorePrepEnvExpr env bndr rhs2, floats, Nothing)
             else do {
 
        ; let new_float = mkFloat dmd is_unlifted bndr2 rhs2
@@ -388,19 +419,38 @@ cpeBind top_lvl env (NonRec bndr rhs)
         -- We want bndr'' in the envt, because it records
         -- the evaluated-ness of the binder
        ; return (extendCorePrepEnv env bndr bndr2,
-                 addFloat floats new_float) }}
+                 addFloat floats new_float,
+                 Nothing) }}
+  | otherwise -- See Note [Join points and floating]
+  = ASSERT(not (isTopLevel top_lvl)) -- can't have top-level join point
+    do { (_, bndr1) <- cpCloneBndr env bndr
+       ; (bndr2, rhs1) <- cpeJoinPair env bndr1 rhs
+       ; return (extendCorePrepEnv env bndr bndr2,
+                 emptyFloats,
+                 Just (NonRec bndr2 rhs1)) }
 
 cpeBind top_lvl env (Rec pairs)
-  = do { let (bndrs,rhss) = unzip pairs
-       ; (env', bndrs1) <- cpCloneBndrs env (map fst pairs)
+  | not (isJoinId (head bndrs))
+  = do { (env', bndrs1) <- cpCloneBndrs env bndrs
        ; stuff <- zipWithM (cpePair top_lvl Recursive topDmd False env') bndrs1 rhss
 
        ; let (floats_s, bndrs2, rhss2) = unzip3 stuff
              all_pairs = foldrOL add_float (bndrs2 `zip` rhss2)
                                            (concatFloats floats_s)
        ; return (extendCorePrepEnvList env (bndrs `zip` bndrs2),
-                 unitFloat (FloatLet (Rec all_pairs))) }
+                 unitFloat (FloatLet (Rec all_pairs)),
+                 Nothing) }
+  | otherwise -- See Note [Join points and floating]
+  = do { (env', bndrs1) <- cpCloneBndrs env bndrs
+       ; pairs1 <- zipWithM (cpeJoinPair env') bndrs1 rhss
+
+       ; let bndrs2 = map fst pairs1
+       ; return (extendCorePrepEnvList env' (bndrs `zip` bndrs2),
+                 emptyFloats,
+                 Just (Rec pairs1)) }
   where
+    (bndrs, rhss) = unzip pairs
+
         -- Flatten all the floats, and the currrent
         -- group into a single giant Rec
     add_float (FloatLet (NonRec b r)) prs2 = (b,r) : prs2
@@ -413,7 +463,8 @@ cpePair :: TopLevelFlag -> RecFlag -> Demand -> Bool
         -> UniqSM (Floats, Id, CpeRhs)
 -- Used for all bindings
 cpePair top_lvl is_rec dmd is_unlifted env bndr rhs
-  = do { (floats1, rhs1) <- cpeRhsE env rhs
+  = ASSERT(not (isJoinId bndr)) -- those should use cpeJoinPair
+    do { (floats1, rhs1) <- cpeRhsE env rhs
 
        -- See if we are allowed to float this stuff out of the RHS
        ; (floats2, rhs2) <- float_from_rhs floats1 rhs1
@@ -496,6 +547,45 @@ When InlineMe notes go away this won't happen any more.  But
 it seems good for CorePrep to be robust.
 -}
 
+---------------
+cpeJoinPair :: CorePrepEnv -> JoinId -> CoreExpr
+            -> UniqSM (JoinId, CpeRhs)
+-- Used for all join bindings
+cpeJoinPair env bndr rhs
+  = ASSERT(isJoinId bndr)
+    do { let Just join_arity = isJoinId_maybe bndr
+             (bndrs, body)   = collectNBinders join_arity rhs
+
+       ; (env', bndrs') <- cpCloneBndrs env bndrs
+
+       ; body' <- cpeBodyNF env' body -- Will let-bind the body if it starts
+                                      -- with a lambda
+
+       ; let rhs'  = mkCoreLams bndrs' body'
+             bndr' = bndr `setIdUnfolding` evaldUnfolding
+                          `setIdArity` count isId bndrs
+                            -- See Note [Arity and join points]
+
+       ; return (bndr', rhs') }
+
+{-
+Note [Arity and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Up to now, we've allowed a join point to have an arity greater than its join
+arity (minus type arguments), since this is what's useful for eta expansion.
+However, for code gen purposes, its arity must be exactly the number of value
+arguments it will be called with, and it must have exactly that many value
+lambdas. Hence if there are extra lambdas we must let-bind the body of the RHS:
+
+  join j x y z = \w -> ... in ...
+    =>
+  join j x y z = (let f = \w -> ... in f) in ...
+
+This is also what happens with Note [Silly extra arguments]. Note that it's okay
+for us to mess with the arity because a join point is never exported.
+-}
+
 -- ---------------------------------------------------------------------------
 --              CpeRhs: produces a result satisfying CpeRhs
 -- ---------------------------------------------------------------------------
@@ -518,10 +608,12 @@ cpeRhsE _env expr@(Lit {}) = return (emptyFloats, expr)
 cpeRhsE env expr@(Var {})  = cpeApp env expr
 cpeRhsE env expr@(App {}) = cpeApp env expr
 
-cpeRhsE env (Let bind expr)
-  = do { (env', new_binds) <- cpeBind NotTopLevel env bind
-       ; (floats, body) <- cpeRhsE env' expr
-       ; return (new_binds `appendFloats` floats, body) }
+cpeRhsE env (Let bind body)
+  = do { (env', bind_floats, maybe_bind') <- cpeBind NotTopLevel env bind
+       ; (body_floats, body') <- cpeRhsE env' body
+       ; let expr' = case maybe_bind' of Just bind' -> Let bind' body'
+                                         Nothing    -> body'
+       ; return (bind_floats `appendFloats` body_floats, expr') }
 
 cpeRhsE env (Tick tickish expr)
   | tickishPlace tickish == PlaceNonLam && tickish `tickishScopesLike` SoftScope
index 9ad8321..4da81fd 100644 (file)
@@ -11,50 +11,64 @@ module CoreStats (
         CoreStats(..), coreBindsStats, exprStats,
     ) where
 
+import BasicTypes
 import CoreSyn
 import Outputable
 import Coercion
 import Var
 import Type (Type, typeSize, seqType)
-import Id (idType)
+import Id (idType, isJoinId)
 import CoreSeq (megaSeqIdInfo)
 
 data CoreStats = CS { cs_tm :: Int    -- Terms
                     , cs_ty :: Int    -- Types
-                    , cs_co :: Int }  -- Coercions
+                    , cs_co :: Int    -- Coercions
+                    , cs_vb :: Int    -- Local value bindings
+                    , cs_jb :: Int }  -- Local join bindings
 
 
 instance Outputable CoreStats where
- ppr (CS { cs_tm = i1, cs_ty = i2, cs_co = i3 })
+ ppr (CS { cs_tm = i1, cs_ty = i2, cs_co = i3, cs_vb = i4, cs_jb = i5 })
    = braces (sep [text "terms:"     <+> intWithCommas i1 <> comma,
                   text "types:"     <+> intWithCommas i2 <> comma,
-                  text "coercions:" <+> intWithCommas i3])
+                  text "coercions:" <+> intWithCommas i3 <> comma,
+                  text "joins:"     <+> intWithCommas i5 <> char '/' <>
+                                        intWithCommas (i4 + i5) ])
 
 plusCS :: CoreStats -> CoreStats -> CoreStats
-plusCS (CS { cs_tm = p1, cs_ty = q1, cs_co = r1 })
-       (CS { cs_tm = p2, cs_ty = q2, cs_co = r2 })
-  = CS { cs_tm = p1+p2, cs_ty = q1+q2, cs_co = r1+r2 }
+plusCS (CS { cs_tm = p1, cs_ty = q1, cs_co = r1, cs_vb = v1, cs_jb = j1 })
+       (CS { cs_tm = p2, cs_ty = q2, cs_co = r2, cs_vb = v2, cs_jb = j2 })
+  = CS { cs_tm = p1+p2, cs_ty = q1+q2, cs_co = r1+r2, cs_vb = v1+v2
+       , cs_jb = j1+j2 }
 
 zeroCS, oneTM :: CoreStats
-zeroCS = CS { cs_tm = 0, cs_ty = 0, cs_co = 0 }
+zeroCS = CS { cs_tm = 0, cs_ty = 0, cs_co = 0, cs_vb = 0, cs_jb = 0 }
 oneTM  = zeroCS { cs_tm = 1 }
 
 sumCS :: (a -> CoreStats) -> [a] -> CoreStats
 sumCS f = foldr (plusCS . f) zeroCS
 
 coreBindsStats :: [CoreBind] -> CoreStats
-coreBindsStats = sumCS bindStats
+coreBindsStats = sumCS (bindStats TopLevel)
 
-bindStats :: CoreBind -> CoreStats
-bindStats (NonRec v r) = bindingStats v r
-bindStats (Rec prs)    = sumCS (\(v,r) -> bindingStats v r) prs
+bindStats :: TopLevelFlag -> CoreBind -> CoreStats
+bindStats top_lvl (NonRec v r) = bindingStats top_lvl v r
+bindStats top_lvl (Rec prs)    = sumCS (\(v,r) -> bindingStats top_lvl v r) prs
 
-bindingStats :: Var -> CoreExpr -> CoreStats
-bindingStats v r = bndrStats v `plusCS` exprStats r
+bindingStats :: TopLevelFlag -> Var -> CoreExpr -> CoreStats
+bindingStats top_lvl v r = letBndrStats top_lvl v `plusCS` exprStats r
 
 bndrStats :: Var -> CoreStats
 bndrStats v = oneTM `plusCS` tyStats (varType v)
 
+letBndrStats :: TopLevelFlag -> Var -> CoreStats
+letBndrStats top_lvl v
+  | isTyVar v || isTopLevel top_lvl = bndrStats v
+  | isJoinId v = oneTM { cs_jb = 1 } `plusCS` ty_stats
+  | otherwise  = oneTM { cs_vb = 1 } `plusCS` ty_stats
+  where
+    ty_stats = tyStats (varType v)
+
 exprStats :: CoreExpr -> CoreStats
 exprStats (Var {})        = oneTM
 exprStats (Lit {})        = oneTM
@@ -62,7 +76,7 @@ exprStats (Type t)        = tyStats t
 exprStats (Coercion c)    = coStats c
 exprStats (App f a)       = exprStats f `plusCS` exprStats a
 exprStats (Lam b e)       = bndrStats b `plusCS` exprStats e
-exprStats (Let b e)       = bindStats b `plusCS` exprStats e
+exprStats (Let b e)       = bindStats NotTopLevel b `plusCS` exprStats e
 exprStats (Case e b _ as) = exprStats e `plusCS` bndrStats b
                                         `plusCS` sumCS altStats as
 exprStats (Cast e co)     = coStats co `plusCS` exprStats e
index 72df704..9d69493 100644 (file)
@@ -39,6 +39,10 @@ module CoreSubst (
 
 #include "HsVersions.h"
 
+import {-# SOURCE #-} CoreArity ( etaExpandToJoinPoint )
+                        -- Needed for simpleOptPgm to convert bindings to join
+                        -- points, but CoreArity uses substitutions throughout
+
 import CoreSyn
 import CoreFVs
 import CoreSeq
@@ -867,6 +871,9 @@ simpleOptExpr :: CoreExpr -> CoreExpr
 -- We also inline bindings that bind a Eq# box: see
 -- See Note [Getting the map/coerce RULE to work].
 --
+-- Also we convert functions to join points where possible (as
+-- the occurrence analyser does most of the work anyway).
+--
 -- The result is NOT guaranteed occurrence-analysed, because
 -- in  (let x = y in ....) we substitute for x; so y's occ-info
 -- may change radically
@@ -1012,8 +1019,9 @@ simple_opt_bind' subst (Rec prs)
   = (subst'', res_bind)
   where
     res_bind            = Just (Rec (reverse rev_prs'))
-    (subst', bndrs')    = subst_opt_bndrs subst (map fst prs)
-    (subst'', rev_prs') = foldl do_pr (subst', []) (prs `zip` bndrs')
+    prs'                = map (uncurry convert_if_marked) prs
+    (subst', bndrs')    = subst_opt_bndrs subst (map fst prs')
+    (subst'', rev_prs') = foldl do_pr (subst', []) (prs' `zip` bndrs')
     do_pr (subst, prs) ((b,r), b')
        = case maybe_substitute subst b r2 of
            Just subst' -> (subst', prs)
@@ -1023,7 +1031,20 @@ simple_opt_bind' subst (Rec prs)
          r2 = simple_opt_expr subst r
 
 simple_opt_bind' subst (NonRec b r)
-  = simple_opt_out_bind subst (b, simple_opt_expr subst r)
+  = simple_opt_out_bind subst (b', simple_opt_expr subst r')
+  where
+    (b', r') = convert_if_marked b r
+
+convert_if_marked :: InVar -> InExpr -> (InVar, InExpr)
+convert_if_marked bndr rhs
+  | isId bndr
+  , AlwaysTailCalled ar <- tailCallInfo (idOccInfo bndr)
+    -- Marked to become a join point
+  , (bndrs, body) <- etaExpandToJoinPoint ar rhs
+  = -- Tail call info now unnecessary
+    (zapIdTailCallInfo (bndr `asJoinId` ar), mkLams bndrs body)
+  | otherwise
+  = (bndr, rhs)
 
 ----------------------
 simple_opt_out_bind :: Subst -> (InVar, OutExpr) -> (Subst, Maybe CoreBind)
@@ -1072,8 +1093,10 @@ maybe_substitute subst b r
     safe_to_inline :: OccInfo -> Bool
     safe_to_inline (IAmALoopBreaker {})     = False
     safe_to_inline IAmDead                  = True
-    safe_to_inline (OneOcc in_lam one_br _) = (not in_lam && one_br) || trivial
-    safe_to_inline NoOccInfo                = trivial
+    safe_to_inline occ@(OneOcc {})          = (not (occ_in_lam occ) &&
+                                                occ_one_br occ)
+                                            || trivial
+    safe_to_inline (ManyOccs {})            = trivial
 
     trivial | exprIsTrivial r = True
             | (Var fun, args) <- collectArgs r
index 333a55b..f74e3e5 100644 (file)
@@ -3,7 +3,7 @@
 (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
 -}
 
-{-# LANGUAGE CPP, DeriveDataTypeable #-}
+{-# LANGUAGE CPP, DeriveDataTypeable, FlexibleContexts #-}
 
 -- | CoreSyn holds all the main data types for use by for the Glasgow Haskell Compiler midsection
 module CoreSyn (
@@ -21,7 +21,7 @@ module CoreSyn (
 
         -- ** 'Expr' construction
         mkLets, mkLams,
-        mkApps, mkTyApps, mkCoApps, mkVarApps,
+        mkApps, mkTyApps, mkCoApps, mkVarApps, mkTyArg,
 
         mkIntLit, mkIntLitInt,
         mkWordLit, mkWordLitWord,
@@ -38,6 +38,7 @@ module CoreSyn (
         -- ** Simple 'Expr' access functions and predicates
         bindersOf, bindersOfBinds, rhssOfBind, rhssOfAlts,
         collectBinders, collectTyBinders, collectTyAndValBinders,
+        collectNBinders,
         collectArgs, collectArgsTicks, flattenBinds,
 
         exprToType, exprToCoercion_maybe,
@@ -75,7 +76,8 @@ module CoreSyn (
         collectAnnArgs, collectAnnArgsTicks,
 
         -- ** Operations on annotations
-        deAnnotate, deAnnotate', deAnnAlt, collectAnnBndrs,
+        deAnnotate, deAnnotate', deAnnAlt,
+        collectAnnBndrs, collectNAnnBndrs,
 
         -- * Orphanhood
         IsOrphan(..), isOrphan, notOrphan, chooseOrphanAnchor,
@@ -408,7 +410,8 @@ The let/app invariant
      the right hand side of a non-recursive 'Let', and
      the argument of an 'App',
     /may/ be of unlifted type, but only if
-    the expression is ok-for-speculation.
+    the expression is ok-for-speculation
+    or the 'Let' is for a join point.
 
 This means that the let can be floated around
 without difficulty. For example, this is OK:
@@ -510,6 +513,181 @@ this exhaustive list can be empty!
   conversion; remember STG is un-typed, so there is no need for
   the empty case to do the type conversion.
 
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+In Core, a *join point* is a specially tagged function whose only occurrences
+are saturated tail calls. A tail call can appear in these places:
+
+  1. In the branches (not the scrutinee) of a case
+  2. Underneath a let (value or join point)
+  3. Inside another join point
+
+We write a join-point declaration as
+  join j @a @b x y = e1 in e2,
+like a let binding but with "join" instead (or "join rec" for "let rec"). Note
+that we put the parameters before the = rather than using lambdas; this is
+because it's relevant how many parameters the join point takes *as a join
+point.* This number is called the *join arity,* distinct from arity because it
+counts types as well as values. Note that a join point may return a lambda! So
+  join j x = x + 1
+is different from
+  join j = \x -> x + 1
+The former has join arity 1, while the latter has join arity 0.
+
+The identifier for a join point is called a join id or a *label.* An invocation
+is called a *jump.* We write a jump using the jump keyword:
+
+  jump j 3
+
+The words *label* and *jump* are evocative of assembly code (or Cmm) for a
+reason: join points are indeed compiled as labeled blocks, and jumps become
+actual jumps (plus argument passing and stack adjustment). There is no closure
+allocated and only a fraction of the function-call overhead. Hence we would
+like as many functions as possible to become join points (see OccurAnal) and
+the type rules for join points ensure we preserve the properties that make them
+efficient.
+
+In the actual AST, a join point is indicated by the IdDetails of the binder: a
+local value binding gets 'VanillaId' but a join point gets a 'JoinId' with its
+join arity.
+
+For more details, see the paper:
+
+  Luke Maurer, Paul Downen, Zena Ariola, and Simon Peyton Jones. "Compiling
+  without continuations." Submitted to PLDI'17.
+
+  https://www.microsoft.com/en-us/research/publication/compiling-without-continuations/
+
+Note [Invariants on join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Join points must follow these invariants:
+
+  1. All occurrences must be tail calls. Each of these tail calls must pass the
+     same number of arguments, counting both types and values; we call this the
+     "join arity" (to distinguish from regular arity, which only counts values).
+  2. For join arity n, the right-hand side must begin with at least n lambdas.
+  3. If the binding is recursive, then all other bindings in the recursive group
+     must also be join points.
+  4. The binding's type must not be polymorphic in its return type (as defined
+     in Note [The polymorphism rule of join points]).
+
+Examples:
+
+  join j1  x = 1 + x in jump j (jump j x)  -- Fails 1: non-tail call
+  join j1' x = 1 + x in if even a
+                          then jump j1 a
+                          else jump j1 a b -- Fails 1: inconsistent calls
+  join j2  x = flip (+) x in j2 1 2        -- Fails 2: not enough lambdas
+  join j2' x = \y -> x + y in j3 1         -- Passes: extra lams ok
+  join j @a (x :: a) = x                   -- Fails 4: polymorphic in ret type
+
+Invariant 1 applies to left-hand sides of rewrite rules, so a rule for a join
+point must have an exact call as its LHS.
+
+Strictly speaking, invariant 3 is redundant, since a call from inside a lazy
+binding isn't a tail call. Since a let-bound value can't invoke a free join
+point, then, they can't be mutually recursive. (A Core binding group *can*
+include spurious extra bindings if the occurrence analyser hasn't run, so
+invariant 3 does still need to be checked.) For the rigorous definition of
+"tail call", see Section 3 of the paper (Note [Join points]).
+
+Invariant 4 is subtle; see Note [The polymorphism rule of join points].
+
+Core Lint will check these invariants, anticipating that any binder whose
+OccInfo is marked AlwaysTailCalled will become a join point as soon as the
+simplifier (or simpleOptPgm) runs.
+
+Note [The type of a join point]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+A join point has the same type it would have as a function. That is, if it takes
+an Int and a Bool and its body produces a String, its type is `Int -> Bool ->
+String`. Natural as this may seem, it can be awkward. A join point shouldn't be
+thought to "return" in the same sense a function does---a jump is one-way. This
+is crucial for understanding how case-of-case interacts with join points:
+
+  case (join
+          j :: Int -> Bool -> String
+          j x y = ...
+        in
+          jump j z w) of
+    "" -> True
+    _  -> False
+
+The simplifier will pull the case into the join point (see Note [Case-of-case
+and join points] in Simplify):
+
+  join
+    j :: Int -> Bool -> Bool -- changed!
+    j x y = case ... of "" -> True
+                        _  -> False
+  in
+    jump j z w
+
+The body of the join point now returns a Bool, so the label `j` has to have its
+type updated accordingly. Inconvenient though this may be, it has the advantage
+that 'CoreUtils.exprType' can still return a type for any expression, including
+a jump.
+
+This differs from the paper (see Note [Invariants on join points]). In the
+paper, we instead give j the type `Int -> Bool -> forall a. a`. Then each jump
+carries the "return type" as a parameter, exactly the way other non-returning
+functions like `error` work:
+
+  case (join
+          j :: Int -> Bool -> forall a. a
+          j x y = ...
+        in
+          jump j z w @String) of
+    "" -> True
+    _  -> False
+
+Now we can move the case inward and we only have to change the jump:
+
+  join
+    j :: Int -> Bool -> forall a. a
+    j x y = case ... of "" -> True
+                        _  -> False
+  in
+    jump j z w @Bool
+
+(Core Lint would still check that the body of the join point has the right type;
+that type would simply not be reflected in the join id.)
+
+Note [The polymorphism rule of join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Invariant 4 of Note [Invariants on join points] forbids a join point to be
+polymorphic in its return type. That is, if its type is
+
+  forall a1 ... ak. t1 -> ... -> tn -> r
+
+where its join arity is k+n, none of the type parameters ai may occur free in r.
+The most direct explanation is that given
+
+  join j @a1 ... @ak x1 ... xn = e1 in e2
+
+our typing rules require `e1` and `e2` to have the same type. Therefore the type
+of `e1`---the return type of the join point---must be the same as the type of
+e2. Since the type variables aren't bound in `e2`, its type can't include them,
+and thus neither can the type of `e1`.
+
+There's a deeper explanation in terms of the sequent calculus in Section 5.3 of
+a previous paper:
+
+  Paul Downen, Luke Maurer, Zena Ariola, and Simon Peyton Jones. "Sequent
+  calculus as a compiler intermediate language." ICFP'16.
+
+  https://www.microsoft.com/en-us/research/wp-content/uploads/2016/04/sequent-calculus-icfp16.pdf
+
+The quick version: Consider the CPS term (the paper uses the sequent calculus,
+but we can translate readily):
+
+  \k -> join j @a1 ... @ak x1 ... xn = e1 k in e2 k
+
+Since `j` is a join point, it doesn't bind a continuation variable but reuses
+the variable `k` from the context. But the parameters `ai` are not in `k`'s
+scope, and `k`'s type determines the return type of `j`; thus the `ai`s don't
+appear in the return type of `j`. (Also, since `e1` and `e2` are passed the same
+continuation, they must have the same type; hence the direct explanation above.)
 
 ************************************************************************
 *                                                                      *
@@ -1534,10 +1712,16 @@ type TaggedAlt  t = Alt  (TaggedBndr t)
 instance Outputable b => Outputable (TaggedBndr b) where
   ppr (TB b l) = char '<' <> ppr b <> comma <> ppr l <> char '>'
 
-instance Outputable b => OutputableBndr (TaggedBndr b) where
+-- OutputableBndr Var is declared separately in PprCore; using a FlexibleContext
+-- to avoid circularity
+instance (OutputableBndr Var, Outputable b) =>
+    OutputableBndr (TaggedBndr b) where
   pprBndr _ b = ppr b   -- Simple
   pprInfixOcc  b = ppr b
   pprPrefixOcc b = ppr b
+  pprNonRecBndrKeyword (TB b _) = pprNonRecBndrKeyword b
+  pprRecBndrKeyword    (TB b _) = pprRecBndrKeyword    b
+  pprLamsOnLhs         (TB b _) = pprLamsOnLhs         b
 
 deTagExpr :: TaggedExpr t -> CoreExpr
 deTagExpr (Var v)                   = Var v
@@ -1584,17 +1768,17 @@ mkCoApps  f args = foldl (\ e a -> App e (Coercion a)) f args
 mkVarApps f vars = foldl (\ e a -> App e (varToCoreExpr a)) f vars
 mkConApp con args = mkApps (Var (dataConWorkId con)) args
 
-mkTyApps  f args = foldl (\ e a -> App e (typeOrCoercion a)) f args
-  where
-    typeOrCoercion ty
-      | Just co <- isCoercionTy_maybe ty = Coercion co
-      | otherwise                        = Type ty
+mkTyApps  f args = foldl (\ e a -> App e (mkTyArg a)) f args
 
 mkConApp2 :: DataCon -> [Type] -> [Var] -> Expr b
 mkConApp2 con tys arg_ids = Var (dataConWorkId con)
                             `mkApps` map Type tys
                             `mkApps` map varToCoreExpr arg_ids
 
+mkTyArg :: Type -> Expr b
+mkTyArg ty
+  | Just co <- isCoercionTy_maybe ty = Coercion co
+  | otherwise                        = Type ty
 
 -- | Create a machine integer literal expression of type @Int#@ from an @Integer@.
 -- If you want an expression of type @Int@ use 'MkCore.mkIntExpr'
@@ -1750,6 +1934,9 @@ collectBinders         :: Expr b   -> ([b],     Expr b)
 collectTyBinders       :: CoreExpr -> ([TyVar], CoreExpr)
 collectValBinders      :: CoreExpr -> ([Id],    CoreExpr)
 collectTyAndValBinders :: CoreExpr -> ([TyVar], [Id], CoreExpr)
+-- | Strip off exactly N leading lambdas (type or value). Good for use with
+-- join points.
+collectNBinders        :: Int -> Expr b -> ([b], Expr b)
 
 collectBinders expr
   = go [] expr
@@ -1775,6 +1962,13 @@ collectTyAndValBinders expr
     (tvs, body1) = collectTyBinders expr
     (ids, body)  = collectValBinders body1
 
+collectNBinders orig_n orig_expr
+  = go orig_n [] orig_expr
+  where
+    go 0 bs expr      = (reverse bs, expr)
+    go n bs (Lam b e) = go (n-1) (b:bs) e
+    go _ _  _         = pprPanic "collectNBinders" $ int orig_n
+
 -- | Takes a nested application expression and returns the the function
 -- being applied and the arguments to which it is applied
 collectArgs :: Expr b -> (Expr b, [Arg b])
@@ -1929,3 +2123,12 @@ collectAnnBndrs e
   where
     collect bs (_, AnnLam b body) = collect (b:bs) body
     collect bs body               = (reverse bs, body)
+
+-- | As 'collectNBinders' but for 'AnnExpr' rather than 'Expr'
+collectNAnnBndrs :: Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
+collectNAnnBndrs orig_n e
+  = collect orig_n [] e
+  where
+    collect 0 bs body               = (reverse bs, body)
+    collect n bs (_, AnnLam b body) = collect (n-1) (b:bs) body
+    collect _ _  _                  = pprPanic "collectNBinders" $ int orig_n
index 574d841..11c4a5e 100644 (file)
@@ -523,15 +523,13 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
       | otherwise = size_up e
 
     size_up (Let (NonRec binder rhs) body)
-      = size_up rhs             `addSizeNSD`
-        size_up body            `addSizeN`
-        (if isUnliftedType (idType binder) then 0 else 10)
-                -- For the allocation
-                -- If the binder has an unlifted type there is no allocation
+      = size_up_rhs (binder, rhs) `addSizeNSD`
+        size_up body              `addSizeN`
+        size_up_alloc binder
 
     size_up (Let (Rec pairs) body)
-      = foldr (addSizeNSD . size_up . snd)
-              (size_up body `addSizeN` (10 * length pairs))     -- (length pairs) for the allocation
+      = foldr (addSizeNSD . size_up_rhs)
+              (size_up body `addSizeN` sum (map (size_up_alloc . fst) pairs))
               pairs
 
     size_up (Case e _ _ alts)
@@ -606,6 +604,14 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
               | otherwise
                 = False
 
+    size_up_rhs (bndr, rhs)
+      | Just join_arity <- isJoinId_maybe bndr
+        -- Skip arguments to join point
+      , (_bndrs, body) <- collectNBinders join_arity rhs
+      = size_up body
+      | otherwise
+      = size_up rhs
+
     ------------
     -- size_up_app is used when there's ONE OR MORE value args
     size_up_app (App fun arg) args voids
@@ -642,6 +648,16 @@ sizeExpr dflags bOMB_OUT_SIZE top_args expr
         -- A good example is Foreign.C.Error.errrnoToIOError
 
     ------------
+    -- Cost to allocate binding with given binder
+    size_up_alloc bndr
+      |  isTyVar bndr                 -- Doesn't exist at runtime
+      || isUnliftedType (idType bndr) -- Doesn't live in heap
+      || isJoinId bndr                -- Not allocated at all
+      = 0
+      | otherwise
+      = 10
+
+    ------------
         -- These addSize things have to be here because
         -- I don't want to give them bOMB_OUT_SIZE as an argument
     addSizeN TooBig          _  = TooBig
@@ -706,6 +722,17 @@ callSize
  -> Int
 callSize n_val_args voids = 10 * (1 + n_val_args - voids)
 
+-- | The size of a jump to a join point
+jumpSize
+ :: Int  -- ^ number of value args
+ -> Int  -- ^ number of value args that are void
+ -> Int
+jumpSize n_val_args voids = 2 * (1 + n_val_args - voids)
+  -- A jump is 20% the size of a function call. Making jumps free reopens
+  -- bug #6048, but making them any more expensive loses a 21% improvement in
+  -- spectral/puzzle. TODO Perhaps adjusting the default threshold would be a
+  -- better solution?
+
 funSize :: DynFlags -> [Id] -> Id -> Int -> Int -> ExprSize
 -- Size for functions that are not constructors or primops
 -- Note [Function applications]
@@ -715,9 +742,11 @@ funSize dflags top_args fun n_val_args voids
   | otherwise = SizeIs size arg_discount res_discount
   where
     some_val_args = n_val_args > 0
+    is_join = isJoinId fun
 
-    size | some_val_args = callSize n_val_args voids
-         | otherwise     = 0
+    size | is_join              = jumpSize n_val_args voids
+         | not some_val_args    = 0
+         | otherwise            = callSize n_val_args voids
         -- The 1+ is for the function itself
         -- Add 1 for each non-trivial arg;
         -- the allocation cost, as in let(rec)
index d856e3d..4eef079 100644 (file)
@@ -49,7 +49,10 @@ module CoreUtils (
         stripTicksE, stripTicksT,
 
         -- * StaticPtr
-        collectMakeStaticArgs
+        collectMakeStaticArgs,
+
+        -- * Join points
+        isJoinBind
     ) where
 
 #include "HsVersions.h"
@@ -2304,3 +2307,17 @@ collectMakeStaticArgs e
     | (fun@(Var b), [Type t, loc, arg], _) <- collectArgsTicks (const True) e
     , idName b == makeStaticName = Just (fun, t, loc, arg)
 collectMakeStaticArgs _          = Nothing
+
+{-
+************************************************************************
+*                                                                      *
+\subsection{Join points}
+*                                                                      *
+************************************************************************
+-}
+
+-- | Does this binding bind a join point (or a recursive group of join points)?
+isJoinBind :: CoreBind -> Bool
+isJoinBind (NonRec b _)       = isJoinId b
+isJoinBind (Rec ((b, _) : _)) = isJoinId b
+isJoinBind _                  = False
index 882faa7..7d24202 100644 (file)
@@ -107,6 +107,7 @@ sortQuantVars vs = sorted_tcvs ++ ids
 mkCoreLet :: CoreBind -> CoreExpr -> CoreExpr
 mkCoreLet (NonRec bndr rhs) body        -- See Note [CoreSyn let/app invariant]
   | needsCaseBinding (idType bndr) rhs
+  , not (isJoinId bndr)
   = Case rhs bndr (exprType body) [(DEFAULT,[],body)]
 mkCoreLet bind body
   = Let bind body
index 152a701..196a9b9 100644 (file)
@@ -29,6 +29,7 @@ import Type
 import Coercion
 import DynFlags
 import BasicTypes
+import Maybes
 import Util
 import Outputable
 import FastString
@@ -113,7 +114,14 @@ ppr_bind ann (Rec binds)           = vcat (map pp binds)
 ppr_binding :: OutputableBndr b => Annotation b -> (b, Expr b) -> SDoc
 ppr_binding ann (val_bdr, expr)
   = ann expr $$ pprBndr LetBind val_bdr $$
-    hang (ppr val_bdr <+> equals) 2 (pprCoreExpr expr)
+    hang (ppr val_bdr <+> sep (map (pprBndr LambdaBind) lhs_bndrs) <+> equals) 2
+         (pprCoreExpr rhs)
+  where
+    (bndrs, body)          = collectBinders expr
+    (lhs_bndrs, rhs_bndrs) = splitAt (pprLamsOnLhs val_bdr) bndrs
+    rhs                    = mkLams rhs_bndrs body
+                      -- Returns ([], expr) unless it's a join point, in which
+                      -- case we want the args before the =
 
 pprParendExpr expr = ppr_expr parens expr
 pprCoreExpr   expr = ppr_expr noParens expr
@@ -131,7 +139,8 @@ ppr_expr :: OutputableBndr b => (SDoc -> SDoc) -> Expr b -> SDoc
         -- The function adds parens in context that need
         -- an atomic value (e.g. function args)
 
-ppr_expr _       (Var name)    = ppr name
+ppr_expr _       (Var name)    = ppWhen (isJoinId name) (text "jump") <+>
+                                   ppr name
 ppr_expr add_par (Type ty)     = add_par (text "TYPE:" <+> ppr ty)       -- Weird
 ppr_expr add_par (Coercion co) = add_par (text "CO:" <+> ppr co)
 ppr_expr add_par (Lit lit)     = pprLiteral add_par lit
@@ -172,7 +181,10 @@ ppr_expr add_par expr@(App {})
                              tc        = dataConTyCon dc
                              saturated = val_args `lengthIs` idArity f
 
-                   _ -> parens (hang (ppr f) 2 pp_args)
+                   _ -> parens (hang fun_doc 2 pp_args)
+                   where
+                     fun_doc | isJoinId f = text "jump" <+> ppr f
+                             | otherwise  = ppr f
 
         _ -> parens (hang (pprParendExpr fun) 2 pp_args)
     }
@@ -239,12 +251,14 @@ ppr_expr add_par (Let bind@(NonRec val_bdr rhs) expr@(Let _ _))
 -- General case (recursive case, too)
 ppr_expr add_par (Let bind expr)
   = add_par $
-    sep [hang (ptext keyword) 2 (ppr_bind noAnn bind <+> text "} in"),
+    sep [hang (keyword <+> char '{') 2 (ppr_bind noAnn bind <+> text "} in"),
          pprCoreExpr expr]
   where
     keyword = case bind of
-                Rec _      -> (sLit "letrec {")
-                NonRec _ _ -> (sLit "let {")
+                NonRec b _    -> pprNonRecBndrKeyword b
+                Rec ((b,_):_) -> pprRecBndrKeyword    b
+                Rec []        -> text "let" -- This *shouldn't* happen, but
+                                            -- let's be tolerant here
 
 ppr_expr add_par (Tick tickish expr)
   = sdocWithDynFlags $ \dflags ->
@@ -315,6 +329,11 @@ instance OutputableBndr Var where
   pprBndr = pprCoreBinder
   pprInfixOcc  = pprInfixName  . varName
   pprPrefixOcc = pprPrefixName . varName
+  pprNonRecBndrKeyword bndr | isJoinId bndr = text "join"
+                            | otherwise     = text "let"
+  pprRecBndrKeyword    bndr | isJoinId bndr = text "joinrec"
+                            | otherwise     = text "letrec"
+  pprLamsOnLhs bndr = isJoinId_maybe bndr `orElse` 0
 
 pprCoreBinder :: BindingSite -> Var -> SDoc
 pprCoreBinder LetBind binder
@@ -398,7 +417,7 @@ pprIdBndrInfo info
     lbv_info  = oneShotInfo info
 
     has_prag  = not (isDefaultInlinePragma prag_info)
-    has_occ   = not (isNoOcc occ_info)
+    has_occ   = not (isManyOccs occ_info)
     has_dmd   = not $ isTopDmd dmd_info
     has_lbv   = not (hasNoOneShotInfo lbv_info)
 
index 0d336ad..165130a 100644 (file)
@@ -893,6 +893,15 @@ for the primitive case:
 \end{verbatim}
 
 Now @fail.33@ is a function, so it can be let-bound.
+
+We would *like* to use join points here; in fact, these "fail variables" are
+paradigmatic join points! Sadly, this breaks pattern synonyms, which desugar as
+CPS functions - i.e. they take "join points" as parameters. It's not impossible
+to imagine extending our type system to allow passing join points around (very
+carefully), but we certainly don't support it now.
+
+99.99% of the time, the fail variables wind up as join points in short order
+anyway, and the Void# doesn't do much harm.
 -}
 
 mkFailurePair :: CoreExpr       -- Result type of the whole case expression
@@ -912,6 +921,11 @@ mkFailurePair expr
 {-
 Note [Failure thunks and CPR]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+(This note predates join points as formal entities (hence the quotation marks).
+We can't use actual join points here (see above); if we did, this would also
+solve the CPR problem, since join points don't get CPR'd. See Note [Don't CPR
+join points] in WorkWrap.)
+
 When we make a failure point we ensure that it
 does not look like a thunk. Example:
 
index d4dd51e..7740977 100644 (file)
@@ -10,7 +10,7 @@ module IfaceSyn (
 
         IfaceDecl(..), IfaceFamTyConFlav(..), IfaceClassOp(..), IfaceAT(..),
         IfaceConDecl(..), IfaceConDecls(..), IfaceEqSpec,
-        IfaceExpr(..), IfaceAlt, IfaceLetBndr(..),
+        IfaceExpr(..), IfaceAlt, IfaceLetBndr(..), IfaceJoinInfo(..),
         IfaceBinding(..), IfaceConAlt(..),
         IfaceIdInfo(..), IfaceIdDetails(..), IfaceUnfolding(..),
         IfaceInfoItem(..), IfaceRule(..), IfaceAnnotation(..), IfaceAnnTarget,
@@ -502,7 +502,10 @@ data IfaceBinding
 -- IfaceLetBndr is like IfaceIdBndr, but has IdInfo too
 -- It's used for *non-top-level* let/rec binders
 -- See Note [IdInfo on nested let-bindings]
-data IfaceLetBndr = IfLetBndr IfLclName IfaceType IfaceIdInfo
+data IfaceLetBndr = IfLetBndr IfLclName IfaceType IfaceIdInfo IfaceJoinInfo
+
+data IfaceJoinInfo = IfaceNotJoinPoint
+                   | IfaceJoinPoint JoinArity
 
 {-
 Note [Empty case alternatives]
@@ -1158,8 +1161,8 @@ ppr_con_bs :: IfaceConAlt -> [IfLclName] -> SDoc
 ppr_con_bs con bs = ppr con <+> hsep (map ppr bs)
 
 ppr_bind :: (IfaceLetBndr, IfaceExpr) -> SDoc
-ppr_bind (IfLetBndr b ty info, rhs)
-  = sep [hang (ppr b <+> dcolon <+> ppr ty) 2 (ppr info),
+ppr_bind (IfLetBndr b ty info ji, rhs)
+  = sep [hang (ppr b <+> dcolon <+> ppr ty) 2 (ppr ji <+> ppr info),
          equals <+> pprIfaceExpr noParens rhs]
 
 ------------------
@@ -1207,6 +1210,10 @@ instance Outputable IfaceInfoItem where
   ppr HsNoCafRefs           = text "HasNoCafRefs"
   ppr HsLevity              = text "Never levity-polymorphic"
 
+instance Outputable IfaceJoinInfo where
+  ppr IfaceNotJoinPoint   = empty
+  ppr (IfaceJoinPoint ar) = angleBrackets (text "join" <+> ppr ar)
+
 instance Outputable IfaceUnfolding where
   ppr (IfCompulsory e)     = text "<compulsory>" <+> parens (ppr e)
   ppr (IfCoreUnfold s e)   = (if s
@@ -1407,8 +1414,8 @@ freeNamesIfLetBndr :: IfaceLetBndr -> NameSet
 -- Remember IfaceLetBndr is used only for *nested* bindings
 -- The IdInfo can contain an unfolding (in the case of
 -- local INLINE pragmas), so look there too
-freeNamesIfLetBndr (IfLetBndr _name ty info) = freeNamesIfType ty
-                                             &&& freeNamesIfIdInfo info
+freeNamesIfLetBndr (IfLetBndr _name ty info _ji) = freeNamesIfType ty
+                                                 &&& freeNamesIfIdInfo info
 
 freeNamesIfTvBndr :: IfaceTvBndr -> NameSet
 freeNamesIfTvBndr (_fs,k) = freeNamesIfKind k
@@ -2075,14 +2082,27 @@ instance Binary IfaceBinding where
             _ -> do { ac <- get bh; return (IfaceRec ac) }
 
 instance Binary IfaceLetBndr where
-    put_ bh (IfLetBndr a b c) = do
+    put_ bh (IfLetBndr a b c d) = do
             put_ bh a
             put_ bh b
             put_ bh c
+            put_ bh d
     get bh = do a <- get bh
                 b <- get bh
                 c <- get bh
-                return (IfLetBndr a b c)
+                d <- get bh
+                return (IfLetBndr a b c d)
+
+instance Binary IfaceJoinInfo where
+    put_ bh IfaceNotJoinPoint = putByte bh 0
+    put_ bh (IfaceJoinPoint ar) = do
+        putByte bh 1
+        put_ bh ar
+    get bh = do
+        h <- getByte bh
+        case h of
+            0 -> return IfaceNotJoinPoint
+            _ -> liftM IfaceJoinPoint $ get bh
 
 instance Binary IfaceTyConParent where
     put_ bh IfNoParent = putByte bh 0
index e08a3d7..f6a4f41 100644 (file)
@@ -1367,12 +1367,13 @@ tcIfaceExpr (IfaceCase scrut case_bndr alts)  = do
      alts' <- mapM (tcIfaceAlt scrut' tc_app) alts
      return (Case scrut' case_bndr' (coreAltsType alts') alts')
 
-tcIfaceExpr (IfaceLet (IfaceNonRec (IfLetBndr fs ty info) rhs) body)
+tcIfaceExpr (IfaceLet (IfaceNonRec (IfLetBndr fs ty info ji) rhs) body)
   = do  { name    <- newIfaceName (mkVarOccFS fs)
         ; ty'     <- tcIfaceType ty
         ; id_info <- tcIdInfo False {- Don't ignore prags; we are inside one! -}
                               name ty' info
         ; let id = mkLocalIdOrCoVarWithInfo name ty' id_info
+                     `asJoinId_maybe` tcJoinInfo ji
         ; rhs' <- tcIfaceExpr rhs
         ; body' <- extendIfaceIdEnv [id] (tcIfaceExpr body)
         ; return (Let (NonRec id rhs') body') }
@@ -1384,11 +1385,11 @@ tcIfaceExpr (IfaceLet (IfaceRec pairs) body)
        ; body' <- tcIfaceExpr body
        ; return (Let (Rec pairs') body') } }
  where
-   tc_rec_bndr (IfLetBndr fs ty _)
+   tc_rec_bndr (IfLetBndr fs ty _ ji)
      = do { name <- newIfaceName (mkVarOccFS fs)
           ; ty'  <- tcIfaceType ty
-          ; return (mkLocalIdOrCoVar name ty') }
-   tc_pair (IfLetBndr _ _ info, rhs) id
+          ; return (mkLocalIdOrCoVar name ty' `asJoinId_maybe` tcJoinInfo ji) }
+   tc_pair (IfLetBndr _ _ info _, rhs) id
      = do { rhs' <- tcIfaceExpr rhs
           ; id_info <- tcIdInfo False {- Don't ignore prags; we are inside one! -}
                                 (idName id) (idType id) info
@@ -1509,6 +1510,10 @@ tcIdInfo ignore_prags name ty info = do
                        | otherwise = info
            ; return (info1 `setUnfoldingInfo` unf) }
 
+tcJoinInfo :: IfaceJoinInfo -> Maybe JoinArity
+tcJoinInfo (IfaceJoinPoint ar) = Just ar
+tcJoinInfo IfaceNotJoinPoint   = Nothing
+
 tcUnfolding :: Name -> Type -> IdInfo -> IfaceUnfolding -> IfL Unfolding
 tcUnfolding name _ info (IfCoreUnfold stable if_expr)
   = do  { dflags <- getDynFlags
index 696d0ff..37d41f4 100644 (file)
@@ -325,6 +325,7 @@ toIfaceLetBndr :: Id -> IfaceLetBndr
 toIfaceLetBndr id  = IfLetBndr (occNameFS (getOccName id))
                                (toIfaceType (idType id))
                                (toIfaceIdInfo (idInfo id))
+                               (toIfaceJoinInfo (isJoinId_maybe id))
   -- Put into the interface file any IdInfo that CoreTidy.tidyLetBndr
   -- has left on the Id.  See Note [IdInfo on nested let-bindings] in IfaceSyn
 
@@ -382,6 +383,10 @@ toIfaceIdInfo id_info
     levity_hsinfo | isNeverLevPolyIdInfo id_info = Just HsLevity
                   | otherwise                    = Nothing
 
+toIfaceJoinInfo :: Maybe JoinArity -> IfaceJoinInfo
+toIfaceJoinInfo (Just ar) = IfaceJoinPoint ar
+toIfaceJoinInfo Nothing   = IfaceNotJoinPoint
+
 --------------------------
 toIfUnfolding :: Bool -> Unfolding -> Maybe IfaceInfoItem
 toIfUnfolding lb (CoreUnfolding { uf_tmpl = rhs
index f9314bd..971b3e0 100644 (file)
@@ -11,7 +11,7 @@ module CSE (cseProgram) where
 #include "HsVersions.h"
 
 import CoreSubst
-import Var              ( Var )
+import Var              ( Var, isJoinId )
 import Id               ( Id, idType, idUnfolding, idInlineActivation
                         , zapIdOccInfo, zapIdUsageInfo )
 import CoreUtils        ( mkAltExpr
@@ -245,6 +245,18 @@ not if you are using unsafe casts.  I actually found a case where we
 had
    (x :: HValue) |> (UnsafeCo :: HValue ~ Array# Int)
 
+Note [CSE for join points?]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We must not be naive about join points in CSE:
+   join j = e in
+   if b then jump j else 1 + e
+The expression (1 + jump j) is not good (see Note [Invariants on join points] in
+CoreSyn). This seems to come up quite seldom, but it happens (first seen
+compiling ppHtml in Haddock.Backends.Xhtml).
+
+We could try and be careful by tracking which join points are still valid at
+each subexpression, but since join points aren't allocated or shared, there's
+less to gain by trying to CSE them.
 
 ************************************************************************
 *                                                                      *
@@ -304,6 +316,8 @@ addBinding env in_id out_id rhs'
              -- See Note [CSE for INLINE and NOINLINE]
           || isStableUnfolding (idUnfolding out_id)
              -- See Note [CSE for stable unfoldings]
+          || isJoinId in_id
+             -- See Note [CSE for join points?]
 
     -- Should we use SUBSTITUTE or EXTEND?
     -- See Note [CSE for bindings]
index 12e69b9..7b80776 100644 (file)
@@ -133,6 +133,7 @@ data CoreToDo           -- These are diff core-to-core passes,
 
   | CoreTidy
   | CorePrep
+  | CoreOccurAnal
 
 instance Outputable CoreToDo where
   ppr (CoreDoSimplify _ _)     = text "Simplifier"
@@ -152,6 +153,7 @@ instance Outputable CoreToDo where
   ppr CoreDesugarOpt           = text "Desugar (after optimization)"
   ppr CoreTidy                 = text "Tidy Core"
   ppr CorePrep                 = text "CorePrep"
+  ppr CoreOccurAnal            = text "Occurrence analysis"
   ppr CoreDoPrintCore          = text "Print core"
   ppr (CoreDoRuleCheck {})     = text "Rule check"
   ppr CoreDoNothing            = text "CoreDoNothing"
index f32b5a3..1fd969e 100644 (file)
@@ -23,7 +23,7 @@ import MkCore
 import CoreUtils        ( exprIsDupable, exprIsExpandable,
                           exprOkForSideEffects, mkTicks )
 import CoreFVs
-import Id               ( isOneShotBndr, idType )
+import Id               ( isJoinId, isJoinId_maybe, isOneShotBndr, idType )
 import Var
 import Type             ( isUnliftedType )
 import VarSet
@@ -31,6 +31,7 @@ import Util
 import DynFlags
 import Outputable
 import Data.List( mapAccumL )
+import BasicTypes       ( RecFlag(..), isRec )
 
 {-
 Top-level interface function, @floatInwards@.  Note that we do not
@@ -160,18 +161,25 @@ fiExpr dflags to_drop ann_expr@(_,AnnApp {})
            (zipWith (fiExpr dflags) arg_drops ann_args)
   where
     (ann_fun, ann_args, ticks) = collectAnnArgsTicks tickishFloatable ann_expr
-    (extra_fvs, arg_fvs) = mapAccumL mk_arg_fvs emptyDVarSet ann_args
+    (extra_fvs0, fun_fvs)
+      | (_, AnnVar _) <- ann_fun = (freeVarsOf ann_fun, emptyDVarSet)
+          -- Don't float the binding for f into f x y z; see Note [Join points]
+          -- for why we *can't* do it when f is a join point. (If f isn't a
+          -- join point, floating it in isn't especially harmful but it's
+          -- useless since the simplifier will immediately float it back out.)
+      | otherwise                = (emptyDVarSet, freeVarsOf ann_fun)
+    (extra_fvs, arg_fvs) = mapAccumL mk_arg_fvs extra_fvs0 ann_args
 
     mk_arg_fvs :: FreeVarSet -> CoreExprWithFVs -> (FreeVarSet, FreeVarSet)
     mk_arg_fvs extra_fvs ann_arg
-      | noFloatIntoRhs ann_arg
+      | noFloatIntoRhs False NonRecursive ann_arg
       = (extra_fvs `unionDVarSet` freeVarsOf ann_arg, emptyDVarSet)
       | otherwise
       = (extra_fvs, freeVarsOf ann_arg)
 
     drop_here : extra_drop : fun_drop : arg_drops
       = sepBindsByDropPoint dflags False
-          (extra_fvs : freeVarsOf ann_fun : arg_fvs)
+          (extra_fvs : fun_fvs : arg_fvs)
           (freeVarsOfType ann_fun `unionDVarSet`
            mapUnionDVarSet freeVarsOfType ann_args)
           to_drop
@@ -186,6 +194,28 @@ We don't want to float bindings into here
 because that might destroy the let/app invariant, which requires
 unlifted function arguments to be ok-for-speculation.
 
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+
+Generally, we don't need to worry about join points - there are places we're
+not allowed to float them, but since they can't have occurrences in those
+places, we're not tempted.
+
+We do need to be careful about jumps, however:
+
+  joinrec j x y z = ... in
+  jump j a b c
+
+Previous versions often floated the definition of a recursive function into its
+only non-recursive occurrence. But for a join point, this is a disaster:
+
+  (joinrec j x y z = ... in
+  jump j) a b c -- wrong!
+
+Every jump must be exact, so the jump to j must have three arguments. Hence
+we're careful not to float into the target of a jump (though we can float into
+the arguments just fine).
+
 Note [Floating in past a lambda group]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 * We must be careful about floating inside a value lambda.
@@ -221,6 +251,9 @@ So we treat lambda in groups, using the following rule:
 
 This is what the 'go' function in the AnnLam case is doing.
 
+(Join points are handled similarly: a join point is considered one-shot iff
+it's non-recursive, so we float only into non-recursive join points.)
+
 Urk! if all are tyvars, and we don't float in, we may miss an
       opportunity to float inside a nested case branch
 -}
@@ -308,11 +341,14 @@ fiExpr dflags to_drop (_,AnnLet (AnnNonRec id rhs) body)
     rhs_fvs  = freeVarsOf rhs
 
     rule_fvs = idRuleAndUnfoldingVarsDSet id        -- See Note [extra_fvs (2): free variables of rules]
-    extra_fvs | noFloatIntoRhs rhs = rule_fvs `unionDVarSet` freeVarsOf rhs
-              | otherwise          = rule_fvs
+    extra_fvs | noFloatIntoRhs (isJoinId id) NonRecursive rhs
+              = rule_fvs `unionDVarSet` freeVarsOf rhs
+              | otherwise
+              = rule_fvs
         -- See Note [extra_fvs (1): avoid floating into RHS]
         -- No point in floating in only to float straight out again
-        -- Ditto ok-for-speculation unlifted RHSs
+        -- We *can't* float into ok-for-speculation unlifted RHSs
+        -- But do float into join points
 
     [shared_binds, extra_binds, rhs_binds, body_binds]
         = sepBindsByDropPoint dflags False
@@ -327,7 +363,7 @@ fiExpr dflags to_drop (_,AnnLet (AnnNonRec id rhs) body)
                   shared_binds                          -- the bindings used both in rhs and body
 
         -- Push rhs_binds into the right hand side of the binding
-    rhs'     = fiExpr dflags rhs_binds rhs
+    rhs'     = fiRhs dflags rhs_binds id rhs
     rhs_fvs' = rhs_fvs `unionDVarSet` floatedBindsFVs rhs_binds `unionDVarSet` rule_fvs
                         -- Don't forget the rule_fvs; the binding mentions them!
 
@@ -341,8 +377,8 @@ fiExpr dflags to_drop (_,AnnLet (AnnRec bindings) body)
         -- See Note [extra_fvs (1,2)]
     rule_fvs = mapUnionDVarSet idRuleAndUnfoldingVarsDSet ids
     extra_fvs = rule_fvs `unionDVarSet`
-                unionDVarSets [ freeVarsOf rhs | rhs@(_, rhs') <- rhss
-                              , noFloatIntoExpr rhs' ]
+                unionDVarSets [ freeVarsOf rhs | (bndr, rhs) <- bindings
+                              , noFloatIntoRhs (isJoinId bndr) Recursive rhs ]
 
     (shared_binds:extra_binds:body_binds:rhss_binds)
         = sepBindsByDropPoint dflags False
@@ -367,7 +403,7 @@ fiExpr dflags to_drop (_,AnnLet (AnnRec bindings) body)
             -> [(Id, CoreExpr)]
 
     fi_bind to_drops pairs
-      = [ (binder, fiExpr dflags to_drop rhs)
+      = [ (binder, fiRhs dflags to_drop binder rhs)
         | ((binder, rhs), to_drop) <- zipEqual "fi_bind" pairs to_drops ]
 
 {-
@@ -418,7 +454,8 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
 
         -- Float into the alts with the is_case flag set
     (drop_here2 : alts_drops_s)
-      = sepBindsByDropPoint dflags True alts_fvs all_alts_ty_fvs alts_drops
+      = sepBindsByDropPoint dflags True alts_fvs all_alts_ty_fvs
+                            alts_drops
 
     scrut_fvs       = freeVarsOf scrut
     alts_fvs        = map alt_fvs alts
@@ -434,17 +471,29 @@ fiExpr dflags to_drop (_, AnnCase scrut case_bndr ty alts)
 
     fi_alt to_drop (con, args, rhs) = (con, args, fiExpr dflags to_drop rhs)
 
+fiRhs :: DynFlags -> FloatInBinds -> CoreBndr -> CoreExprWithFVs -> CoreExpr
+fiRhs dflags to_drop bndr rhs
+  | Just join_arity <- isJoinId_maybe bndr
+  , let (bndrs, body) = collectNAnnBndrs join_arity rhs
+  = mkLams bndrs (fiExpr dflags to_drop body)
+  | otherwise
+  = fiExpr dflags to_drop rhs
+
 okToFloatInside :: [Var] -> Bool
 okToFloatInside bndrs = all ok bndrs
   where
     ok b = not (isId b) || isOneShotBndr b
     -- Push the floats inside there are no non-one-shot value binders
 
-noFloatIntoRhs :: CoreExprWithFVs -> Bool
+noFloatIntoRhs :: Bool -> RecFlag -> CoreExprWithFVs -> Bool
 -- ^ True if it's a bad idea to float bindings into this RHS
 -- Preconditio:  rhs :: rhs_ty
-noFloatIntoRhs rhs@(_, rhs')
-  =  isUnliftedType rhs_ty   -- See Note [Do not destroy the let/app invariant]
+noFloatIntoRhs is_join is_rec rhs@(_, rhs')
+  |  is_join
+  =  isRec is_rec -- Joins are one-shot iff non-recursive
+  |  otherwise
+  =  isUnliftedType rhs_ty
+       -- See Note [Do not destroy the let/app invariant]
   || noFloatIntoExpr rhs'
   where
     rhs_ty = exprTypeFV rhs
index 10955d2..17ffba4 100644 (file)
@@ -19,7 +19,8 @@ import CoreMonad        ( FloatOutSwitches(..) )
 
 import DynFlags
 import ErrUtils         ( dumpIfSet_dyn )
-import Id               ( Id, idArity, isBottomingId )
+import Id               ( Id, idArity, idType, isBottomingId,
+                          isJoinId, isJoinId_maybe )
 import Var              ( Var )
 import SetLevels
 import UniqSupply       ( UniqSupply )
@@ -27,8 +28,11 @@ import Bag
 import Util
 import Maybes
 import Outputable
+import Type
 import qualified Data.IntMap as M
 
+import Data.List        ( partition )
+
 #include "HsVersions.h"
 
 {-
@@ -104,6 +108,52 @@ vwhich might usefully be separated to
 @
 Well, maybe.  We don't do this at the moment.
 
+Note [Join points]
+~~~~~~~~~~~~~~~~~~
+Every occurrence of a join point must be a tail call (see Note [Invariants on
+join points] in CoreSyn), so we must be careful with how far we float them. The
+mechanism for doing so is the *join ceiling*, detailed in Note [Join ceiling]
+in SetLevels. For us, the significance is that a binder might be marked to be
+dropped at the nearest boundary between tail calls and non-tail calls. For
+example:
+
+  (< join j = ... in
+     let x = < ... > in
+     case < ... > of
+       A -> ...
+       B -> ...
+   >) < ... > < ... >
+
+Here the join ceilings are marked with angle brackets. Either side of an
+application is a join ceiling, as is the scrutinee position of a case
+expression or the RHS of a let binding (but not a join point).
+
+Why do we *want* do float join points at all? After all, they're never
+allocated, so there's no sharing to be gained by floating them. However, the
+other benefit of floating is making RHSes small, and this can have a significant
+impact. In particular, stream fusion has been known to produce nested loops like
+this:
+
+  joinrec j1 x1 =
+    joinrec j2 x2 =
+      joinrec j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
+      in jump j3 x2
+    in jump j2 x1
+  in jump j1 x
+
+(Assume x1 and x2 do *not* occur free in j3.)
+
+Here j1 and j2 are wholly superfluous---each of them merely forwards its
+argument to j3. Since j3 only refers to x3, we can float j2 and j3 to make
+everything one big mutual recursion:
+
+  joinrec j1 x1 = jump j2 x1
+          j2 x2 = jump j3 x2
+          j3 x3 = ... jump j1 (x3 + 1) ... jump j2 (x3 + 1) ...
+  in jump j1 x
+
+Now the simplifier will happily inline the trivial j1 and j2, leaving only j3.
+Without floating, we're stuck with three loops instead of one.
 
 ************************************************************************
 *                                                                      *
@@ -141,8 +191,11 @@ floatTopBind bind
   = case (floatBind bind) of { (fs, floats, bind') ->
     let float_bag = flattenTopFloats floats
     in case bind' of
-      Rec prs   -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs)))
-      NonRec {} -> (fs, float_bag `snocBag` bind') }
+      -- bind' can't have unlifted values or join points, so can only be one
+      -- value bind, rec or non-rec (see comment on floatBind)
+      [Rec prs]    -> (fs, unitBag (Rec (addTopFloatPairs float_bag prs)))
+      [NonRec b e] -> (fs, float_bag `snocBag` NonRec b e)
+      _            -> pprPanic "floatTopBind" (ppr bind') }
 
 {-
 ************************************************************************
@@ -152,42 +205,76 @@ floatTopBind bind
 ************************************************************************
 -}
 
-floatBind :: LevelledBind -> (FloatStats, FloatBinds, CoreBind)
+floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
+  -- Returns a list with either
+  --   * A single non-recursive binding (value or join point), or
+  --   * The following, in order:
+  --     * Zero or more non-rec unlifted bindings
+  --     * One or both of:
+  --       * A recursive group of join binds
+  --       * A recursive group of value binds
+  -- See Note [Floating out of Rec rhss] for why things get arranged this way.
 floatBind (NonRec (TB var _) rhs)
-  = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
+  = case (floatRhs var rhs) of { (fs, rhs_floats, rhs') ->
 
         -- A tiresome hack:
         -- see Note [Bottoming floats: eta expansion] in SetLevels
     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
               | otherwise         = rhs'
 
-    in (fs, rhs_floats, NonRec var rhs'') }
+    in (fs, rhs_floats, [NonRec var rhs'']) }
 
 floatBind (Rec pairs)
   = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
-    (fs, rhs_floats, Rec (concat new_pairs)) }
+    let (new_ul_pairss, new_other_pairss) = unzip new_pairs
+        (new_join_pairs, new_l_pairs)     = partition (isJoinId . fst)
+                                                      (concat new_other_pairss)
+        -- Can't put the join points and the values in the same rec group
+        new_rec_binds | null new_join_pairs = [ Rec new_l_pairs    ]
+                      | null new_l_pairs    = [ Rec new_join_pairs ]
+                      | otherwise           = [ Rec new_l_pairs
+                                              , Rec new_join_pairs ]
+        new_non_rec_binds = [ NonRec b e | (b, e) <- concat new_ul_pairss ]
+    in
+    (fs, rhs_floats, new_non_rec_binds ++ new_rec_binds) }
   where
+    do_pair :: (LevelledBndr, LevelledExpr)
+            -> (FloatStats, FloatBinds,
+                ([(Id,CoreExpr)],  -- Non-recursive unlifted value bindings
+                 [(Id,CoreExpr)])) -- Join points and lifted value bindings
     do_pair (TB name spec, rhs)
       | isTopLvl dest_lvl  -- See Note [floatBind for top level]
-      = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
-        (fs, emptyFloats, addTopFloatPairs (flattenTopFloats rhs_floats) [(name, rhs')])}
+      = case (floatRhs name rhs) of { (fs, rhs_floats, rhs') ->
+        (fs, emptyFloats, ([], addTopFloatPairs (flattenTopFloats rhs_floats)
+                                                [(name, rhs')]))}
       | otherwise         -- Note [Floating out of Rec rhss]
-      = case (floatExpr rhs) of { (fs, rhs_floats, rhs') ->
+      = case (floatRhs name rhs) of { (fs, rhs_floats, rhs') ->
         case (partitionByLevel dest_lvl rhs_floats) of { (rhs_floats', heres) ->
-        case (splitRecFloats heres) of { (pairs, case_heres) ->
-        (fs, rhs_floats', (name, installUnderLambdas case_heres rhs') : pairs) }}}
+        case (splitRecFloats heres) of { (ul_pairs, pairs, case_heres) ->
+        let pairs' = (name, installUnderLambdas case_heres rhs') : pairs in
+        (fs, rhs_floats', (ul_pairs, pairs')) }}}
       where
         dest_lvl = floatSpecLevel spec
 
-splitRecFloats :: Bag FloatBind -> ([(Id,CoreExpr)], Bag FloatBind)
+splitRecFloats :: Bag FloatBind
+               -> ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
+                   [(Id,CoreExpr)], -- Join points and lifted value bindings
+                   Bag FloatBind)   -- A tail of further bindings
 -- The "tail" begins with a case
 -- See Note [Floating out of Rec rhss]
 splitRecFloats fs
-  = go [] (bagToList fs)
+  = go [] [] (bagToList fs)
   where
-    go prs (FloatLet (NonRec b r) : fs) = go ((b,r):prs) fs
-    go prs (FloatLet (Rec prs')   : fs) = go (prs' ++ prs) fs
-    go prs fs                           = (prs, listToBag fs)
+    go ul_prs prs (FloatLet (NonRec b r) : fs) | isUnliftedType (idType b)
+                                               , not (isJoinId b)
+                                               = go ((b,r):ul_prs) prs fs
+                                               | otherwise
+                                               = go ul_prs ((b,r):prs) fs
+    go ul_prs prs (FloatLet (Rec prs')   : fs) = go ul_prs (prs' ++ prs) fs
+    go ul_prs prs fs                           = (reverse ul_prs, prs,
+                                                  listToBag fs)
+                                                   -- Order only matters for
+                                                   -- non-rec
 
 installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
 -- Note [Floating out of Rec rhss]
@@ -227,6 +314,31 @@ So, gruesomely, we split the floats into
 This loses full-laziness the rare situation where there is a
 FloatCase and a Rec interacting.
 
+If there are unlifted FloatLets (that *aren't* join points) among the floats,
+we can't add them to the recursive group without angering Core Lint, but since
+they must be ok-for-speculation, they can't actually be making any recursive
+calls, so we can safely pull them out and keep them non-recursive.
+
+(Why is something getting floated to <1,0> that doesn't make a recursive call?
+The case that came up in testing was that f *and* the unlifted binding were
+getting floated *to the same place*:
+
+  \x<2,0> ->
+    ... <3,0>
+    letrec { f<F<2,0>> =
+      ... let x'<F<2,0>> = x +# 1# in ...
+    } in ...
+
+Everything gets labeled "float to <2,0>" because it all depends on x, but this
+makes f and x' look mutually recursive when they're not.
+
+The test was shootout/k-nucleotide, as compiled using commit 47d5dd68 on the
+wip/join-points branch.
+
+TODO: This can probably be solved somehow in SetLevels. The difference between
+"this *is at* level <2,0>" and "this *depends on* level <2,0>" is very
+important.)
+
 Note [floatBind for top level]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus
@@ -285,27 +397,28 @@ floatExpr (Coercion co) = (zeroStats, emptyFloats, Coercion co)
 floatExpr (Lit lit) = (zeroStats, emptyFloats, Lit lit)
 
 floatExpr (App e a)
-  = case (floatExpr  e) of { (fse, floats_e, e') ->
-    case (floatExpr  a) of { (fsa, floats_a, a') ->
+  = case (atJoinCeiling $ floatExpr  e) of { (fse, floats_e, e') ->
+    case (atJoinCeiling $ floatExpr  a) of { (fsa, floats_a, a') ->
     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
 
 floatExpr lam@(Lam (TB _ lam_spec) _)
   = let (bndrs_w_lvls, body) = collectBinders lam
         bndrs                = [b | TB b _ <- bndrs_w_lvls]
-        bndr_lvl             = floatSpecLevel lam_spec
+        bndr_lvl             = asJoinCeilLvl (floatSpecLevel lam_spec)
         -- All the binders have the same level
         -- See SetLevels.lvlLamBndrs
+        -- Use asJoinCeilLvl to make this the join ceiling
     in
     case (floatBody bndr_lvl body) of { (fs, floats, body') ->
     (add_to_stats fs floats, floats, mkLams bndrs body') }
 
 floatExpr (Tick tickish expr)
   | tickish `tickishScopesLike` SoftScope -- not scoped, can just float
-  = case (floatExpr expr)    of { (fs, floating_defns, expr') ->
+  = case (atJoinCeiling $ floatExpr expr)    of { (fs, floating_defns, expr') ->
     (fs, floating_defns, Tick tickish expr') }
 
   | not (tickishCounts tickish) || tickishCanSplit tickish
-  = case (floatExpr expr)    of { (fs, floating_defns, expr') ->
+  = case (atJoinCeiling $ floatExpr expr)    of { (fs, floating_defns, expr') ->
     let -- Annotate bindings floated outwards past an scc expression
         -- with the cc.  We mark that cc as "duplicated", though.
         annotated_defns = wrapTick (mkNoCount tickish) floating_defns
@@ -321,25 +434,27 @@ floatExpr (Tick tickish expr)
   = pprPanic "floatExpr tick" (ppr tickish)
 
 floatExpr (Cast expr co)
-  = case (floatExpr expr) of { (fs, floating_defns, expr') ->
+  = case (atJoinCeiling $ floatExpr expr) of { (fs, floating_defns, expr') ->
     (fs, floating_defns, Cast expr' co) }
 
 floatExpr (Let bind body)
   = case bind_spec of
       FloatMe dest_lvl
-        -> case (floatBind bind) of { (fsb, bind_floats, bind') ->
+        -> case (floatBind bind) of { (fsb, bind_floats, binds') ->
            case (floatExpr body) of { (fse, body_floats, body') ->
+           let new_bind_floats = foldr plusFloats emptyFloats
+                                   (map (unitLetFloat dest_lvl) binds') in
            ( add_stats fsb fse
-           , bind_floats `plusFloats` unitLetFloat dest_lvl bind'
+           , bind_floats `plusFloats` new_bind_floats
                          `plusFloats` body_floats
            , body') }}
 
       StayPut bind_lvl  -- See Note [Avoiding unnecessary floating]
-        -> case (floatBind bind)          of { (fsb, bind_floats, bind') ->
+        -> case (floatBind bind)          of { (fsb, bind_floats, binds') ->
            case (floatBody bind_lvl body) of { (fse, body_floats, body') ->
            ( add_stats fsb fse
            , bind_floats `plusFloats` body_floats
-           , Let bind' body') }}
+           , foldr Let body' binds' ) }}
   where
     bind_spec = case bind of
                  NonRec (TB _ s) _     -> s
@@ -350,8 +465,8 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
   = case case_spec of
       FloatMe dest_lvl  -- Case expression moves
         | [(con@(DataAlt {}), bndrs, rhs)] <- alts
-        -> case floatExpr scrut of { (fse, fde, scrut') ->
-           case floatExpr rhs   of { (fsb, fdb, rhs') ->
+        -> case atJoinCeiling $ floatExpr scrut of { (fse, fde, scrut') ->
+           case                 floatExpr rhs   of { (fsb, fdb, rhs') ->
            let
              float = unitCaseFloat dest_lvl scrut'
                           case_bndr con [b | TB b _ <- bndrs]
@@ -361,7 +476,7 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
         -> pprPanic "Floating multi-case" (ppr alts)
 
       StayPut bind_lvl  -- Case expression stays put
-        -> case floatExpr scrut of { (fse, fde, scrut') ->
+        -> case atJoinCeiling $ floatExpr scrut of { (fse, fde, scrut') ->
            case floatList (float_alt bind_lvl) alts of { (fsa, fda, alts')  ->
            (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
            }}
@@ -370,6 +485,25 @@ floatExpr (Case scrut (TB case_bndr case_spec) ty alts)
         = case (floatBody bind_lvl rhs) of { (fs, rhs_floats, rhs') ->
           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
 
+floatRhs :: CoreBndr
+         -> LevelledExpr
+         -> (FloatStats, FloatBinds, CoreExpr)
+floatRhs bndr rhs
+  | Just join_arity <- isJoinId_maybe bndr
+  , Just (bndrs, body) <- try_collect join_arity rhs []
+  = case bndrs of
+      []                -> floatExpr rhs
+      (TB _ lam_spec):_ ->
+        let lvl = floatSpecLevel lam_spec in
+        case floatBody lvl body of { (fs, floats, body') ->
+        (fs, floats, mkLams [b | TB b _ <- bndrs] body') }
+  | otherwise
+  = atJoinCeiling $ floatExpr rhs
+  where
+    try_collect 0 expr      acc = Just (reverse acc, expr)
+    try_collect n (Lam b e) acc = try_collect (n-1) e (b:acc)
+    try_collect _ _         _   = Nothing
+
 {-
 Note [Avoiding unnecessary floating]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -439,8 +573,10 @@ add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
 
 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
-add_to_stats (FlS a b c) (FB tops others)
-  = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
+add_to_stats (FlS a b c) (FB tops ceils others)
+  = FlS (a + lengthBag tops)
+        (b + lengthBag ceils + lengthBag (flattenMajor others))
+        (c + 1)
 
 {-
 ************************************************************************
@@ -474,18 +610,21 @@ type MajorEnv = M.IntMap MinorEnv         -- Keyed by major level
 type MinorEnv = M.IntMap (Bag FloatBind)  -- Keyed by minor level
 
 data FloatBinds  = FB !(Bag FloatLet)           -- Destined for top level
-                      !MajorEnv                 -- Levels other than top
+                      !(Bag FloatBind)          -- Destined for join ceiling
+                      !MajorEnv                 -- Other levels
      -- See Note [Representation of FloatBinds]
 
 instance Outputable FloatBinds where
-  ppr (FB fbs defs)
+  ppr (FB fbs ceils defs)
       = text "FB" <+> (braces $ vcat
            [ text "tops ="     <+> ppr fbs
+           , text "ceils ="    <+> ppr ceils
            , text "non-tops =" <+> ppr defs ])
 
 flattenTopFloats :: FloatBinds -> Bag CoreBind
-flattenTopFloats (FB tops defs)
+flattenTopFloats (FB tops ceils defs)
   = ASSERT2( isEmptyBag (flattenMajor defs), ppr defs )
+    ASSERT2( isEmptyBag ceils, ppr ceils )
     tops
 
 addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
@@ -502,22 +641,29 @@ flattenMinor :: MinorEnv -> Bag FloatBind
 flattenMinor = M.foldr unionBags emptyBag
 
 emptyFloats :: FloatBinds
-emptyFloats = FB emptyBag M.empty
+emptyFloats = FB emptyBag emptyBag M.empty
 
 unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
-unitCaseFloat (Level major minor) e b con bs
-  = FB emptyBag (M.singleton major (M.singleton minor (unitBag (FloatCase e b con bs))))
+unitCaseFloat (Level major minor t) e b con bs
+  | t == JoinCeilLvl
+  = FB emptyBag floats M.empty
+  | otherwise
+  = FB emptyBag emptyBag (M.singleton major (M.singleton minor floats))
+  where
+    floats = unitBag (FloatCase e b con bs)
 
 unitLetFloat :: Level -> FloatLet -> FloatBinds
-unitLetFloat lvl@(Level major minor) b
-  | isTopLvl lvl = FB (unitBag b) M.empty
-  | otherwise    = FB emptyBag (M.singleton major (M.singleton minor floats))
+unitLetFloat lvl@(Level major minor t) b
+  | isTopLvl lvl     = FB (unitBag b) emptyBag M.empty
+  | t == JoinCeilLvl = FB emptyBag floats M.empty
+  | otherwise        = FB emptyBag emptyBag (M.singleton major
+                                              (M.singleton minor floats))
   where
     floats = unitBag (FloatLet b)
 
 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
-plusFloats (FB t1 l1) (FB t2 l2)
-  = FB (t1 `unionBags` t2) (l1 `plusMajor` l2)
+plusFloats (FB t1 c1 l1) (FB t2 c2 l2)
+  = FB (t1 `unionBags` t2) (c1 `unionBags` c2) (l1 `plusMajor` l2)
 
 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
 plusMajor = M.unionWith plusMinor
@@ -557,9 +703,10 @@ partitionByMajorLevel (Level major _) (FB tops defns)
                Just h  -> flattenMinor h
 -}
 
-partitionByLevel (Level major minor) (FB tops defns)
-  = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
-     here_min `unionBags` flattenMinor inner_min
+partitionByLevel (Level major minor typ) (FB tops ceils defns)
+  = (FB tops ceils' (outer_maj `plusMajor` M.singleton major outer_min),
+     here_min `unionBags` here_ceil
+              `unionBags` flattenMinor inner_min
               `unionBags` flattenMajor inner_maj)
 
   where
@@ -568,10 +715,28 @@ partitionByLevel (Level major minor) (FB tops defns)
                                             Nothing -> (M.empty, Nothing, M.empty)
                                             Just min_defns -> M.splitLookup minor min_defns
     here_min = mb_here_min `orElse` emptyBag
+    (here_ceil, ceils') | typ == JoinCeilLvl = (ceils, emptyBag)
+                        | otherwise          = (emptyBag, ceils)
+
+-- Like partitionByLevel, but instead split out the bindings that are marked
+-- to float to the nearest join ceiling (see Note [Join points])
+partitionAtJoinCeiling :: FloatBinds -> (FloatBinds, Bag FloatBind)
+partitionAtJoinCeiling (FB tops ceils defs)
+  = (FB tops emptyBag defs, ceils)
+
+-- Perform some action at a join ceiling, i.e., don't let join points float out
+-- (see Note [Join points])
+atJoinCeiling :: (FloatStats, FloatBinds, CoreExpr)
+              -> (FloatStats, FloatBinds, CoreExpr)
+atJoinCeiling (fs, floats, expr')
+  = (fs, floats', install ceils expr')
+  where
+    (floats', ceils) = partitionAtJoinCeiling floats
 
 wrapTick :: Tickish Id -> FloatBinds -> FloatBinds
-wrapTick t (FB tops defns)
-  = FB (mapBag wrap_bind tops) (M.map (M.map wrap_defns) defns)
+wrapTick t (FB tops ceils defns)
+  = FB (mapBag wrap_bind tops) (wrap_defns ceils)
+       (M.map (M.map wrap_defns) defns)
   where
     wrap_defns = mapBag wrap_one
 
index 1df1405..1776db5 100644 (file)
@@ -197,10 +197,13 @@ libCase :: LibCaseEnv
         -> CoreExpr
         -> CoreExpr
 
-libCase env (Var v)             = libCaseId env v
+libCase env (Var v)             = libCaseApp env v []
 libCase _   (Lit lit)           = Lit lit
 libCase _   (Type ty)           = Type ty
 libCase _   (Coercion co)       = Coercion co
+libCase env e@(App {})          | let (fun, args) = collectArgs e
+                                , Var v <- fun
+                                = libCaseApp env v args
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Tick tickish body) = Tick tickish (libCase env body)
 libCase env (Cast e co)         = Cast (libCase env e) co
@@ -228,20 +231,31 @@ libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 {-
 Ids
 ~~~
+
+To unfold, we can't just wrap the id itself in its binding if it's a join point:
+
+  jump j a b c  =>  (joinrec j x y z = ... in jump j) a b c -- wrong!!!
+
+Every jump must provide all arguments, so we have to be careful to wrap the
+whole jump instead:
+
+  jump j a b c  =>  joinrec j x y z = ... in jump j a b c -- right
+
 -}
 
-libCaseId :: LibCaseEnv -> Id -> CoreExpr
-libCaseId env v
+libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
+libCaseApp env v args
   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
   , notNull free_scruts                 -- with free vars scrutinised in RHS
-  = Let the_bind (Var v)
+  = Let the_bind expr'
 
   | otherwise
-  = Var v
+  = expr'
 
   where
     rec_id_level = lookupLevel env v
     free_scruts  = freeScruts env rec_id_level
+    expr'        = mkApps (Var v) (map (libCase env) args)
 
 freeScruts :: LibCaseEnv
            -> LibCaseLevel      -- Level of the recursive Id
index a50fe22..864d468 100644 (file)
@@ -11,7 +11,7 @@ The occurrence analyser re-typechecks a core expression, returning a new
 core expression with (hopefully) improved usage information.
 -}
 
-{-# LANGUAGE CPP, BangPatterns #-}
+{-# LANGUAGE CPP, BangPatterns, MultiWayIf #-}
 
 module OccurAnal (
         occurAnalysePgm, occurAnalyseExpr, occurAnalyseExpr_NoBinderSwap
@@ -24,16 +24,17 @@ import CoreFVs
 import CoreUtils        ( exprIsTrivial, isDefaultAlt, isExpandableApp,
                           stripTicksTopE, mkTicks )
 import Id
+import IdInfo
 import Name( localiseName )
 import BasicTypes
 import Module( Module )
 import Coercion
+import Type
 
 import VarSet
 import VarEnv
 import Var
 import Demand           ( argOneShots, argsOneShots )
-import Maybes           ( orElse )
 import Digraph          ( SCC(..), Node
                         , stronglyConnCompFromEdgedVerticesUniq
                         , stronglyConnCompFromEdgedVerticesUniqR )
@@ -59,7 +60,7 @@ occurAnalysePgm :: Module       -- Used only in debug output
                 -> [CoreRule] -> [CoreVect] -> VarSet
                 -> CoreProgram -> CoreProgram
 occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
-  | isEmptyVarEnv final_usage
+  | isEmptyDetails final_usage
   = occ_anald_binds
 
   | otherwise   -- See Note [Glomming]
@@ -69,14 +70,15 @@ occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
   where
     init_env = initOccEnv active_rule
     (final_usage, occ_anald_binds) = go init_env binds
-    (_, occ_anald_glommed_binds)   = occAnalRecBind init_env imp_rule_edges
+    (_, occ_anald_glommed_binds)   = occAnalRecBind init_env TopLevel
+                                                    imp_rule_edges
                                                     (flattenBinds occ_anald_binds)
                                                     initial_uds
           -- It's crucial to re-analyse the glommed-together bindings
           -- so that we establish the right loop breakers. Otherwise
           -- we can easily create an infinite loop (Trac #9583 is an example)
 
-    initial_uds = addIdOccs emptyDetails
+    initial_uds = addManyOccsSet emptyDetails
                             (rulesFreeVars imp_rules `unionVarSet`
                              vectsFreeVars vects `unionVarSet`
                              vectVars)
@@ -100,7 +102,8 @@ occurAnalysePgm this_mod active_rule imp_rules vects vectVars binds
         = (final_usage, bind' ++ binds')
         where
            (bs_usage, binds')   = go env binds
-           (final_usage, bind') = occAnalBind env imp_rule_edges bind bs_usage
+           (final_usage, bind') = occAnalBind env TopLevel imp_rule_edges bind
+                                              bs_usage
 
 occurAnalyseExpr :: CoreExpr -> CoreExpr
         -- Do occurrence analysis, and discard occurrence info returned
@@ -640,6 +643,133 @@ But watch out!  If 'fs' is not chosen as a loop breaker, we may get an infinite
   - now there's another opportunity to apply the RULE
 
 This showed up when compiling Control.Concurrent.Chan.getChanContents.
+
+------------------------------------------------------------
+Note [Finding join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+It's the occurrence analyser's job to find bindings that we can turn into join
+points, but it doesn't perform that transformation right away. Rather, it marks
+the eligible bindings as part of their occurrence data, leaving it to the
+simplifier (or to simpleOptPgm) to actually change the binder's 'IdDetails'.
+The simplifier then eta-expands the RHS if needed and then updates the
+occurrence sites. Dividing the work this way means that the occurrence analyser
+still only takes one pass, yet one can always tell the difference between a
+function call and a jump by looking at the occurrence (because the same pass
+changes the 'IdDetails' and propagates the binders to their occurrence sites).
+
+To track potential join points, we use the 'occ_tail' field of OccInfo. A value
+of `AlwaysTailCalled n` indicates that every occurrence of the variable is a
+tail call with `n` arguments (counting both value and type arguments). Otherwise
+'occ_tail' will be 'NoTailCallInfo'. The tail call info flows bottom-up with the
+rest of 'OccInfo' until it goes on the binder.
+
+Note [Rules and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Things get fiddly with rules. Suppose we have:
+
+  let j :: Int -> Int
+      j y = 2 * y
+      k :: Int -> Int -> Int
+      {-# RULES "SPEC k 0" k 0 = j #-}
+      k x y = x + 2 * y
+  in ...
+
+Now suppose that both j and k appear only as saturated tail calls in the body.
+Thus we would like to make them both join points. The rule complicates matters,
+though, as its RHS has an unapplied occurrence of j. *However*, if we were to
+eta-expand the rule, all would be well:
+
+  {-# RULES "SPEC k 0" forall a. k 0 a = j a #-}
+
+So conceivably we could notice that a potential join point would have an
+"undersaturated" rule and account for it. This would mean we could make
+something that's been specialised a join point, for instance. But local bindings
+are rarely specialised, and being overly cautious about rules only
+costs us anything when, for some `j`:
+
+  * Before specialisation, `j` has non-tail calls, so it can't be a join point.
+  * During specialisation, `j` gets specialised and thus acquires rules.
+  * Sometime afterward, the non-tail calls to `j` disappear (as dead code, say),
+    and so now `j` *could* become a join point.
+
+This appears to be very rare in practice. TODO Perhaps we should gather
+statistics to be sure.
+
+Note [Excess polymorphism and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In principle, if a function would be a join point except that it fails
+the polymorphism rule (see Note [The polymorphism rule of join points] in
+CoreSyn), it can still be made a join point with some effort. This is because
+all tail calls must return the same type (they return to the same context!), and
+thus if the return type depends on an argument, that argument must always be the
+same.
+
+For instance, consider:
+
+  let f :: forall a. a -> Char -> [a]
+      f @a x c = ... f @a x 'a' ...
+  in ... f @Int 1 'b' ... f @Int 2 'c' ...
+
+(where the calls are tail calls). `f` fails the polymorphism rule because its
+return type is [a], where [a] is bound. But since the type argument is always
+'Int', we can rewrite it as:
+
+  let f' :: Int -> Char -> [Int]
+      f' x c = ... f' x 'a' ...
+  in ... f' 1 'b' ... f 2 'c' ...
+
+and now we can make f' a join point:
+
+  join f' :: Int -> Char -> [Int]
+       f' x c = ... jump f' x 'a' ...
+  in ... jump f' 1 'b' ... jump f' 2 'c' ...
+
+It's not clear that this comes up often, however. TODO: Measure how often and
+add this analysis if necessary.
+
+------------------------------------------------------------
+Note [Adjusting for lambdas]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There's a bit of a dance we need to do after analysing a lambda expression or
+a right-hand side. In particular, we need to
+
+  a) call 'markAllInsideLam' *unless* the binding is for a thunk, a one-shot
+     lambda, or a non-recursive join point; and
+  b) call 'markAllNonTailCalled' *unless* the binding is for a join point.
+
+Some examples, with how the free occurrences in e (assumed not to be a value
+lambda) get marked:
+
+                             inside lam    non-tail-called
+  ------------------------------------------------------------
+  let x = e                  No            Yes
+  let f = \x -> e            Yes           Yes
+  let f = \x{OneShot} -> e   No            Yes
+  \x -> e                    Yes           Yes
+  join j x = e               No            No
+  joinrec j x = e            Yes           No
+
+There are a few other caveats; most importantly, if we're marking a binding as
+'AlwaysTailCalled', it's *going* to be a join point, so we treat it as one so
+that the effect cascades properly. Consequently, at the time the RHS is
+analysed, we won't know what adjustments to make; thus 'occAnalLamOrRhs' must
+return the unadjusted 'UsageDetails', to be adjusted by 'adjustRhsUsage' once
+join-point-hood has been decided.
+
+Thus the overall sequence taking place in 'occAnalNonRecBind' and
+'occAnalRecBind' is as follows:
+
+  1. Call 'occAnalLamOrRhs' to find usage information for the RHS.
+  2. Call 'tagNonRecBinder' or 'tagRecBinders', which decides whether to make
+     the binding a join point.
+  3. Call 'adjustRhsUsage' accordingly. (Done as part of 'tagRecBinders' when
+     recursive.)
+
+(In the recursive case, this logic is spread between 'makeNode' and
+'occAnalRec'.)
 -}
 
 ------------------------------------------------------------------
@@ -647,21 +777,22 @@ This showed up when compiling Control.Concurrent.Chan.getChanContents.
 ------------------------------------------------------------------
 
 occAnalBind :: OccEnv           -- The incoming OccEnv
+            -> TopLevelFlag
             -> ImpRuleEdges
             -> CoreBind
             -> UsageDetails             -- Usage details of scope
             -> (UsageDetails,           -- Of the whole let(rec)
                 [CoreBind])
 
-occAnalBind env top_env (NonRec binder rhs) body_usage
-  = occAnalNonRecBind env top_env binder rhs body_usage
-occAnalBind env top_env (Rec pairs) body_usage
-  = occAnalRecBind env top_env pairs body_usage
+occAnalBind env lvl top_env (NonRec binder rhs) body_usage
+  = occAnalNonRecBind env lvl top_env binder rhs body_usage
+occAnalBind env lvl top_env (Rec pairs) body_usage
+  = occAnalRecBind env lvl top_env pairs body_usage
 
 -----------------
-occAnalNonRecBind :: OccEnv -> ImpRuleEdges -> Var -> CoreExpr
+occAnalNonRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> Var -> CoreExpr
                   -> UsageDetails -> (UsageDetails, [CoreBind])
-occAnalNonRecBind env imp_rule_edges binder rhs body_usage
+occAnalNonRecBind env lvl imp_rule_edges binder rhs body_usage
   | isTyVar binder      -- A type let; we don't gather usage info
   = (body_usage, [NonRec binder rhs])
 
@@ -669,24 +800,36 @@ occAnalNonRecBind env imp_rule_edges binder rhs body_usage
   = (body_usage, [])
 
   | otherwise                   -- It's mentioned in the body
-  = (body_usage' +++ rhs_usage4, [NonRec tagged_binder rhs'])
+  = (body_usage' +++ rhs_usage', [NonRec tagged_binder rhs'])
   where
-    (body_usage', tagged_binder) = tagBinder body_usage binder
-    (rhs_usage1, rhs')           = occAnalNonRecRhs env tagged_binder rhs
-    rhs_usage2 = addIdOccs rhs_usage1 (idUnfoldingVars binder)
-
-    rhs_usage3 = addIdOccs rhs_usage2 (idRuleVars binder)
+    (bndrs, body) = collectBinders rhs
+    (body_usage', tagged_binder) = tagNonRecBinder lvl body_usage binder
+    (rhs_usage1, bndrs', body') = occAnalNonRecRhs env tagged_binder bndrs body
+    rhs' = mkLams bndrs' body'
+    rhs_usage2 = case occAnalUnfolding env NonRecursive binder of
+                   Just unf_usage -> rhs_usage1 +++ unf_usage
+                   Nothing        -> rhs_usage1
+       -- See Note [Unfoldings and join points]
+
+    mb_join_arity = willBeJoinId_maybe tagged_binder
+    rules_w_uds = occAnalRules env mb_join_arity NonRecursive tagged_binder
+
+    rhs_usage3 = rhs_usage2 +++ combineUsageDetailsList
+                                  (map (\(_, l, r) -> l +++ r) rules_w_uds)
        -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
 
-    rhs_usage4 = maybe rhs_usage3 (addIdOccs rhs_usage3) $
+    rhs_usage4 = maybe rhs_usage3 (addManyOccsSet rhs_usage3) $
                  lookupVarEnv imp_rule_edges binder
        -- See Note [Preventing loops due to imported functions rules]
 
+    rhs_usage' = adjustRhsUsage (willBeJoinId_maybe tagged_binder) NonRecursive
+                                bndrs' rhs_usage4
+
 -----------------
-occAnalRecBind :: OccEnv -> ImpRuleEdges -> [(Var,CoreExpr)]
+occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
                -> UsageDetails -> (UsageDetails, [CoreBind])
-occAnalRecBind env imp_rule_edges pairs body_usage
-  = foldr occAnalRec (body_usage, []) sccs
+occAnalRecBind env lvl imp_rule_edges pairs body_usage
+  = foldr (occAnalRec lvl) (body_usage, []) sccs
         -- For a recursive group, we
         --      * occ-analyse all the RHSs
         --      * compute strongly-connected components
@@ -703,27 +846,40 @@ occAnalRecBind env imp_rule_edges pairs body_usage
 
     bndr_set = mkVarSet (map fst pairs)
 
+{-
+Note [Unfoldings and join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We assume that anything in an unfolding occurs multiple times, since unfoldings
+are often copied (that's the whole point!). But we still need to track tail
+calls for the purpose of finding join points.
+-}
+
 -----------------------------
-occAnalRec :: SCC Details
+occAnalRec :: TopLevelFlag
+           -> SCC Details
            -> (UsageDetails, [CoreBind])
            -> (UsageDetails, [CoreBind])
 
         -- The NonRec case is just like a Let (NonRec ...) above
-occAnalRec (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs, nd_uds = rhs_uds}))
+occAnalRec lvl (AcyclicSCC (ND { nd_bndr = bndr, nd_rhs = rhs
+                               , nd_uds = rhs_uds, nd_rhs_bndrs = rhs_bndrs }))
            (body_uds, binds)
   | not (bndr `usedIn` body_uds)
   = (body_uds, binds)           -- See Note [Dead code]
 
   | otherwise                   -- It's mentioned in the body
-  = (body_uds' +++ rhs_uds,
+  = (body_uds' +++ rhs_uds',
      NonRec tagged_bndr rhs : binds)
   where
-    (body_uds', tagged_bndr) = tagBinder body_uds bndr
+    (body_uds', tagged_bndr) = tagNonRecBinder lvl body_uds bndr
+    rhs_uds' = adjustRhsUsage (willBeJoinId_maybe tagged_bndr) NonRecursive
+                              rhs_bndrs rhs_uds
 
         -- The Rec case is the interesting one
         -- See Note [Recursive bindings: the grand plan]
         -- See Note [Loop breaking]
-occAnalRec (CyclicSCC details_s) (body_uds, binds)
+occAnalRec lvl (CyclicSCC details_s) (body_uds, binds)
   | not (any (`usedIn` body_uds) bndrs) -- NB: look at body_uds, not total_uds
   = (body_uds, binds)                   -- See Note [Dead code]
 
@@ -738,16 +894,12 @@ occAnalRec (CyclicSCC details_s) (body_uds, binds)
     bndrs    = map nd_bndr details_s
     bndr_set = mkVarSet bndrs
 
-    ----------------------------
-    -- Compute usage details
-    total_uds = foldl add_uds body_uds details_s
-    final_uds = total_uds `minusVarEnv` bndr_set
-    add_uds usage_so_far nd = usage_so_far +++ nd_uds nd
-
     ------------------------------
         -- See Note [Choosing loop breakers] for loop_breaker_nodes
+    final_uds :: UsageDetails
     loop_breaker_nodes :: [LetrecNode]
-    loop_breaker_nodes = mkLoopBreakerNodes bndr_set total_uds details_s
+    (final_uds, loop_breaker_nodes)
+      = mkLoopBreakerNodes lvl bndr_set body_uds details_s
 
     ------------------------------
     weak_fvs :: VarSet
@@ -832,13 +984,18 @@ reOrderNodes depth bndr_set weak_fvs (node : nodes) binds
 
 mk_loop_breaker :: LetrecNode -> Binding
 mk_loop_breaker (ND { nd_bndr = bndr, nd_rhs = rhs}, _, _)
-  = (setIdOccInfo bndr strongLoopBreaker, rhs)
+  = (bndr `setIdOccInfo` strongLoopBreaker { occ_tail = tail_info }, rhs)
+  where
+    tail_info = tailCallInfo (idOccInfo bndr)
 
 mk_non_loop_breaker :: VarSet -> LetrecNode -> Binding
 -- See Note [Weak loop breakers]
 mk_non_loop_breaker weak_fvs (ND { nd_bndr = bndr, nd_rhs = rhs}, _, _)
-  | bndr `elemVarSet` weak_fvs = (setIdOccInfo bndr weakLoopBreaker, rhs)
+  | bndr `elemVarSet` weak_fvs = (setIdOccInfo bndr occ', rhs)
   | otherwise                  = (bndr, rhs)
+  where
+    occ' = weakLoopBreaker { occ_tail = tail_info }
+    tail_info = tailCallInfo (idOccInfo bndr)
 
 ----------------------------------
 chooseLoopBreaker :: Bool             -- True <=> Too many iterations,
@@ -982,7 +1139,7 @@ we choose 'plus1' as the loop breaker (which is entirely possible
 otherwise), the loop does not unravel nicely.
 
 
-@occAnalRhs@ deals with the question of bindings where the Id is marked
+@occAnalUnfolding@ deals with the question of bindings where the Id is marked
 by an INLINE pragma.  For these we record that anything which occurs
 in its RHS occurs many times.  This pessimistically assumes that ths
 inlined binder also occurs many times in its scope, but if it doesn't
@@ -1010,6 +1167,9 @@ type LetrecNode = Node Unique Details  -- Node comes from Digraph
 data Details
   = ND { nd_bndr :: Id          -- Binder
        , nd_rhs  :: CoreExpr    -- RHS, already occ-analysed
+       , nd_rhs_bndrs :: [CoreBndr] -- Outer lambdas of RHS
+                                    -- INVARIANT: (nd_rhs_bndrs nd, _) ==
+                                    --              collectBinders (nd_rhs nd)
 
        , nd_uds  :: UsageDetails  -- Usage from RHS, and RULES, and stable unfoldings
                                   -- ignoring phase (ie assuming all are active)
@@ -1064,6 +1224,7 @@ makeNode env imp_rule_edges bndr_set (bndr, rhs)
   where
     details = ND { nd_bndr            = bndr
                  , nd_rhs             = rhs'
+                 , nd_rhs_bndrs       = bndrs'
                  , nd_uds             = rhs_usage3
                  , nd_inl             = inl_fvs
                  , nd_weak            = node_fvs `minusVarSet` inl_fvs
@@ -1072,54 +1233,66 @@ makeNode env imp_rule_edges bndr_set (bndr, rhs)
 
     -- Constructing the edges for the main Rec computation
     -- See Note [Forming Rec groups]
-    (rhs_usage1, rhs') = occAnalRecRhs env rhs
-    rhs_usage2 = addIdOccs rhs_usage1 all_rule_fvs   -- Note [Rules are extra RHSs]
-                                                     -- Note [Rule dependency info]
-    rhs_usage3 = case mb_unf_fvs of
-                   Just unf_fvs -> addIdOccs rhs_usage2 unf_fvs
+    (bndrs, body) = collectBinders rhs
+    (rhs_usage1, bndrs', body') = occAnalRecRhs env bndrs body
+    rhs' = mkLams bndrs' body'
+    rhs_usage2 = rhs_usage1 +++ all_rule_uds
+                   -- Note [Rules are extra RHSs]
+                   -- Note [Rule dependency info]
+    rhs_usage3 = case mb_unf_uds of
+                   Just unf_uds -> rhs_usage2 +++ unf_uds
                    Nothing      -> rhs_usage2
     node_fvs = udFreeVars bndr_set rhs_usage3
 
     -- Finding the free variables of the rules
     is_active = occ_rule_act env :: Activation -> Bool
-    rules = filterOut isBuiltinRule (idCoreRules bndr)
-    rules_w_fvs :: [(Activation, VarSet)]    -- Find the RHS fvs
-    rules_w_fvs = maybe id (\ids -> ((AlwaysActive, ids):)) (lookupVarEnv imp_rule_edges bndr)
-                   -- See Note [Preventing loops due to imported functions rules]
-                  [ (ru_act rule, fvs)
-                  | rule <- rules
-                  , let fvs = exprFreeVars (ru_rhs rule)
-                              `delVarSetList` ru_bndrs rule
-                  , not (isEmptyVarSet fvs) ]
-    all_rule_fvs = rule_lhs_fvs `unionVarSet` rule_rhs_fvs
-    rule_rhs_fvs = mapUnionVarSet snd rules_w_fvs
-    rule_lhs_fvs = mapUnionVarSet (\ru -> exprsFreeVars (ru_args ru)
-                                          `delVarSetList` ru_bndrs ru) rules
-    active_rule_fvs = unionVarSets [fvs | (a,fvs) <- rules_w_fvs, is_active a]
-
-    -- Finding the free variables of the INLINE pragma (if any)
-    unf        = realIdUnfolding bndr     -- Ignore any current loop-breaker flag
-    mb_unf_fvs = stableUnfoldingVars unf
+
+    rules_w_uds :: [(CoreRule, UsageDetails, UsageDetails)]
+    rules_w_uds = occAnalRules env (Just (length bndrs)) Recursive bndr
+
+    rules_w_rhs_fvs :: [(Activation, VarSet)]    -- Find the RHS fvs
+    rules_w_rhs_fvs = maybe id (\ids -> ((AlwaysActive, ids):))
+                               (lookupVarEnv imp_rule_edges bndr)
+      -- See Note [Preventing loops due to imported functions rules]
+                      [ (ru_act rule, udFreeVars bndr_set rhs_uds)
+                      | (rule, _, rhs_uds) <- rules_w_uds ]
+    all_rule_uds = combineUsageDetailsList $
+                     concatMap (\(_, l, r) -> [l, r]) rules_w_uds
+    active_rule_fvs = unionVarSets [fvs | (a,fvs) <- rules_w_rhs_fvs
+                                        , is_active a]
+
+    -- Finding the usage details of the INLINE pragma (if any)
+    mb_unf_uds = occAnalUnfolding env Recursive bndr
 
     -- Find the "nd_inl" free vars; for the loop-breaker phase
-    inl_fvs = case mb_unf_fvs of
+    inl_fvs = case mb_unf_uds of
                 Nothing -> udFreeVars bndr_set rhs_usage1 -- No INLINE, use RHS
-                Just unf_fvs -> unf_fvs
+                Just unf_uds -> udFreeVars bndr_set unf_uds
                       -- We could check for an *active* INLINE (returning
                       -- emptyVarSet for an inactive one), but is_active
                       -- isn't the right thing (it tells about
                       -- RULE activation), so we'd need more plumbing
 
-mkLoopBreakerNodes :: VarSet -> UsageDetails -> [Details] -> [LetrecNode]
--- Does three things
+mkLoopBreakerNodes :: TopLevelFlag
+                   -> VarSet
+                   -> UsageDetails   -- for BODY of let
+                   -> [Details]
+                   -> (UsageDetails, -- adjusted
+                       [LetrecNode])
+-- Does four things
 --   a) tag each binder with its occurrence info
 --   b) add a NodeScore to each node
 --   c) make a Node with the right dependency edges for
 --      the loop-breaker SCC analysis
-mkLoopBreakerNodes bndr_set total_uds details_s
-  = map mk_lb_node details_s
+--   d) adjust each RHS's usage details according to
+--      the binder's (new) shotness and join-point-hood
+mkLoopBreakerNodes lvl bndr_set body_uds details_s
+  = (final_uds, zipWith mk_lb_node details_s bndrs')
   where
-    mk_lb_node nd@(ND { nd_bndr = bndr, nd_rhs = rhs, nd_inl = inl_fvs })
+    (final_uds, bndrs') = tagRecBinders lvl body_uds
+                            [ (nd_bndr nd, nd_uds nd, nd_rhs_bndrs nd)
+                            | nd <- details_s ]
+    mk_lb_node nd@(ND { nd_bndr = bndr, nd_rhs = rhs, nd_inl = inl_fvs }) bndr'
       = (nd', varUnique bndr, nonDetKeysUFM lb_deps)
               -- It's OK to use nonDetKeysUFM here as
               -- stronglyConnCompFromEdgedVerticesR is still deterministic with edges
@@ -1127,7 +1300,6 @@ mkLoopBreakerNodes bndr_set total_uds details_s
               -- Note [Deterministic SCC] in Digraph.
       where
         nd'     = nd { nd_bndr = bndr', nd_score = score }
-        bndr'   = setBinderOcc total_uds bndr
         score   = nodeScore bndr bndr' rhs lb_deps
         lb_deps = extendFvs_ rule_fv_env inl_fvs
 
@@ -1156,59 +1328,57 @@ nodeScore old_bndr new_bndr bind_rhs lb_deps
   | old_bndr `elemVarSet` lb_deps  -- Self-recursive things are great loop breakers
   = (0, 0, True)                   -- See Note [Self-recursion and loop breakers]
 
-  | otherwise  -- An Id has an unfolding
-  = case id_unfolding of
-      DFunUnfolding { df_args = args }
-        -- Never choose a DFun as a loop breaker
-        -- Note [DFuns should not be loop breakers]
-        -> (9, length args, is_lb)
-
-      CoreUnfolding { uf_src = src, uf_tmpl = unf_rhs, uf_guidance = guide }
-        | isStableSource src
-        -> case guide of
-             UnfWhen {}                      -> (6, cheapExprSize unf_rhs, is_lb)
-             UnfIfGoodArgs { ug_size = size} -> (3, size,                  is_lb)
-             UnfNever                        -> (0, 0,                     is_lb)
-              -- See Note [Loop breakers and INLINE/INLINABLE pragmas] for
-              -- the 6 vs 3 choice
-
-         -- Note that this case hits /all/ stable unfoldings, so we
-         -- never look at 'bind_rhs' for stable unfoldings. That's right, because
-         -- 'rhs' is irrelevant for inlining things with a stable unfolding
-
-         -- Data structures are more important than INLINE pragmas
-         -- so that dictionary/method recursion unravels
-
-      _ | exprIsTrivial bind_rhs
-        -> mk_score 10  -- Practically certain to be inlined
-          -- Used to have also: && not (isExportedId bndr)
-          -- But I found this sometimes cost an extra iteration when we have
-          --      rec { d = (a,b); a = ...df...; b = ...df...; df = d }
-          -- where df is the exported dictionary. Then df makes a really
-          -- bad choice for loop breaker
-
-        | is_con_app bind_rhs   -- Data types help with cases: Note [Constructor applications]
-        -> mk_score 5
-
-        | isOneOcc (idOccInfo new_bndr)
-        -> mk_score 2  -- Likely to be inlined
-
-        | canUnfold id_unfolding   -- The Id has some kind of unfolding
-        -> mk_score 1
+  | exprIsTrivial rhs
+  = mk_score 10  -- Practically certain to be inlined
+    -- Used to have also: && not (isExportedId bndr)
+    -- But I found this sometimes cost an extra iteration when we have
+    --      rec { d = (a,b); a = ...df...; b = ...df...; df = d }
+    -- where df is the exported dictionary. Then df makes a really
+    -- bad choice for loop breaker
 
-        | otherwise
-        -> (0, 0, is_lb)
+  | DFunUnfolding { df_args = args } <- id_unfolding
+    -- Never choose a DFun as a loop breaker
+    -- Note [DFuns should not be loop breakers]
+  = (9, length args, is_lb)
+
+    -- Data structures are more important than INLINE pragmas
+    -- so that dictionary/method recursion unravels
+
+  | CoreUnfolding { uf_guidance = UnfWhen {} } <- id_unfolding
+  = mk_score 6
+
+  | is_con_app rhs   -- Data types help with cases:
+  = mk_score 5       -- Note [Constructor applications]
+
+  | isStableUnfolding id_unfolding
+  , canUnfold id_unfolding
+  = mk_score 3
+
+  | isOneOcc (idOccInfo new_bndr)
+  = mk_score 2  -- Likely to be inlined
+
+  | canUnfold id_unfolding  -- The Id has some kind of unfolding
+  = mk_score 1
+
+  | otherwise
+  = (0, 0, is_lb)
 
   where
     mk_score :: Int -> NodeScore
     mk_score rank = (rank, rhs_size, is_lb)
 
     is_lb    = isStrongLoopBreaker (idOccInfo old_bndr)
+    rhs      = case id_unfolding of
+                 CoreUnfolding { uf_src = src, uf_tmpl = unf_rhs }
+                    | isStableSource src
+                    -> unf_rhs
+                 _  -> bind_rhs
+       -- 'bind_rhs' is irrelevant for inlining things with a stable unfolding
     rhs_size = case id_unfolding of
                  CoreUnfolding { uf_guidance = guidance }
                     | UnfIfGoodArgs { ug_size = size } <- guidance
                     -> size
-                 _  -> cheapExprSize bind_rhs
+                 _  -> cheapExprSize rhs
 
     id_unfolding = realIdUnfolding old_bndr
        -- realIdUnfolding: Ignore loop-breaker-ness here because
@@ -1349,20 +1519,29 @@ Hence the is_lb field of NodeScore
 ************************************************************************
 -}
 
-occAnalRecRhs :: OccEnv -> CoreExpr    -- Rhs
-           -> (UsageDetails, CoreExpr)
+occAnalRhs :: OccEnv -> RecFlag -> Id -> [CoreBndr] -> CoreExpr
+           -> (UsageDetails, [CoreBndr], CoreExpr)
               -- Returned usage details covers only the RHS,
               -- and *not* the RULE or INLINE template for the Id
-occAnalRecRhs env rhs = occAnal (rhsCtxt env) rhs
+occAnalRhs env Recursive _ bndrs body
+  = occAnalRecRhs env bndrs body
+occAnalRhs env NonRecursive id bndrs body
+  = occAnalNonRecRhs env id bndrs body
+
+occAnalRecRhs :: OccEnv -> [CoreBndr] -> CoreExpr    -- Rhs lambdas, body
+           -> (UsageDetails, [CoreBndr], CoreExpr)
+              -- Returned usage details covers only the RHS,
+              -- and *not* the RULE or INLINE template for the Id
+occAnalRecRhs env bndrs body = occAnalLamOrRhs (rhsCtxt env) bndrs body
 
 occAnalNonRecRhs :: OccEnv
-                 -> Id -> CoreExpr    -- Binder and rhs
+                 -> Id -> [CoreBndr] -> CoreExpr    -- Binder; rhs lams, body
                      -- Binder is already tagged with occurrence info
-                 -> (UsageDetails, CoreExpr)
+                 -> (UsageDetails, [CoreBndr], CoreExpr)
               -- Returned usage details covers only the RHS,
               -- and *not* the RULE or INLINE template for the Id
-occAnalNonRecRhs env bndr rhs
-  = occAnal rhs_env rhs
+occAnalNonRecRhs env bndr bndrs body
+  = occAnalLamOrRhs rhs_env bndrs body
   where
     -- See Note [Cascading inlines]
     env1 | certainly_inline = env
@@ -1374,13 +1553,70 @@ occAnalNonRecRhs env bndr rhs
 
     certainly_inline -- See Note [Cascading inlines]
       = case idOccInfo bndr of
-          OneOcc in_lam one_br _ -> not in_lam && one_br && active && not_stable
+          OneOcc { occ_in_lam = in_lam, occ_one_br = one_br }
+                                 -> not in_lam && one_br && active && not_stable
           _                      -> False
 
     dmd        = idDemandInfo bndr
     active     = isAlwaysActive (idInlineActivation bndr)
     not_stable = not (isStableUnfolding (idUnfolding bndr))
 
+occAnalUnfolding :: OccEnv
+                 -> RecFlag
+                 -> Id
+                 -> Maybe UsageDetails
+                      -- Just the analysis, not a new unfolding. The unfolding
+                      -- got analysed when it was created and we don't need to
+                      -- update it.
+occAnalUnfolding env rec_flag id
+  = case realIdUnfolding id of -- ignore previous loop-breaker flag
+      CoreUnfolding { uf_tmpl = rhs, uf_src = src }
+        | not (isStableSource src)
+        -> Nothing
+        | otherwise
+        -> Just $ zapDetails usage
+        where
+          (bndrs, body) = collectBinders rhs
+          (usage, _, _) = occAnalRhs env rec_flag id bndrs body
+
+      DFunUnfolding { df_bndrs = bndrs, df_args = args }
+        -> Just $ zapDetails (delDetailsList usage bndrs)
+        where
+          usage = foldr (+++) emptyDetails (map (fst . occAnal env) args)
+
+      _ -> Nothing
+
+occAnalRules :: OccEnv
+             -> Maybe JoinArity -- If the binder is (or MAY become) a join
+                                -- point, what its join arity is (or WOULD
+                                -- become). See Note [Rules and join points].
+             -> RecFlag
+             -> Id
+             -> [(CoreRule,      -- Each (non-built-in) rule
+                  UsageDetails,  -- Usage details for LHS
+                  UsageDetails)] -- Usage details for RHS
+occAnalRules env mb_expected_join_arity rec_flag id
+  = [ (rule, lhs_uds, rhs_uds) | rule@Rule {} <- idCoreRules id
+                               , let (lhs_uds, rhs_uds) = occ_anal_rule rule ]
+  where
+    occ_anal_rule (Rule { ru_bndrs = bndrs, ru_args = args, ru_rhs = rhs })
+      = (lhs_uds, final_rhs_uds)
+      where
+        lhs_uds = addManyOccsSet emptyDetails $
+                    (exprsFreeVars args `delVarSetList` bndrs)
+        (rhs_bndrs, rhs_body) = collectBinders rhs
+        (rhs_uds, _, _) = occAnalRhs env rec_flag id rhs_bndrs rhs_body
+                            -- Note [Rules are extra RHSs]
+                            -- Note [Rule dependency info]
+        final_rhs_uds = adjust_tail_info bndrs $ markAllMany $
+                          (rhs_uds `delDetailsList` bndrs)
+    occ_anal_rule _
+      = (emptyDetails, emptyDetails)
+
+    adjust_tail_info bndrs uds -- see Note [Rules and join points]
+      = case mb_expected_join_arity of
+          Just ar | bndrs `lengthIs` ar -> uds
+          _                             -> markAllNonTailCalled uds
 {-
 Note [Cascading inlines]
 ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1437,8 +1673,8 @@ occAnal :: OccEnv
 
 occAnal _   expr@(Type _) = (emptyDetails,         expr)
 occAnal _   expr@(Lit _)  = (emptyDetails,         expr)
-occAnal env expr@(Var v)  = (mkOneOcc env v False, expr)
-    -- At one stage, I gathered the idRuleVars for v here too,
+occAnal env expr@(Var _)  = occAnalApp env (expr, [], [])
+    -- At one stage, I gathered the idRuleVars for the variable here too,
     -- which in a way is the right thing to do.
     -- But that went wrong right after specialisation, when
     -- the *occurrences* of the overloaded function didn't have any
@@ -1446,7 +1682,7 @@ occAnal env expr@(Var v)  = (mkOneOcc env v False, expr)
     -- weren't used at all.
 
 occAnal _ (Coercion co)
-  = (addIdOccs emptyDetails (coVarsOfCo co), Coercion co)
+  = (addManyOccsSet emptyDetails (coVarsOfCo co), Coercion co)
         -- See Note [Gather occurrences of coercion variables]
 
 {-
@@ -1458,10 +1694,10 @@ we can sort them into the right place when doing dependency analysis.
 
 occAnal env (Tick tickish body)
   | tickish `tickishScopesLike` SoftScope
-  = (usage, Tick tickish body')
+  = (markAllNonTailCalled usage, Tick tickish body')
 
   | Breakpoint _ ids <- tickish
-  = (usage_lam +++ mkVarEnv (zip ids (repeat NoOccInfo)), Tick tickish body')
+  = (usage_lam +++ foldr addManyOccs emptyDetails ids, Tick tickish body')
     -- never substitute for any of the Ids in a Breakpoint
 
   | otherwise
@@ -1469,14 +1705,20 @@ occAnal env (Tick tickish body)
   where
     !(usage,body') = occAnal env body
     -- for a non-soft tick scope, we can inline lambdas only
-    usage_lam = mapVarEnv markInsideLam usage
+    usage_lam = markAllNonTailCalled (markAllInsideLam usage)
+                  -- TODO There may be ways to make ticks and join points play
+                  -- nicer together, but right now there are problems:
+                  --   let j x = ... in tick<t> (j 1)
+                  -- Making j a join point may cause the simplifier to drop t
+                  -- (if the tick is put into the continuation). So we don't
+                  -- count j 1 as a tail call.
 
 occAnal env (Cast expr co)
   = case occAnal env expr of { (usage, expr') ->
-    let usage1 = markManyIf (isRhsEnv env) usage
-        usage2 = addIdOccs usage1 (coVarsOfCo co)
+    let usage1 = zapDetailsIf (isRhsEnv env) usage
+        usage2 = addManyOccsSet usage1 (coVarsOfCo co)
           -- See Note [Gather occurrences of coercion variables]
-    in (usage2, Cast expr' co)
+    in (markAllNonTailCalled usage2, Cast expr' co)
         -- If we see let x = y `cast` co
         -- then mark y as 'Many' so that we don't
         -- immediately inline y again.
@@ -1491,7 +1733,7 @@ occAnal env app@(App _ _)
 
 occAnal env (Lam x body) | isTyVar x
   = case occAnal env body of { (body_usage, body') ->
-    (body_usage, Lam x body')
+    (markAllNonTailCalled body_usage, Lam x body')
     }
 
 -- For value lambdas we do a special hack.  Consider
@@ -1504,19 +1746,17 @@ occAnal env (Lam x body) | isTyVar x
 -- Then, the simplifier is careful when partially applying lambdas.
 
 occAnal env expr@(Lam _ _)
-  = case occAnal env_body body of { (body_usage, body') ->
+  = case occAnalLamOrRhs env binders body of { (usage, tagged_binders, body') ->
     let
-        (final_usage, tagged_binders) = tagLamBinders body_usage binders'
-                      -- Use binders' to put one-shot info on the lambdas
-
-        really_final_usage
-          | all isOneShotBndr binders' = final_usage
-          | otherwise = mapVarEnv markInsideLam final_usage
+        expr'       = mkLams tagged_binders body'
+        final_usage | all isOneShotBndr tagged_binders
+                    = markAllNonTailCalled usage
+                    | otherwise
+                    = markAllInsideLam $ markAllNonTailCalled usage
     in
-    (really_final_usage, mkLams tagged_binders body') }
+    (final_usage, expr') }
   where
     (binders, body)      = collectBinders expr
-    (env_body, binders') = oneShotGroup env binders
 
 occAnal env (Case scrut bndr ty alts)
   = case occ_anal_scrut scrut alts     of { (scrut_usage, scrut') ->
@@ -1524,7 +1764,8 @@ occAnal env (Case scrut bndr ty alts)
     let
         alts_usage  = foldr combineAltsUsageDetails emptyDetails alts_usage_s
         (alts_usage1, tagged_bndr) = tag_case_bndr alts_usage bndr
-        total_usage = scrut_usage +++ alts_usage1
+        total_usage = markAllNonTailCalled scrut_usage +++ alts_usage1
+                        -- Alts can have tail calls, but the scrutinee can't
     in
     total_usage `seq` (total_usage, Case scrut' tagged_bndr ty alts') }}
   where
@@ -1538,18 +1779,21 @@ occAnal env (Case scrut bndr ty alts)
         -- into
         --      case x of w { (p,q) -> f (p,q) }
     tag_case_bndr usage bndr
-      = case lookupVarEnv usage bndr of
-          Nothing -> (usage,                  setIdOccInfo bndr IAmDead)
-          Just _  -> (usage `delVarEnv` bndr, setIdOccInfo bndr NoOccInfo)
+      = (usage', setIdOccInfo bndr final_occ_info)
+      where
+        occ_info       = lookupDetails usage bndr
+        usage'         = usage `delDetails` bndr
+        final_occ_info = case occ_info of IAmDead -> IAmDead
+                                          _       -> noOccInfo
 
     alt_env = mkAltEnv env scrut bndr
     occ_anal_alt = occAnalAlt alt_env
 
     occ_anal_scrut (Var v) (alt1 : other_alts)
         | not (null other_alts) || not (isDefaultAlt alt1)
-        = (mkOneOcc env v True, Var v)  -- The 'True' says that the variable occurs
-                                        -- in an interesting context; the case has
-                                        -- at least one non-default alternative
+        = (mkOneOcc env v True 0, Var v)
+            -- The 'True' says that the variable occurs in an interesting
+            -- context; the case has at least one non-default alternative
     occ_anal_scrut (Tick t e) alts
         | t `tickishScopesLike` SoftScope
           -- No reason to not look through all ticks here, but only
@@ -1561,8 +1805,10 @@ occAnal env (Case scrut bndr ty alts)
         = occAnal (vanillaCtxt env) scrut    -- No need for rhsCtxt
 
 occAnal env (Let bind body)
-  = case occAnal env body                               of { (body_usage, body') ->
-    case occAnalBind env noImpRuleEdges bind body_usage of { (final_usage, new_binds) ->
+  = case occAnal env body                of { (body_usage, body') ->
+    case occAnalBind env NotTopLevel
+                     noImpRuleEdges bind
+                     body_usage          of { (final_usage, new_binds) ->
        (final_usage, mkLets new_binds body') }}
 
 occAnalArgs :: OccEnv -> [CoreExpr] -> [OneShots] -> (UsageDetails, [CoreExpr])
@@ -1608,8 +1854,9 @@ occAnalApp env (Var fun, args, ticks)
 
     !(args_uds, args') = occAnalArgs env args one_shots
     !final_args_uds
-       | isRhsEnv env && is_exp = mapVarEnv markInsideLam args_uds
-       | otherwise              = args_uds
+       | isRhsEnv env && is_exp = markAllNonTailCalled $
+                                  markAllInsideLam args_uds
+       | otherwise              = markAllNonTailCalled args_uds
        -- We mark the free vars of the argument of a constructor or PAP
        -- as "inside-lambda", if it is the RHS of a let(rec).
        -- This means that nothing gets inlined into a constructor or PAP
@@ -1621,7 +1868,8 @@ occAnalApp env (Var fun, args, ticks)
        -- See Note [Arguments of let-bound constructors]
 
     n_val_args = valArgCount args
-    fun_uds    = mkOneOcc env fun (n_val_args > 0)
+    n_args     = length args
+    fun_uds    = mkOneOcc env fun (n_val_args > 0) n_args
     is_exp     = isExpandableApp fun n_val_args
            -- See Note [CONLIKE pragma] in BasicTypes
            -- The definition of is_exp should match that in
@@ -1631,7 +1879,8 @@ occAnalApp env (Var fun, args, ticks)
                  -- See Note [Use one-shot info]
 
 occAnalApp env (fun, args, ticks)
-  = (fun_uds +++ args_uds, mkTicks ticks $ mkApps fun' args')
+  = (markAllNonTailCalled (fun_uds +++ args_uds),
+     mkTicks ticks $ mkApps fun' args')
   where
     !(fun_uds, fun') = occAnal (addAppCtxt env args) fun
         -- The addAppCtxt is a bit cunning.  One iteration of the simplifier
@@ -1642,11 +1891,11 @@ occAnalApp env (fun, args, ticks)
         -- onto the context stack.
     !(args_uds, args') = occAnalArgs env args []
 
-markManyIf :: Bool              -- If this is true
-           -> UsageDetails      -- Then do markMany on this
-           -> UsageDetails
-markManyIf True  uds = mapVarEnv markMany uds
-markManyIf False uds = uds
+zapDetailsIf :: Bool              -- If this is true
+             -> UsageDetails      -- Then do zapDetails on this
+             -> UsageDetails
+zapDetailsIf True  uds = zapDetails uds
+zapDetailsIf False uds = uds
 
 {-
 Note [Use one-shot information]
@@ -1690,6 +1939,28 @@ life, beause it binds 'y' to (a,b) (imagine got inlined and
 scrutinised y).
 -}
 
+occAnalLamOrRhs :: OccEnv -> [CoreBndr] -> CoreExpr
+                -> (UsageDetails, [CoreBndr], CoreExpr)
+occAnalLamOrRhs env [] body
+  = case occAnal env body of (body_usage, body') -> (body_usage, [], body')
+      -- RHS of thunk or nullary join point
+occAnalLamOrRhs env (bndr:bndrs) body
+  | isTyVar bndr
+  = -- Important: Keep the environment so that we don't inline into an RHS like
+    --   \(@ x) -> C @x (f @x)
+    -- (see the beginning of Note [Cascading inlines]).
+    case occAnalLamOrRhs env bndrs body of
+      (body_usage, bndrs', body') -> (body_usage, bndr:bndrs', body')
+occAnalLamOrRhs env binders body
+  = case occAnal env_body body of { (body_usage, body') ->
+    let
+        (final_usage, tagged_binders) = tagLamBinders body_usage binders'
+                      -- Use binders' to put one-shot info on the lambdas
+    in
+    (final_usage, tagged_binders, body') }
+  where
+    (env_body, binders') = oneShotGroup env binders
+
 occAnalAlt :: (OccEnv, Maybe (Id, CoreExpr))
            -> CoreAlt
            -> (UsageDetails, Alt IdWithOccInfo)
@@ -1722,7 +1993,7 @@ wrapAltRHS env (Just (scrut_var, let_rhs)) alt_usg bndrs alt_rhs
     -- if the scrutinee was a cast, so we must gather their
     -- usage. See Note [Gather occurrences of coercion variables]
     (let_rhs_usg, let_rhs') = occAnal env let_rhs
-    (alt_usg', tagged_scrut_var) = tagBinder alt_usg scrut_var
+    (alt_usg', [tagged_scrut_var]) = tagLamBinders alt_usg [scrut_var]
 
 wrapAltRHS _ _ alt_usg _ alt_rhs
   = (alt_usg, alt_rhs)
@@ -2054,48 +2325,191 @@ mkAltEnv env@(OccEnv { occ_gbl_scrut = pe }) scrut case_bndr
 \subsection[OccurAnal-types]{OccEnv}
 *                                                                      *
 ************************************************************************
+
+Note [UsageDetails and zapping]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+On many occasions, we must modify all gathered occurrence data at once. For
+instance, all occurrences underneath a (non-one-shot) lambda set the
+'occ_in_lam' flag to become 'True'. We could use 'mapVarEnv' to do this, but
+that takes O(n) time and we will do this often---in particular, there are many
+places where tail calls are not allowed, and each of these causes all variables
+to get marked with 'NoTailCallInfo'.
+
+Instead of relying on `mapVarEnv`, then, we carry three 'IdEnv's around along
+with the 'OccInfoEnv'. Each of these extra environments is a "zapped set"
+recording which variables have been zapped in some way. Zapping all occurrence
+info then simply means setting the corresponding zapped set to the whole
+'OccInfoEnv', a fast O(1) operation.
 -}
 
-type UsageDetails = IdEnv OccInfo       -- A finite map from ids to their usage
+type OccInfoEnv = IdEnv OccInfo -- A finite map from ids to their usage
                 -- INVARIANT: never IAmDead
                 -- (Deadness is signalled by not being in the map at all)
 
+type ZappedSet = OccInfoEnv -- Values are ignored
+
+data UsageDetails
+  = UD { ud_env       :: !OccInfoEnv
+       , ud_z_many    :: ZappedSet   -- apply 'markMany' to these
+       , ud_z_in_lam  :: ZappedSet   -- apply 'markInsideLam' to these
+       , ud_z_no_tail :: ZappedSet } -- apply 'markNonTailCalled' to these
+  -- INVARIANT: All three zapped sets are subsets of the OccInfoEnv
+
+instance Outputable UsageDetails where
+  ppr ud = ppr (ud_env (flattenUsageDetails ud))
+
+-------------------
+-- UsageDetails API
+
 (+++), combineAltsUsageDetails
         :: UsageDetails -> UsageDetails -> UsageDetails
+(+++) = combineUsageDetailsWith addOccInfo
+combineAltsUsageDetails = combineUsageDetailsWith orOccInfo
 
-(+++) usage1 usage2
-  = plusVarEnv_C addOccInfo usage1 usage2
-
-combineAltsUsageDetails usage1 usage2
-  = plusVarEnv_C orOccInfo usage1 usage2
+combineUsageDetailsList :: [UsageDetails] -> UsageDetails
+combineUsageDetailsList = foldl (+++) emptyDetails
 
-addOneOcc :: UsageDetails -> Id -> OccInfo -> UsageDetails
-addOneOcc usage id info
-  = plusVarEnv_C addOccInfo usage (unitVarEnv id info)
-        -- ToDo: make this more efficient
+mkOneOcc :: OccEnv -> Id -> InterestingCxt -> JoinArity -> UsageDetails
+mkOneOcc env id int_cxt arity
+  | isLocalId id
+  = singleton $ OneOcc { occ_in_lam  = False
+                       , occ_one_br  = True
+                       , occ_int_cxt = int_cxt
+                       , occ_tail    = AlwaysTailCalled arity }
+  | id `elemVarEnv` occ_gbl_scrut env
+  = singleton noOccInfo
 
-emptyDetails :: UsageDetails
-emptyDetails = (emptyVarEnv :: UsageDetails)
+  | otherwise
+  = emptyDetails
+  where
+    singleton info = emptyDetails { ud_env = unitVarEnv id info }
 
-usedIn :: Id -> UsageDetails -> Bool
-v `usedIn` details = isExportedId v || v `elemVarEnv` details
+addOneOcc :: UsageDetails -> Id -> OccInfo -> UsageDetails
+addOneOcc ud id info
+  = ud { ud_env = extendVarEnv_C plus_zapped (ud_env ud) id info }
+      `alterZappedSets` (`delVarEnv` id)
+  where
+    plus_zapped old new = doZapping ud id old `addOccInfo` new
 
-addIdOccs :: UsageDetails -> VarSet -> UsageDetails
-addIdOccs usage id_set = nonDetFoldUFM addIdOcc usage id_set
-  -- It's OK to use nonDetFoldUFM here because addIdOcc commutes
+addManyOccsSet :: UsageDetails -> VarSet -> UsageDetails
+addManyOccsSet usage id_set = nonDetFoldUFM addManyOccs usage id_set
+  -- It's OK to use nonDetFoldUFM here because addManyOccs commutes
 
-addIdOcc :: Id -> UsageDetails -> UsageDetails
-addIdOcc v u | isId v    = addOneOcc u v NoOccInfo
-             | otherwise = u
-        -- Give a non-committal binder info (i.e NoOccInfo) because
+-- Add several occurrences, assumed not to be tail calls
+addManyOccs :: Var -> UsageDetails -> UsageDetails
+addManyOccs v u | isId v    = addOneOcc u v noOccInfo
+                | otherwise = u
+        -- Give a non-committal binder info (i.e noOccInfo) because
         --   a) Many copies of the specialised thing can appear
         --   b) We don't want to substitute a BIG expression inside a RULE
         --      even if that's the only occurrence of the thing
         --      (Same goes for INLINE.)
 
+delDetails :: UsageDetails -> Id -> UsageDetails
+delDetails ud bndr
+  = ud `alterUsageDetails` (`delVarEnv` bndr)
+
+delDetailsList :: UsageDetails -> [Id] -> UsageDetails
+delDetailsList ud bndrs
+  = ud `alterUsageDetails` (`delVarEnvList` bndrs)
+
+emptyDetails :: UsageDetails
+emptyDetails = UD { ud_env       = emptyVarEnv
+                  , ud_z_many    = emptyVarEnv
+                  , ud_z_in_lam  = emptyVarEnv
+                  , ud_z_no_tail = emptyVarEnv }
+
+isEmptyDetails :: UsageDetails -> Bool
+isEmptyDetails = isEmptyVarEnv . ud_env
+
+markAllMany, markAllInsideLam, markAllNonTailCalled, zapDetails
+  :: UsageDetails -> UsageDetails
+markAllMany          ud = ud { ud_z_many    = ud_env ud }
+markAllInsideLam     ud = ud { ud_z_in_lam  = ud_env ud }
+markAllNonTailCalled ud = ud { ud_z_no_tail = ud_env ud }
+
+zapDetails = markAllMany . markAllNonTailCalled -- effectively sets to noOccInfo
+
+lookupDetails :: UsageDetails -> Id -> OccInfo
+lookupDetails ud id
+  = case lookupVarEnv (ud_env ud) id of
+      Just occ -> doZapping ud id occ
+      Nothing  -> IAmDead
+
+usedIn :: Id -> UsageDetails -> Bool
+v `usedIn` ud = isExportedId v || v `elemVarEnv` ud_env ud
+
 udFreeVars :: VarSet -> UsageDetails -> VarSet
 -- Find the subset of bndrs that are mentioned in uds
-udFreeVars bndrs uds = intersectUFM_C (\b _ -> b) bndrs uds
+udFreeVars bndrs ud = intersectUFM_C (\b _ -> b) bndrs (ud_env ud)
+
+-------------------
+-- Auxiliary functions for UsageDetails implementation
+
+combineUsageDetailsWith :: (OccInfo -> OccInfo -> OccInfo)
+                        -> UsageDetails -> UsageDetails -> UsageDetails
+combineUsageDetailsWith plus_occ_info ud1 ud2
+  | isEmptyDetails ud1 = ud2
+  | isEmptyDetails ud2 = ud1
+  | otherwise
+  = UD { ud_env       = plusVarEnv_C plus_occ_info (ud_env ud1) (ud_env ud2)
+       , ud_z_many    = plusVarEnv (ud_z_many    ud1) (ud_z_many    ud2)
+       , ud_z_in_lam  = plusVarEnv (ud_z_in_lam  ud1) (ud_z_in_lam  ud2)
+       , ud_z_no_tail = plusVarEnv (ud_z_no_tail ud1) (ud_z_no_tail ud2) }
+
+doZapping :: UsageDetails -> Var -> OccInfo -> OccInfo
+doZapping ud var occ
+  = doZappingByUnique ud (varUnique var) occ
+
+doZappingByUnique :: UsageDetails -> Unique -> OccInfo -> OccInfo
+doZappingByUnique ud uniq
+  = (if | in_subset ud_z_many    -> markMany
+        | in_subset ud_z_in_lam  -> markInsideLam
+        | otherwise              -> id) .
+    (if | in_subset ud_z_no_tail -> markNonTailCalled
+        | otherwise              -> id)
+  where
+    in_subset field = uniq `elemVarEnvByKey` field ud
+
+alterZappedSets :: UsageDetails -> (ZappedSet -> ZappedSet) -> UsageDetails
+alterZappedSets ud f
+  = ud { ud_z_many    = f (ud_z_many    ud)
+       , ud_z_in_lam  = f (ud_z_in_lam  ud)
+       , ud_z_no_tail = f (ud_z_no_tail ud) }
+
+alterUsageDetails :: UsageDetails -> (OccInfoEnv -> OccInfoEnv) -> UsageDetails
+alterUsageDetails ud f
+  = ud { ud_env = f (ud_env ud) }
+      `alterZappedSets` f
+
+flattenUsageDetails :: UsageDetails -> UsageDetails
+flattenUsageDetails ud
+  = ud { ud_env = mapUFM_Directly (doZappingByUnique ud) (ud_env ud) }
+      `alterZappedSets` const emptyVarEnv
+
+-------------------
+-- See Note [Adjusting right-hand sides]
+adjustRhsUsage :: Maybe JoinArity -> RecFlag
+               -> [CoreBndr] -- Outer lambdas, AFTER occ anal
+               -> UsageDetails -> UsageDetails
+adjustRhsUsage mb_join_arity rec_flag bndrs usage
+  = maybe_mark_lam (maybe_drop_tails usage)
+  where
+    maybe_mark_lam ud   | one_shot   = ud
+                        | otherwise  = markAllInsideLam ud
+    maybe_drop_tails ud | exact_join = ud
+                        | otherwise  = markAllNonTailCalled ud
+
+    one_shot = case mb_join_arity of
+                 Just join_arity
+                   | isRec rec_flag -> False
+                   | otherwise      -> all isOneShotBndr (drop join_arity bndrs)
+                 Nothing            -> all isOneShotBndr bndrs
+
+    exact_join = case mb_join_arity of
+                   Just join_arity -> join_arity == length bndrs
+                   _               -> False
 
 type IdWithOccInfo = Id
 
@@ -2109,37 +2523,145 @@ tagLamBinders :: UsageDetails          -- Of scope
 tagLamBinders usage binders = usage' `seq` (usage', bndrs')
   where
     (usage', bndrs') = mapAccumR tag_lam usage binders
-    tag_lam usage bndr = (usage2, setBinderOcc usage bndr)
+    tag_lam usage bndr = (usage2, bndr')
       where
-        usage1 = usage `delVarEnv` bndr
-        usage2 | isId bndr = addIdOccs usage1 (idUnfoldingVars bndr)
+        occ    = lookupDetails usage bndr
+        bndr'  = setBinderOcc (markNonTailCalled occ) bndr
+                   -- Don't try to make an argument into a join point
+        usage1 = usage `delDetails` bndr
+        usage2 | isId bndr = addManyOccsSet usage1 (idUnfoldingVars bndr)
+                               -- This is effectively the RHS of a
+                               -- non-join-point binding, so it's okay to use
+                               -- addManyOccsSet, which assumes no tail calls
                | otherwise = usage1
 
-tagBinder :: UsageDetails           -- Of scope
-          -> Id                     -- Binders
-          -> (UsageDetails,         -- Details with binders removed
-              IdWithOccInfo)        -- Tagged binders
+tagNonRecBinder :: TopLevelFlag           -- At top level?
+                -> UsageDetails           -- Of scope
+                -> CoreBndr               -- Binder
+                -> (UsageDetails,         -- Details with binder removed
+                    IdWithOccInfo)        -- Tagged binder
 
-tagBinder usage binder
+tagNonRecBinder lvl usage binder
  = let
-     usage'  = usage `delVarEnv` binder
-     binder' = setBinderOcc usage binder
+     occ     = lookupDetails usage binder
+     will_be_join = decideJoinPointHood lvl usage [binder]
+     occ'    | will_be_join = occ -- must already be marked AlwaysTailCalled
+             | otherwise    = markNonTailCalled occ
+     binder' = setBinderOcc occ' binder
+     usage'  = usage `delDetails` binder
    in
    usage' `seq` (usage', binder')
 
-setBinderOcc :: UsageDetails -> CoreBndr -> CoreBndr
-setBinderOcc usage bndr
+tagRecBinders :: TopLevelFlag           -- At top level?
+              -> UsageDetails           -- Of body of let ONLY
+              -> [(CoreBndr,            -- Binder
+                   UsageDetails,        -- RHS usage details
+                   [CoreBndr])]         -- Lambdas in new RHS
+              -> (UsageDetails,         -- Adjusted details for whole scope,
+                                        -- with binders removed
+                  [IdWithOccInfo])      -- Tagged binders
+-- Substantially more complicated than non-recursive case. Need to adjust RHS
+-- details *before* tagging binders (because the tags depend on the RHSes).
+tagRecBinders lvl body_uds triples
+ = let
+     (bndrs, rhs_udss, _) = unzip3 triples
+
+     -- 1. Determine join-point-hood of whole group, as determined by
+     --    the *unadjusted* usage details
+     unadj_uds     = body_uds +++ combineUsageDetailsList rhs_udss
+     will_be_joins = decideJoinPointHood lvl unadj_uds bndrs
+
+     -- 2. Adjust usage details of each RHS, taking into account the
+     --    join-point-hood decision
+     rhs_udss' = map adjust triples
+     adjust (bndr, rhs_uds, rhs_bndrs)
+       = adjustRhsUsage mb_join_arity Recursive rhs_bndrs rhs_uds
+       where
+         -- Can't use willBeJoinId_maybe here because we haven't tagged the
+         -- binder yet (the tag depends on these adjustments!)
+         mb_join_arity
+           | will_be_joins
+           , let occ = lookupDetails unadj_uds bndr
+           , AlwaysTailCalled arity <- tailCallInfo occ
+           = Just arity
+           | otherwise
+           = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if we're
+                                       -- making join points!
+             Nothing
+
+     -- 3. Compute final usage details from adjusted RHS details
+     adj_uds   = body_uds +++ combineUsageDetailsList rhs_udss'
+
+     -- 4. Tag each binder with its adjusted details modulo the
+     --    join-point-hood decision
+     occs      = map (lookupDetails adj_uds) bndrs
+     occs'     | will_be_joins = occs
+               | otherwise     = map markNonTailCalled occs
+     bndrs'    = zipWith setBinderOcc occs' bndrs
+
+     -- 5. Drop the binders from the adjusted details and return
+     usage'    = adj_uds `delDetailsList` bndrs
+   in
+   (usage', bndrs')
+
+setBinderOcc :: OccInfo -> CoreBndr -> CoreBndr
+setBinderOcc occ_info bndr
   | isTyVar bndr      = bndr
-  | isExportedId bndr = case idOccInfo bndr of
-                          NoOccInfo -> bndr
-                          _         -> setIdOccInfo bndr NoOccInfo
+  | isExportedId bndr = if isManyOccs (idOccInfo bndr)
+                          then bndr
+                          else setIdOccInfo bndr noOccInfo
             -- Don't use local usage info for visible-elsewhere things
             -- BUT *do* erase any IAmALoopBreaker annotation, because we're
             -- about to re-generate it and it shouldn't be "sticky"
 
   | otherwise = setIdOccInfo bndr occ_info
+
+-- | Decide whether some bindings should be made into join points or not.
+-- Returns `False` if they can't be join points. Note that it's an
+-- all-or-nothing decision, as if multiple binders are given, they're assumed to
+-- be mutually recursive.
+--
+-- See Note [Invariants for join points] in CoreSyn.
+decideJoinPointHood :: TopLevelFlag -> UsageDetails
+                    -> [CoreBndr]
+                    -> Bool
+decideJoinPointHood TopLevel _ _
+  = False
+decideJoinPointHood NotTopLevel usage bndrs
+  | isJoinId (head bndrs)
+  = WARN(not all_ok, text "OccurAnal failed to rediscover join point(s):" <+>
+                       ppr bndrs)
+    all_ok
+  | otherwise
+  = all_ok
   where
-    occ_info = lookupVarEnv usage bndr `orElse` IAmDead
+    -- See Note [Invariants on join points]; invariants cited by number below.
+    -- Invariant 2 is always satisfiable by the simplifier by eta expansion.
+    all_ok = -- Invariant 3: Either all are join points or none are
+             all ok bndrs
+
+    ok bndr
+      | -- Invariant 1: Only tail calls, all same join arity
+        AlwaysTailCalled arity <- tailCallInfo (lookupDetails usage bndr)
+      , -- Invariant 1 as applied to LHSes of rules
+        all (ok_rule arity) (idCoreRules bndr)
+        -- Invariant 4: Satisfies polymorphism rule
+      , isValidJoinPointType arity (idType bndr)
+      = True
+      | otherwise
+      = False
+
+    ok_rule _ BuiltinRule{} = False -- only possible with plugin shenanigans
+    ok_rule join_arity (Rule { ru_args = args })
+      = length args == join_arity
+        -- Invariant 1 as applied to LHSes of rules
+
+willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity
+willBeJoinId_maybe bndr
+  | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
+  = Just arity
+  | otherwise
+  = isJoinId_maybe bndr
 
 {-
 ************************************************************************
@@ -2149,37 +2671,41 @@ setBinderOcc usage bndr
 ************************************************************************
 -}
 
-mkOneOcc :: OccEnv -> Id -> InterestingCxt -> UsageDetails
-mkOneOcc env id int_cxt
-  | isLocalId id
-  = unitVarEnv id (OneOcc False True int_cxt)
+markMany, markInsideLam, markNonTailCalled :: OccInfo -> OccInfo
 
-  | id `elemVarEnv` occ_gbl_scrut env
-  = unitVarEnv id NoOccInfo
-
-  | otherwise
-  = emptyDetails
-
-markMany, markInsideLam :: OccInfo -> OccInfo
+markMany IAmDead = IAmDead
+markMany occ     = ManyOccs { occ_tail = occ_tail occ }
 
-markMany _  = NoOccInfo
+markInsideLam occ@(OneOcc {}) = occ { occ_in_lam = True }
+markInsideLam occ             = occ
 
-markInsideLam (OneOcc _ one_br int_cxt) = OneOcc True one_br int_cxt
-markInsideLam occ                       = occ
+markNonTailCalled IAmDead = IAmDead
+markNonTailCalled occ     = occ { occ_tail = NoTailCallInfo }
 
 addOccInfo, orOccInfo :: OccInfo -> OccInfo -> OccInfo
 
 addOccInfo a1 a2  = ASSERT( not (isDeadOcc a1 || isDeadOcc a2) )
-                    NoOccInfo   -- Both branches are at least One
+                    ManyOccs { occ_tail = tailCallInfo a1 `andTailCallInfo`
+                                          tailCallInfo a2 }
+                                -- Both branches are at least One
                                 -- (Argument is never IAmDead)
 
 -- (orOccInfo orig new) is used
 -- when combining occurrence info from branches of a case
 
-orOccInfo (OneOcc in_lam1 _ int_cxt1)
-          (OneOcc in_lam2 _ int_cxt2)
-  = OneOcc (in_lam1 || in_lam2)
-           False        -- False, because it occurs in both branches
-           (int_cxt1 && int_cxt2)
+orOccInfo (OneOcc { occ_in_lam = in_lam1, occ_int_cxt = int_cxt1
+                  , occ_tail   = tail1 })
+          (OneOcc { occ_in_lam = in_lam2, occ_int_cxt = int_cxt2
+                  , occ_tail   = tail2 })
+  = OneOcc { occ_in_lam  = in_lam1 || in_lam2
+           , occ_one_br  = False -- False, because it occurs in both branches
+           , occ_int_cxt = int_cxt1 && int_cxt2
+           , occ_tail    = tail1 `andTailCallInfo` tail2 }
 orOccInfo a1 a2 = ASSERT( not (isDeadOcc a1 || isDeadOcc a2) )
-                  NoOccInfo
+                  ManyOccs { occ_tail = tailCallInfo a1 `andTailCallInfo`
+                                        tailCallInfo a2 }
+
+andTailCallInfo :: TailCallInfo -> TailCallInfo -> TailCallInfo
+andTailCallInfo info@(AlwaysTailCalled arity1) (AlwaysTailCalled arity2)
+  | arity1 == arity2 = info
+andTailCallInfo _ _  = NoTailCallInfo
index c0d6e8d..d1ff3fc 100644 (file)
   the scrutinee of the case, and we can inline it.
 -}
 
-{-# LANGUAGE CPP #-}
+{-# LANGUAGE CPP, MultiWayIf #-}
 module SetLevels (
         setLevels,
 
-        Level(..), tOP_LEVEL,
+        Level(..), LevelType(..), tOP_LEVEL, isJoinCeilLvl, asJoinCeilLvl,
         LevelledBind, LevelledExpr, LevelledBndr,
         FloatSpec(..), floatSpecLevel,
 
@@ -74,6 +74,7 @@ import CoreArity        ( exprBotStrictness_maybe )
 import CoreFVs          -- all of it
 import CoreSubst
 import MkCore           ( sortQuantVars )
+
 import Id
 import IdInfo
 import Var
@@ -84,7 +85,7 @@ import Demand           ( StrictSig, increaseStrictSigArity )
 import Name             ( getOccName, mkSystemVarName )
 import OccName          ( occNameString )
 import Type             ( isUnliftedType, Type, mkLamTypes, splitTyConApp_maybe )
-import BasicTypes       ( Arity, RecFlag(..) )
+import BasicTypes       ( Arity, RecFlag(..), isRec )
 import DataCon          ( dataConOrigResTy )
 import TysWiredIn
 import UniqSupply
@@ -95,6 +96,8 @@ import UniqDFM
 import FV
 import Data.Maybe
 
+import Control.Monad    ( zipWithM )
+
 {-
 ************************************************************************
 *                                                                      *
@@ -107,10 +110,12 @@ type LevelledExpr = TaggedExpr FloatSpec
 type LevelledBind = TaggedBind FloatSpec
 type LevelledBndr = TaggedBndr FloatSpec
 
-data Level = Level Int  -- Major level: number of enclosing value lambdas
-                   Int  -- Minor level: number of big-lambda and/or case
-                        -- expressions between here and the nearest
-                        -- enclosing value lambda
+data Level = Level Int  -- Level number of enclosing lambdas
+                   Int  -- Number of big-lambda and/or case expressions and/or
+                        -- context boundaries between
+                        -- here and the nearest enclosing lambda
+                   LevelType -- Binder or join ceiling?
+data LevelType = BndrLvl | JoinCeilLvl deriving (Eq)
 
 data FloatSpec
   = FloatMe Level       -- Float to just inside the binding
@@ -139,7 +144,7 @@ a_0 = let  b_? = ...  in
            x_1 = ... b ... in ...
 \end{verbatim}
 
-The main function @lvlExpr@ carries a ``context level'' (@ctxt_lvl@).
+The main function @lvlExpr@ carries a ``context level'' (@le_ctxt_lvl@).
 That's meant to be the level number of the enclosing binder in the
 final (floated) program.  If the level number of a sub-expression is
 less than that of the context, then it might be worth let-binding the
@@ -176,6 +181,26 @@ One particular case is that of workers: we don't want to float the
 call to the worker outside the wrapper, otherwise the worker might get
 inlined into the floated expression, and an importing module won't see
 the worker at all.
+
+Note [Join ceiling]
+~~~~~~~~~~~~~~~~~~~
+Join points can't float very far; too far, and they can't remain join points
+(though see Note [When to ruin a join point]). So, suppose we have:
+
+  f x =
+    (joinrec j y = ... x ... in jump j x) + 1
+
+One may be tempted to float j out to the top of f's RHS, but then the jump
+would not be a tail call. Thus we keep track of a level called the *join
+ceiling* past which join points are not allowed to float.
+
+The troublesome thing is that, unlike most levels to which something might
+float, there is not necessarily an identifier to which the join ceiling is
+attached. Fortunately, if something is to be floated to a join ceiling, it must
+be dropped at the *nearest* join ceiling. Thus each level is marked as to
+whether it is a join ceiling, so that FloatOut can tell which binders are being
+floated to the nearest join ceiling and which to a particular binder (or set of
+binders).
 -}
 
 instance Outputable FloatSpec where
@@ -183,36 +208,44 @@ instance Outputable FloatSpec where
   ppr (StayPut l) = ppr l
 
 tOP_LEVEL :: Level
-tOP_LEVEL   = Level 0 0
+tOP_LEVEL   = Level 0 0 BndrLvl
 
 incMajorLvl :: Level -> Level
-incMajorLvl (Level major _) = Level (major + 1) 0
+incMajorLvl (Level major _ _) = Level (major + 1) 0 BndrLvl
 
 incMinorLvl :: Level -> Level
-incMinorLvl (Level major minor) = Level major (minor+1)
+incMinorLvl (Level major minor _) = Level major (minor+1) BndrLvl
+
+asJoinCeilLvl :: Level -> Level
+asJoinCeilLvl (Level major minor _) = Level major minor JoinCeilLvl
 
 maxLvl :: Level -> Level -> Level
-maxLvl l1@(Level maj1 min1) l2@(Level maj2 min2)
+maxLvl l1@(Level maj1 min1 _) l2@(Level maj2 min2 _)
   | (maj1 > maj2) || (maj1 == maj2 && min1 > min2) = l1
   | otherwise                                      = l2
 
 ltLvl :: Level -> Level -> Bool
-ltLvl (Level maj1 min1) (Level maj2 min2)
+ltLvl (Level maj1 min1 _) (Level maj2 min2 _)
   = (maj1 < maj2) || (maj1 == maj2 && min1 < min2)
 
 ltMajLvl :: Level -> Level -> Bool
     -- Tells if one level belongs to a difft *lambda* level to another
-ltMajLvl (Level maj1 _) (Level maj2 _) = maj1 < maj2
+ltMajLvl (Level maj1 _ _) (Level maj2 _ _) = maj1 < maj2
 
 isTopLvl :: Level -> Bool
-isTopLvl (Level 0 0) = True
-isTopLvl _           = False
+isTopLvl (Level 0 0 _) = True
+isTopLvl _             = False
+
+isJoinCeilLvl :: Level -> Bool
+isJoinCeilLvl (Level _ _ t) = t == JoinCeilLvl
 
 instance Outputable Level where
-  ppr (Level maj min) = hcat [ char '<', int maj, char ',', int min, char '>' ]
+  ppr (Level maj min typ)
+    = hcat [ char '<', int maj, char ',', int min, char '>'
+           , ppWhen (typ == JoinCeilLvl) (char 'C') ]
 
 instance Eq Level where
-  (Level maj1 min1) == (Level maj2 min2) = maj1 == maj2 && min1 == min2
+  (Level maj1 min1 _) == (Level maj2 min2 _) = maj1 == maj2 && min1 == min2
 
 {-
 ************************************************************************
@@ -241,14 +274,14 @@ setLevels float_lams binds us
 
 lvlTopBind :: LevelEnv -> Bind Id -> LvlM (LevelledBind, LevelEnv)
 lvlTopBind env (NonRec bndr rhs)
-  = do { rhs' <- lvlExpr env (freeVars rhs)
+  = do { rhs' <- lvlNonTailExpr env (freeVars rhs)
        ; let (env', [bndr']) = substAndLvlBndrs NonRecursive env tOP_LEVEL [bndr]
        ; return (NonRec bndr' rhs', env') }
 
 lvlTopBind env (Rec pairs)
   = do let (bndrs,rhss) = unzip pairs
            (env', bndrs') = substAndLvlBndrs Recursive env tOP_LEVEL bndrs
-       rhss' <- mapM (lvlExpr env' . freeVars) rhss
+       rhss' <- mapM (lvlNonTailExpr env' . freeVars) rhss
        return (Rec (bndrs' `zip` rhss'), env')
 
 {-
@@ -278,16 +311,16 @@ lvlExpr :: LevelEnv             -- Context
         -> LvlM LevelledExpr    -- Result expression
 
 {-
-The @ctxt_lvl@ is, roughly, the level of the innermost enclosing
+The @le_ctxt_lvl@ is, roughly, the level of the innermost enclosing
 binder.  Here's an example
 
         v = \x -> ...\y -> let r = case (..x..) of
                                         ..x..
                            in ..
 
-When looking at the rhs of @r@, @ctxt_lvl@ will be 1 because that's
+When looking at the rhs of @r@, @le_ctxt_lvl@ will be 1 because that's
 the level of @r@, even though it's inside a level-2 @\y@.  It's
-important that @ctxt_lvl@ is 1 and not 2 in @r@'s rhs, because we
+important that @le_ctxt_lvl@ is 1 and not 2 in @r@'s rhs, because we
 don't want @lvlExpr@ to turn the scrutinee of the @case@ into an MFE
 --- because it isn't a *maximal* free expression.
 
@@ -300,11 +333,11 @@ lvlExpr env (_, AnnVar v)       = return (lookupVar env v)
 lvlExpr _   (_, AnnLit lit)     = return (Lit lit)
 
 lvlExpr env (_, AnnCast expr (_, co)) = do
-    expr' <- lvlExpr env expr
+    expr' <- lvlNonTailExpr env expr
     return (Cast expr' (substCo (le_subst env) co))
 
 lvlExpr env (_, AnnTick tickish expr) = do
-    expr' <- lvlExpr env expr
+    expr' <- lvlNonTailExpr env expr
     let tickish' = substTickish (le_subst env) tickish
     return (Tick tickish' expr')
 
@@ -319,8 +352,8 @@ lvlExpr env expr@(_, AnnApp _ _) = do
                     , Nothing <- isClassOpId_maybe f ->
         do
          let (lapp, rargs) = left (n_val_args - arity) expr []
-         rargs' <- mapM (lvlMFE False env) rargs
-         lapp' <- lvlMFE False env lapp
+         rargs' <- mapM (lvlNonTailMFE False env) rargs
+         lapp' <- lvlNonTailMFE False env lapp
          return (foldl App lapp' rargs')
         where
          n_val_args = count (isValArg . deAnnotate) args
@@ -338,8 +371,8 @@ lvlExpr env expr@(_, AnnApp _ _) = do
          -- No PAPs that we can float: just carry on with the
          -- arguments and the function.
       _otherwise -> do
-         args' <- mapM (lvlMFE False env) args
-         fun'  <- lvlExpr env fun
+         args' <- mapM (lvlNonTailMFE False env) args
+         fun'  <- lvlNonTailExpr env fun
          return (foldl App fun' args')
 
 -- We don't split adjacent lambdas.  That is, given
@@ -350,7 +383,7 @@ lvlExpr env expr@(_, AnnApp _ _) = do
 -- lambdas makes them more expensive.
 
 lvlExpr env expr@(_, AnnLam {})
-  = do { new_body <- lvlMFE True new_env body
+  = do { new_body <- lvlNonTailMFE True new_env body
        ; return (mkLams new_bndrs new_body) }
   where
     (bndrs, body)        = collectAnnBndrs expr
@@ -372,9 +405,15 @@ lvlExpr env (_, AnnLet bind body)
        ; return (Let bind' body') }
 
 lvlExpr env (_, AnnCase scrut case_bndr ty alts)
-  = do { scrut' <- lvlMFE True env scrut
+  = do { scrut' <- lvlNonTailMFE True env scrut
        ; lvlCase env (freeVarsOf scrut) scrut' case_bndr ty alts }
 
+lvlNonTailExpr :: LevelEnv             -- Context
+               -> CoreExprWithFVs      -- Input expression
+               -> LvlM LevelledExpr    -- Result expression
+lvlNonTailExpr env expr
+  = lvlExpr (placeJoinCeiling env) expr
+
 -------------------------------------------
 lvlCase :: LevelEnv             -- Level of in-scope names/tyvars
         -> DVarSet              -- Free vars of input scrutinee
@@ -394,14 +433,16 @@ lvlCase env scrut_fvs scrut' case_bndr ty alts
        ; let rhs_env = extendCaseBndrEnv env1 case_bndr scrut'
        ; body' <- lvlMFE True rhs_env body
        ; let alt' = (con, [TB b (StayPut dest_lvl) | b <- bs'], body')
-       ; return (Case scrut' (TB case_bndr' (FloatMe dest_lvl)) ty [alt']) }
+       ; return (Case scrut' (TB case_bndr' (FloatMe dest_lvl)) ty' [alt']) }
 
   | otherwise     -- Stays put
   = do { let (alts_env1, [case_bndr']) = substAndLvlBndrs NonRecursive env incd_lvl [case_bndr]
              alts_env = extendCaseBndrEnv alts_env1 case_bndr scrut'
        ; alts' <- mapM (lvl_alt alts_env) alts
-       ; return (Case scrut' case_bndr' ty alts') }
+       ; return (Case scrut' case_bndr' ty' alts') }
   where
+    ty' = substTy (le_subst env) ty
+
     incd_lvl = incMinorLvl (le_ctxt_lvl env)
     dest_lvl = maxFvLevel (const True) env scrut_fvs
             -- Don't abstact over type variables, hence const True
@@ -487,6 +528,7 @@ lvlMFE True env e@(_, AnnCase {})
 lvlMFE strict_ctxt env ann_expr
   |  floatTopLvlOnly env && not (isTopLvl dest_lvl)
          -- Only floating to the top level is allowed.
+  || isTopLvl dest_lvl && need_join -- Can't put join point at top level
   || isExprLevPoly expr
          -- We can't let-bind levity polymorphic expressions
          -- See Note [Levity polymorphism invariants] in CoreSyn
@@ -496,10 +538,11 @@ lvlMFE strict_ctxt env ann_expr
     lvlExpr env ann_expr
 
   | Just (wrap_float, wrap_use)
-       <- canFloat_maybe rhs_env strict_ctxt float_is_lam expr
-  = do { expr1 <- lvlExpr rhs_env ann_expr
+       <- canFloat_maybe rhs_env strict_ctxt (float_is_lam || need_join) expr
+  = do { expr1 <- if need_join then lvlExpr rhs_env ann_expr
+                               else lvlNonTailExpr rhs_env ann_expr
        ; let abs_expr = mkLams abs_vars_w_lvls (wrap_float expr1)
-       ; var <- newLvlVar abs_expr
+       ; var <- newLvlVar abs_expr join_arity_maybe
        ; let var2 = annotateBotStr var float_n_lams mb_bot_str
        ; return (Let (NonRec (TB var2 (FloatMe dest_lvl)) abs_expr)
                      (wrap_use (mkVarApps (Var var2) abs_vars))) }
@@ -514,13 +557,18 @@ lvlMFE strict_ctxt env ann_expr
     mb_bot_str   = exprBotStrictness_maybe expr
                            -- See Note [Bottoming floats]
                            -- esp Bottoming floats (2)
-    dest_lvl     = destLevel env fvs (isFunction ann_expr) is_bot
+    dest_lvl     = destLevel env fvs (isFunction ann_expr) is_bot need_join
     abs_vars     = abstractVars dest_lvl env fvs
     float_is_lam = float_n_lams > 0       -- The floated thing will be a value lambda
     float_n_lams = count isId abs_vars    -- so nothing is shared; the only benefit
                                           -- is getting it to the top level
     (rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars
 
+        -- Note [Join points and MFEs]
+    need_join = any (\v -> isId v && remainsJoinId env v) (dVarSetElems fvs)
+    join_arity_maybe | need_join = Just (length abs_vars)
+                     | otherwise = Nothing
+
         -- A decision to float entails let-binding this thing, and we only do
         -- that if we'll escape a value lambda, or will go to the top level.
     float_me = (dest_lvl `ltMajLvl` (le_ctxt_lvl env) -- Escapes a value lambda
@@ -542,6 +590,14 @@ lvlMFE strict_ctxt env ann_expr
           --    concat = /\ a -> lvl a
           -- which is pretty stupid.  Hence the strict_ctxt test
 
+lvlNonTailMFE :: Bool                 -- True <=> strict context [body of case
+                                      --   or let]
+              -> LevelEnv             -- Level of in-scope names/tyvars
+              -> CoreExprWithFVs      -- input expression
+              -> LvlM LevelledExpr    -- Result expression
+lvlNonTailMFE strict_ctxt env ann_expr
+  = lvlMFE strict_ctxt (placeJoinCeiling env) ann_expr
+
 canFloat_maybe :: LevelEnv
                -> Bool      -- Strict context
                -> Bool      -- The float has a value lambda
@@ -553,6 +609,7 @@ canFloat_maybe env strict_ctxt float_is_lam expr
   | float_is_lam || exprIsTopLevelBindable expr
   = Just (id, id) -- No wrapping needed if the type is lifted, or
                   -- if we are wrapping it in one or more value lambdas
+                  -- or making it a join point
 
   -- OK, so the float has an unlifted type and no value lambdas
   | strict_ctxt
@@ -668,6 +725,43 @@ Because in doing so we share a tiny bit of computation (the switch) but
 in exchange we build a thunk, which is bad.  This case reduces allocation
 by 7% in spectral/puzzle (a rather strange benchmark) and 1.2% in real/fem.
 Doesn't change any other allocation at all.
+
+Note [Join points and MFEs]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+When we create an MFE float, if it has a free join variable, the new binding
+must be a join point:
+
+  let join j x = ...
+  in case a of A -> ...
+               B -> j 3
+
+  =>
+
+  let join j x = ...
+      join k = j 3 -- only valid because k is a join point
+  in case a of A -> ...
+               B -> k
+
+Normally we're very circumspect about floating join points, but in this case
+it's definitely safe because we can only be floating it as far as another join
+binding. In other words, one might worry about a situation like:
+
+  let join j x = ...
+  in case a of A -> ...
+               B -> f (j 3)
+
+  =>
+
+  let join j x = ...
+  in case a of A -> ...
+               B -> f (let join k = j 3 in k)
+
+Here we have created the MFE float k, and are contemplating floating it up to
+j. This would indeed be an invalid operation on a join point like k. However,
+this example is ill-typed to begin with, since this time the call to j is not a
+tail call. In summary, the very occurrence of the join variable in the MFE is
+proof that we can float the MFE as far as that binding.
 -}
 
 annotateBotStr :: Id -> Arity -> Maybe (Arity, StrictSig) -> Id
@@ -779,8 +873,9 @@ lvlBind env (AnnNonRec bndr rhs)
           -- We can't float an unlifted binding to top level, so we don't
           -- float it at all.  It's a bit brutal, but unlifted bindings
           -- aren't expensive either
+
   = -- No float
-    do { rhs' <- lvlExpr env rhs
+    do { rhs' <- lvlRhs env NonRecursive False mb_join_arity rhs
        ; let  bind_lvl        = incMinorLvl (le_ctxt_lvl env)
               (env', [bndr']) = substAndLvlBndrs NonRecursive env bind_lvl [bndr]
        ; return (NonRec bndr' rhs', env') }
@@ -788,15 +883,19 @@ lvlBind env (AnnNonRec bndr rhs)
   -- Otherwise we are going to float
   | null abs_vars
   = do {  -- No type abstraction; clone existing binder
-         rhs' <- lvlExpr (setCtxtLvl env dest_lvl) rhs
-       ; (env', [bndr']) <- cloneLetVars NonRecursive env dest_lvl [bndr]
+         rhs' <- lvlRhs (setCtxtLvl env dest_lvl) NonRecursive
+                        zapping_join mb_join_arity rhs
+       ; (env', [bndr']) <- cloneLetVars NonRecursive env dest_lvl
+                                         zapping_join [bndr]
        ; let bndr2 = annotateBotStr bndr' 0 mb_bot_str
        ; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') }
 
   | otherwise
   = do {  -- Yes, type abstraction; create a new binder, extend substitution, etc
-         rhs' <- lvlFloatRhs abs_vars dest_lvl env rhs
-       ; (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars [bndr]
+         rhs' <- lvlFloatRhs abs_vars dest_lvl env NonRecursive
+                             zapping_join mb_join_arity rhs
+       ; (env', [bndr']) <- newPolyBndrs dest_lvl env abs_vars
+                                         zapping_join [bndr]
        ; let bndr2 = annotateBotStr bndr' n_extra mb_bot_str
        ; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') }
 
@@ -805,24 +904,34 @@ lvlBind env (AnnNonRec bndr rhs)
     bind_fvs   = rhs_fvs `unionDVarSet` dIdFreeVars bndr
     abs_vars   = abstractVars dest_lvl env bind_fvs
     dest_lvl   = destLevel env bind_fvs (isFunction rhs) is_bot
+                                        is_unfloatable_join
     mb_bot_str = exprBotStrictness_maybe (deAnnotate rhs)
                            -- See Note [Bottoming floats]
                            -- esp Bottoming floats (2)
     is_bot     = isJust mb_bot_str
     n_extra    = count isId abs_vars
 
+    mb_join_arity = isJoinId_maybe bndr
+    is_unfloatable_join = case mb_join_arity of Just ar -> ar > 0
+                                                Nothing -> False
+      -- See Note [When to ruin a join point]
+    zapping_join = dest_lvl `ltLvl` joinCeilingLevel env
+
 lvlBind env (AnnRec pairs)
   |  floatTopLvlOnly env && not (isTopLvl dest_lvl)
          -- Only floating to the top level is allowed.
   || not (profitableFloat env dest_lvl)
   = do { let bind_lvl = incMinorLvl (le_ctxt_lvl env)
              (env', bndrs') = substAndLvlBndrs Recursive env bind_lvl bndrs
-       ; rhss' <- mapM (lvlExpr env') rhss
+       ; rhss' <- zipWithM (lvlRhs env' Recursive False) mb_join_arities rhss
        ; return (Rec (bndrs' `zip` rhss'), env') }
 
   | null abs_vars
-  = do { (new_env, new_bndrs) <- cloneLetVars Recursive env dest_lvl bndrs
-       ; new_rhss <- mapM (lvlExpr (setCtxtLvl new_env dest_lvl)) rhss
+  = do { (new_env, new_bndrs) <- cloneLetVars Recursive env dest_lvl
+                                              zapping_joins bndrs
+       ; let env_rhs = setCtxtLvl new_env dest_lvl
+       ; new_rhss <- zipWithM (lvlRhs env_rhs Recursive zapping_joins)
+                              mb_join_arities rhss
        ; return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss)
                 , new_env) }
 
@@ -843,13 +952,17 @@ lvlBind env (AnnRec pairs)
     let (rhs_env, abs_vars_w_lvls) = lvlLamBndrs env dest_lvl abs_vars
         rhs_lvl = le_ctxt_lvl rhs_env
 
-    (rhs_env', [new_bndr]) <- cloneLetVars Recursive rhs_env rhs_lvl [bndr]
+    (rhs_env', [new_bndr]) <- cloneLetVars Recursive rhs_env rhs_lvl
+                                           zapping_joins [bndr]
     let
         (lam_bndrs, rhs_body)   = collectAnnBndrs rhs
         (body_env1, lam_bndrs1) = substBndrsSL NonRecursive rhs_env' lam_bndrs
         (body_env2, lam_bndrs2) = lvlLamBndrs body_env1 rhs_lvl lam_bndrs1
-    new_rhs_body <- lvlExpr body_env2 rhs_body
-    (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars [bndr]
+        mb_join_arity           = isJoinId_maybe bndr
+    new_rhs_body <- lvlRhs body_env2 Recursive zapping_joins
+                           mb_join_arity rhs_body
+    (poly_env, [poly_bndr]) <- newPolyBndrs dest_lvl env abs_vars
+                                            zapping_joins [bndr]
     return (Rec [(TB poly_bndr (FloatMe dest_lvl)
                  , mkLams abs_vars_w_lvls $
                    mkLams lam_bndrs2 $
@@ -859,8 +972,11 @@ lvlBind env (AnnRec pairs)
            , poly_env)
 
   | otherwise  -- Non-null abs_vars
-  = do { (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars bndrs
-       ; new_rhss <- mapM (lvlFloatRhs abs_vars dest_lvl new_env) rhss
+  = do { (new_env, new_bndrs) <- newPolyBndrs dest_lvl env abs_vars
+                                              zapping_joins bndrs
+       ; new_rhss <- zipWithM (lvlFloatRhs abs_vars dest_lvl new_env
+                                           Recursive zapping_joins)
+                              mb_join_arities rhss
        ; return ( Rec ([TB b (FloatMe dest_lvl) | b <- new_bndrs] `zip` new_rhss)
                 , new_env) }
 
@@ -876,26 +992,72 @@ lvlBind env (AnnRec pairs)
                 bndrs
 
     dest_lvl = destLevel env bind_fvs (all isFunction rhss) False
+                         has_unfloatable_join
     abs_vars = abstractVars dest_lvl env bind_fvs
 
+    mb_join_arities = map isJoinId_maybe bndrs
+    has_unfloatable_join
+      = any (\mb_ar -> case mb_ar of Just ar -> ar > 0
+                                     Nothing -> False) mb_join_arities
+    zapping_joins = dest_lvl `ltLvl` joinCeilingLevel env
+
+lvlRhs :: LevelEnv
+       -> RecFlag
+       -> Bool -- True <=> we're zapping a join point back to a value
+       -> Maybe JoinArity
+       -> CoreExprWithFVs
+       -> LvlM LevelledExpr
+lvlRhs env rec_flag zapping_join mb_join_arity expr
+  = lvlFloatRhs [] (le_ctxt_lvl env) env rec_flag zapping_join
+                mb_join_arity expr
+
 profitableFloat :: LevelEnv -> Level -> Bool
 profitableFloat env dest_lvl
   =  (dest_lvl `ltMajLvl` le_ctxt_lvl env)  -- Escapes a value lambda
   || isTopLvl dest_lvl                      -- Going all the way to top level
 
+
+{-
+Note [When to ruin a join point]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Generally, we protect join points zealously. However, there are two situations
+in which it can pay to promote a join point to a function:
+
+1. If the join point has no value arguments, then floating it outward will make
+   it a *thunk*, not a function, so we might get increased sharing.
+2. If we float the join point all the way to the top level, it still won't be
+   allocated, so the cost is much less.
+
+Refusing to lose a join point in either of these cases can be disastrous---for
+instance, allocation in imaginary/x2n1 *triples* because $w$s^ becomes too big
+to inline, which prevents Float In from making a particular binding strictly
+demanded.
+-}
+
 ----------------------------------------------------
 -- Three help functions for the type-abstraction case
 
-lvlFloatRhs :: [OutVar] -> Level -> LevelEnv -> CoreExprWithFVs
-            -> UniqSM (Expr LevelledBndr)
-lvlFloatRhs abs_vars dest_lvl env rhs
-  = do { body' <- lvlExpr rhs_env body
+lvlFloatRhs :: [OutVar] -> Level -> LevelEnv -> RecFlag -> Bool
+            -> Maybe JoinArity -> CoreExprWithFVs
+            -> LvlM (Expr LevelledBndr)
+lvlFloatRhs abs_vars dest_lvl env rec zapping_joins mb_join_arity rhs
+  = do { body' <- if | Just _ <- mb_join_arity, not zapping_joins
+                       -> lvlExpr rhs_env body
+                     | otherwise
+                       -> lvlNonTailExpr rhs_env body
        ; return (mkLams all_bndrs_w_lvls body') }
   where
-    (bndrs, body)               = collectAnnBndrs rhs
+    (bndrs, body)               | Just join_arity <- mb_join_arity
+                                = collectNAnnBndrs join_arity rhs
+                                | otherwise
+                                = collectAnnBndrs rhs
     (env1, bndrs1)              = substBndrsSL NonRecursive env bndrs
     all_bndrs                   = abs_vars ++ bndrs1
-    (rhs_env, all_bndrs_w_lvls) = lvlLamBndrs env1 dest_lvl all_bndrs
+    (rhs_env, all_bndrs_w_lvls) | Just _ <- mb_join_arity
+                                = lvlJoinBndrs env1 dest_lvl rec all_bndrs
+                                | otherwise
+                                = lvlLamBndrs env1 dest_lvl all_bndrs
         -- The important thing here is that we call lvlLamBndrs on
         -- all these binders at once (abs_vars and bndrs), so they
         -- all get the same major level.  Otherwise we create stupid
@@ -941,13 +1103,21 @@ lvlLamBndrs env lvl bndrs
        -- probable one-shot lambda"
        -- See Note [Computing one-shot info] in Demand.hs
 
+lvlJoinBndrs :: LevelEnv -> Level -> RecFlag -> [OutVar]
+             -> (LevelEnv, [LevelledBndr])
+lvlJoinBndrs env lvl rec bndrs
+  = lvlBndrs env new_lvl bndrs
+  where
+    new_lvl | isRec rec = incMajorLvl lvl
+            | otherwise = incMinorLvl lvl
+      -- Non-recursive join points are one-shot; recursive ones are not
 
 lvlBndrs :: LevelEnv -> Level -> [CoreBndr] -> (LevelEnv, [LevelledBndr])
 -- The binders returned are exactly the same as the ones passed,
 -- apart from applying the substitution, but they are now paired
 -- with a (StayPut level)
 --
--- The returned envt has ctxt_lvl updated to the new_lvl
+-- The returned envt has le_ctxt_lvl updated to the new_lvl
 --
 -- All the new binders get the same level, because
 -- any floating binding is either going to float past
@@ -964,8 +1134,9 @@ lvlBndrs env@(LE { le_lvl_env = lvl_env }) new_lvl bndrs
 destLevel :: LevelEnv -> DVarSet
           -> Bool   -- True <=> is function
           -> Bool   -- True <=> is bottom
+          -> Bool   -- True <=> is join point (or can be floated anyway)
           -> Level
-destLevel env fvs is_function is_bot
+destLevel env fvs is_function is_bot is_join
   | is_bot = tOP_LEVEL  -- Send bottoming bindings to the top
                         -- regardless; see Note [Bottoming floats]
                         -- Esp Bottoming floats (1)
@@ -975,9 +1146,16 @@ destLevel env fvs is_function is_bot
   , countFreeIds fvs <= n_args
   = tOP_LEVEL   -- Send functions to top level; see
                 -- the comments with isFunction
+  | is_join, hits_ceiling = join_ceiling
+  | otherwise = max_fv_level
+  where
+    max_fv_level = maxFvLevel isId env fvs -- Max over Ids only; the tyvars
+                                           -- will be abstracted
 
-  | otherwise = maxFvLevel isId env fvs  -- Max over Ids only; the tyvars
-                                         -- will be abstracted
+    hits_ceiling = max_fv_level `ltLvl` join_ceiling &&
+                   not (isTopLvl max_fv_level)
+                     -- Note [When to ruin a join point]
+    join_ceiling = joinCeilingLevel env
 
 isFunction :: CoreExprWithFVs -> Bool
 -- The idea here is that we want to float *functions* to
@@ -1019,6 +1197,7 @@ data LevelEnv
   = LE { le_switches :: FloatOutSwitches
        , le_ctxt_lvl :: Level           -- The current level
        , le_lvl_env  :: VarEnv Level    -- Domain is *post-cloned* TyVars and Ids
+       , le_join_ceil:: Level           -- Highest level to which joins float
        , le_subst    :: Subst           -- Domain is pre-cloned TyVars and Ids
                                         -- The Id -> CoreExpr in the Subst is ignored
                                         -- (since we want to substitute a LevelledExpr for
@@ -1050,6 +1229,7 @@ initialEnv :: FloatOutSwitches -> LevelEnv
 initialEnv float_lams
   = LE { le_switches = float_lams
        , le_ctxt_lvl = tOP_LEVEL
+       , le_join_ceil = panic "initialEnv"
        , le_lvl_env = emptyVarEnv
        , le_subst = emptySubst
        , le_env = emptyVarEnv }
@@ -1087,6 +1267,13 @@ extendCaseBndrEnv le@(LE { le_subst = subst, le_env = id_env })
        , le_env     = add_id id_env (case_bndr, scrut_var) }
 extendCaseBndrEnv env _ _ = env
 
+-- See Note [Join ceiling]
+placeJoinCeiling :: LevelEnv -> LevelEnv
+placeJoinCeiling le@(LE { le_ctxt_lvl = lvl })
+  = le { le_ctxt_lvl = lvl', le_join_ceil = lvl' }
+  where
+    lvl' = asJoinCeilLvl (incMinorLvl lvl)
+
 maxFvLevel :: (Var -> Bool) -> LevelEnv -> DVarSet -> Level
 maxFvLevel max_me (LE { le_lvl_env = lvl_env, le_env = id_env }) var_set
   = foldDVarSet max_in tOP_LEVEL var_set
@@ -1107,6 +1294,18 @@ lookupVar le v = case lookupVarEnv (le_env le) v of
                     Just (_, expr) -> expr
                     _              -> Var v
 
+-- Level to which join points are allowed to float (boundary of current tail
+-- context). See Note [Join ceiling]
+joinCeilingLevel :: LevelEnv -> Level
+joinCeilingLevel = le_join_ceil
+
+remainsJoinId :: LevelEnv -> Id -> Bool
+remainsJoinId le v = case lookupVarEnv (le_env le) v of
+                         Just (v':_, _) -> isJoinId v'
+                         Nothing        -> isJoinId v
+                         Just ([], e)   -> pprPanic "remainsJoinId" $
+                                             ppr v $$ ppr e
+
 abstractVars :: Level -> LevelEnv -> DVarSet -> [OutVar]
         -- Find the variables in fvs, free vars of the target expression,
         -- whose level is greater than the destination level
@@ -1154,12 +1353,13 @@ type LvlM result = UniqSM result
 initLvl :: UniqSupply -> UniqSM a -> a
 initLvl = initUs_
 
-newPolyBndrs :: Level -> LevelEnv -> [OutVar] -> [InId] -> UniqSM (LevelEnv, [OutId])
+newPolyBndrs :: Level -> LevelEnv -> [OutVar] -> Bool -> [InId]
+             -> LvlM (LevelEnv, [OutId])
 -- The envt is extended to bind the new bndrs to dest_lvl, but
--- the ctxt_lvl is unaffected
+-- the le_ctxt_lvl is unaffected
 newPolyBndrs dest_lvl
              env@(LE { le_lvl_env = lvl_env, le_subst = subst, le_env = id_env })
-             abs_vars bndrs
+             abs_vars zapping_joins bndrs
  = ASSERT( all (not . isCoVar) bndrs )   -- What would we add to the CoSubst in this case. No easy answer.
    do { uniqs <- getUniquesM
       ; let new_bndrs = zipWith mk_poly_bndr bndrs uniqs
@@ -1173,17 +1373,28 @@ newPolyBndrs dest_lvl
     add_id    env (v, v') = extendVarEnv env v ((v':abs_vars), mkVarApps (Var v') abs_vars)
 
     mk_poly_bndr bndr uniq = transferPolyIdInfo bndr abs_vars $         -- Note [transferPolyIdInfo] in Id.hs
+                             maybe_transfer_join_info bndr $
                              mkSysLocalOrCoVar (mkFastString str) uniq poly_ty
                            where
                              str     = "poly_" ++ occNameString (getOccName bndr)
                              poly_ty = mkLamTypes abs_vars (CoreSubst.substTy subst (idType bndr))
+                             maybe_transfer_join_info bndr new_bndr
+                               | not zapping_joins
+                               , Just join_arity <- isJoinId_maybe bndr
+                               = new_bndr `asJoinId`
+                                   join_arity + length abs_vars
+                               | otherwise
+                               = new_bndr
 
 newLvlVar :: LevelledExpr        -- The RHS of the new binding
+          -> Maybe JoinArity     -- Its join arity, if it is a join point
           -> LvlM Id
-newLvlVar lvld_rhs
+newLvlVar lvld_rhs join_arity_maybe
   = do { uniq <- getUniqueM
-       ; return (mk_id uniq rhs_ty) }
+       ; return (add_join_info (mk_id uniq rhs_ty))
+       }
   where
+    add_join_info var = var `asJoinId_maybe` join_arity_maybe
     de_tagged_rhs = deTagExpr lvld_rhs
     rhs_ty        = exprType de_tagged_rhs
 
@@ -1208,25 +1419,30 @@ cloneCaseBndrs env@(LE { le_subst = subst, le_lvl_env = lvl_env, le_env = id_env
 
        ; return (env', vs') }
 
-cloneLetVars :: RecFlag -> LevelEnv -> Level -> [Var] -> LvlM (LevelEnv, [Var])
+cloneLetVars :: RecFlag -> LevelEnv -> Level -> Bool -> [InVar]
+             -> LvlM (LevelEnv, [OutVar])
 -- See Note [Need for cloning during float-out]
 -- Works for Ids bound by let(rec)
 -- The dest_lvl is attributed to the binders in the new env,
--- but cloneVars doesn't affect the ctxt_lvl of the incoming env
+-- but cloneVars doesn't affect the le_ctxt_lvl of the incoming env
 cloneLetVars is_rec
           env@(LE { le_subst = subst, le_lvl_env = lvl_env, le_env = id_env })
-          dest_lvl vs
+          dest_lvl zapping_joins vs
   = do { us <- getUniqueSupplyM
-       ; let (subst', vs1) = case is_rec of
-                               NonRecursive -> cloneBndrs      subst us vs
-                               Recursive    -> cloneRecIdBndrs subst us vs
-             vs2  = map zap_demand_info vs1  -- See Note [Zapping the demand info]
+       ; let vs1  = map (zap_demand_info . maybe_zap_join) vs
+                      -- See Note [Zapping the demand info]
+             (subst', vs2) = case is_rec of
+                               NonRecursive -> cloneBndrs      subst us vs1
+                               Recursive    -> cloneRecIdBndrs subst us vs1
              prs  = vs `zip` vs2
              env' = env { le_lvl_env = addLvls dest_lvl lvl_env vs2
                         , le_subst   = subst'
                         , le_env     = foldl add_id id_env prs }
 
        ; return (env', vs2) }
+  where
+    maybe_zap_join v | isId v, zapping_joins = zapJoinId v
+                     | otherwise             = v
 
 add_id :: IdEnv ([Var], LevelledExpr) -> (Var, Var) -> IdEnv ([Var], LevelledExpr)
 add_id id_env (v, v1)
@@ -1247,4 +1463,7 @@ binding site.  Eg
    f :: Int -> Int
    f x = let v = 3*4 in v+x
 Here v is strict; but if we float v to top level, it isn't any more.
+
+Similarly, if we're floating a join point, it won't be one anymore, so we zap
+join point information as well.
 -}
index 304dc5a..f032aad 100644 (file)
@@ -207,12 +207,16 @@ getCoreToDo dflags
     -- Static forms are moved to the top level with the FloatOut pass.
     -- See Note [Grand plan for static forms] in StaticPtrTable.
     static_ptrs_float_outwards =
-      runWhen static_ptrs $ CoreDoFloatOutwards FloatOutSwitches
-        { floatOutLambdas   = Just 0
-        , floatOutConstants = True
-        , floatOutOverSatApps = False
-        , floatToTopLevelOnly = True
-        }
+      runWhen static_ptrs $ CoreDoPasses
+        [ simpl_gently -- Float Out can't handle type lets (sometimes created
+                       -- by simpleOptPgm via mkParallelBindings)
+        , CoreDoFloatOutwards FloatOutSwitches
+          { floatOutLambdas   = Just 0
+          , floatOutConstants = True
+          , floatOutOverSatApps = False
+          , floatToTopLevelOnly = True
+          }
+        ]
 
     core_todo =
      if opt_level == 0 then
@@ -704,6 +708,7 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
                } ;
            Err.dumpIfSet_dyn dflags Opt_D_dump_occur_anal "Occurrence analysis"
                      (pprCoreBindings tagged_binds);
+           lintPassResult hsc_env CoreOccurAnal tagged_binds;
 
                 -- Get any new rules, and extend the rule base
                 -- See Note [Overall plumbing for rules] in Rules.hs
index 99d8291..f35d120 100644 (file)
@@ -20,17 +20,22 @@ module SimplEnv (
 
         -- * Substitution results
         SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
+        isJoinIdInEnv_maybe,
 
         -- * Simplifying 'Id' binders
-        simplNonRecBndr, simplRecBndrs,
+        simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
         simplBinder, simplBinders,
         substTy, substTyVar, getTCvSubst,
         substCo, substCoVar,
 
         -- * Floats
-        Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, extendFloats,
+        Floats, emptyFloats, isEmptyFloats,
+        addNonRec, addFloats, extendFloats,
         wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
-        doFloatFromRhs, getFloatBinds
+        doFloatFromRhs, getFloatBinds,
+
+        JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
+        wrapJoinFloats, zapJoinFloats, restoreJoinFloats, getJoinFloatBinds,
     ) where
 
 #include "HsVersions.h"
@@ -54,6 +59,7 @@ import BasicTypes
 import MonadUtils
 import Outputable
 import Util
+import UniqFM                   ( pprUniqFM )
 
 import Data.List
 
@@ -86,8 +92,10 @@ data SimplEnv
         -- They are all OutVars, and all bound in this module
         seInScope   :: InScopeSet,      -- OutVars only
                 -- Includes all variables bound by seFloats
-        seFloats    :: Floats
+        seFloats    :: Floats,
                 -- See Note [Simplifier floats]
+        seJoinFloats :: JoinFloats
+                -- Handled separately; they don't go very far
     }
 
 type StaticEnv = SimplEnv       -- Just the static part is relevant
@@ -97,17 +105,24 @@ pprSimplEnv :: SimplEnv -> SDoc
 pprSimplEnv env
   = vcat [text "TvSubst:" <+> ppr (seTvSubst env),
           text "CvSubst:" <+> ppr (seCvSubst env),
-          text "IdSubst:" <+> ppr (seIdSubst env),
+          text "IdSubst:" <+> id_subst_doc,
           text "InScope:" <+> in_scope_vars_doc
     ]
   where
+   id_subst_doc = pprUniqFM ppr_id_subst (seIdSubst env)
+   ppr_id_subst (m_ar, sr) = arity_part <+> ppr sr
+     where arity_part = case m_ar of Just ar -> brackets $
+                                                  text "join" <+> int ar
+                                     Nothing -> empty
+
    in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
                                  (vcat . map ppr_one)
    ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
              | otherwise = ppr v
 
-type SimplIdSubst = IdEnv SimplSR       -- IdId |--> OutExpr
+type SimplIdSubst = IdEnv (Maybe JoinArity, SimplSR) -- IdId |--> OutExpr
         -- See Note [Extending the Subst] in CoreSubst
+        -- See Note [Join arity in SimplIdSubst]
 
 -- | A substitution result.
 data SimplSR
@@ -192,6 +207,20 @@ seIdSubst:
   map to the same target:  x->x, y->x.  Notably:
         case y of x { ... }
   That's why the "set" is actually a VarEnv Var
+
+Note [Join arity in SimplIdSubst]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We have to remember which incoming variables are join points (the occurrences
+may not be marked correctly yet; we're in change of propagating the change if
+OccurAnal makes something a join point). Normally the in-scope set is where we
+keep the latest information, but the in-scope set tracks only OutVars; if a
+binding is unconditionally inlined, it never makes it into the in-scope set,
+and we need to know at the occurrence site that the variable is a join point so
+that we know to drop the context. Thus we remember which join points we're
+substituting. Clumsily, finding whether an InVar is a join variable may require
+looking in both the substitution *and* the in-scope set (see
+'isJoinIdInEnv_maybe').
 -}
 
 mkSimplEnv :: SimplifierMode -> SimplEnv
@@ -199,6 +228,7 @@ mkSimplEnv mode
   = SimplEnv { seMode = mode
              , seInScope = init_in_scope
              , seFloats = emptyFloats
+             , seJoinFloats = emptyJoinFloats
              , seTvSubst = emptyVarEnv
              , seCvSubst = emptyVarEnv
              , seIdSubst = emptyVarEnv }
@@ -241,7 +271,7 @@ updMode upd env = env { seMode = upd (seMode env) }
 extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
 extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
   = ASSERT2( isId var && not (isCoVar var), ppr var )
-    env {seIdSubst = extendVarEnv subst var res}
+    env { seIdSubst = extendVarEnv subst var (isJoinId_maybe var, res) }
 
 extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
 extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
@@ -264,13 +294,22 @@ setInScope :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set, and *zap* the floats
 setInScope env env_with_scope
   = env { seInScope = seInScope env_with_scope,
-          seFloats = emptyFloats }
+          seFloats = emptyFloats,
+          seJoinFloats = emptyJoinFloats }
 
 setFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set *and* the floats
 setFloats env env_with_floats
   = env { seInScope = seInScope env_with_floats,
-          seFloats  = seFloats  env_with_floats }
+          seFloats = seFloats  env_with_floats,
+          seJoinFloats = seJoinFloats env_with_floats }
+
+restoreJoinFloats :: SimplEnv -> SimplEnv -> SimplEnv
+-- Put back floats previously zapped
+-- Unlike 'setFloats', does *not* update the in-scope set, since the right-hand
+-- env is assumed to be *older*
+restoreJoinFloats env old_env
+  = env { seJoinFloats = seJoinFloats old_env }
 
 addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
         -- The new Ids are guaranteed to be freshly allocated
@@ -331,6 +370,8 @@ Can't happen:
 data Floats = Floats (OrdList OutBind) FloatFlag
         -- See Note [Simplifier floats]
 
+type JoinFloats = OrdList OutBind
+
 data FloatFlag
   = FltLifted   -- All bindings are lifted and lazy *or*
                 --     consist of a single primitive string literal
@@ -389,9 +430,13 @@ so we must take the 'or' of the two.
 emptyFloats :: Floats
 emptyFloats = Floats nilOL FltLifted
 
+emptyJoinFloats :: JoinFloats
+emptyJoinFloats = nilOL
+
 unitFloat :: OutBind -> Floats
 -- This key function constructs a singleton float with the right form
-unitFloat bind = Floats (unitOL bind) (flag bind)
+unitFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
+                 Floats (unitOL bind) (flag bind)
   where
     flag (Rec {})                = FltLifted
     flag (NonRec bndr rhs)
@@ -404,6 +449,10 @@ unitFloat bind = Floats (unitOL bind) (flag bind)
                                    FltCareful
       -- Unlifted binders can only be let-bound if exprOkForSpeculation holds
 
+unitJoinFloat :: OutBind -> JoinFloats
+unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
+                     unitOL bind
+
 addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
 -- Add a non-recursive binding and extend the in-scope set
 -- The latter is important; the binder may already be in the
@@ -412,58 +461,104 @@ addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
 addNonRec env id rhs
   = id `seq`   -- This seq forces the Id, and hence its IdInfo,
                -- and hence any inner substitutions
-    env { seFloats = seFloats env `addFlts` unitFloat (NonRec id rhs),
+    env { seFloats = floats',
+          seJoinFloats = jfloats',
           seInScope = extendInScopeSet (seInScope env) id }
+  where
+    bind = NonRec id rhs
+
+    floats'  | isJoinId id = seFloats env
+             | otherwise   = seFloats env `addFlts` unitFloat bind
+    jfloats' | isJoinId id = seJoinFloats env `addJoinFlts` unitJoinFloat bind
+             | otherwise   = seJoinFloats env
 
 extendFloats :: SimplEnv -> OutBind -> SimplEnv
 -- Add these bindings to the floats, and extend the in-scope env too
 extendFloats env bind
-  = env { seFloats  = seFloats env `addFlts` unitFloat bind,
+  = ASSERT(all (not . isJoinId) (bindersOf bind))
+    env { seFloats  = floats',
+          seJoinFloats = jfloats',
           seInScope = extendInScopeSetList (seInScope env) bndrs }
   where
     bndrs = bindersOf bind
 
+    floats'  | isJoinBind bind = seFloats env
+             | otherwise       = seFloats env `addFlts` unitFloat bind
+    jfloats' | isJoinBind bind = seJoinFloats env `addJoinFlts`
+                                   unitJoinFloat bind
+             | otherwise       = seJoinFloats env
+
 addFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Add the floats for env2 to env1;
 -- *plus* the in-scope set for env2, which is bigger
 -- than that for env1
 addFloats env1 env2
   = env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
+          seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2,
           seInScope = seInScope env2 }
 
 addFlts :: Floats -> Floats -> Floats
 addFlts (Floats bs1 l1) (Floats bs2 l2)
   = Floats (bs1 `appOL` bs2) (l1 `andFF` l2)
 
+addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
+addJoinFlts = appOL
+
 zapFloats :: SimplEnv -> SimplEnv
-zapFloats env = env { seFloats = emptyFloats }
+zapFloats env = env { seFloats = emptyFloats
+                    , seJoinFloats = emptyJoinFloats }
+
+zapJoinFloats :: SimplEnv -> SimplEnv
+zapJoinFloats env = env { seJoinFloats = emptyJoinFloats }
 
 addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Flattens the floats from env2 into a single Rec group,
 -- prepends the floats from env1, and puts the result back in env2
 -- This is all very specific to the way recursive bindings are
 -- handled; see Simplify.simplRecBind
-addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff})
+addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
+                                 ,seJoinFloats = jbs })
   = ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
-    env2 {seFloats = seFloats env1 `addFlts` unitFloat (Rec (flattenBinds (fromOL bs)))}
+    env2 {seFloats = seFloats env1 `addFlts` floats'
+         ,seJoinFloats = seJoinFloats env1 `addJoinFlts` jfloats'}
+  where
+    floats'  | isNilOL bs  = emptyFloats
+             | otherwise   = unitFloat (Rec (flattenBinds (fromOL bs)))
+    jfloats' | isNilOL jbs = emptyJoinFloats
+             | otherwise   = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
 
 wrapFloats :: SimplEnv -> OutExpr -> OutExpr
 -- Wrap the floats around the expression; they should all
 -- satisfy the let/app invariant, so mkLets should do the job just fine
-wrapFloats (SimplEnv {seFloats = Floats bs _}) body
-  = foldrOL Let body bs
+wrapFloats env@(SimplEnv {seFloats = Floats bs _}) body
+  = foldrOL Let (wrapJoinFloats env body) bs
+      -- Note: Always safe to put the joins on the inside since the values
+      -- can't refer to them
+
+wrapJoinFloats :: SimplEnv -> OutExpr -> OutExpr
+wrapJoinFloats (SimplEnv {seJoinFloats = jbs}) body
+  = foldrOL Let body jbs
 
 getFloatBinds :: SimplEnv -> [CoreBind]
-getFloatBinds (SimplEnv {seFloats = Floats bs _})
-  = fromOL bs
+getFloatBinds env@(SimplEnv {seFloats = Floats bs _})
+  = fromOL bs ++ getJoinFloatBinds env
+
+getJoinFloatBinds :: SimplEnv -> [CoreBind]
+getJoinFloatBinds (SimplEnv {seJoinFloats = jbs})
+  = fromOL jbs
 
 isEmptyFloats :: SimplEnv -> Bool
-isEmptyFloats (SimplEnv {seFloats = Floats bs _})
-  = isNilOL bs
+isEmptyFloats env@(SimplEnv {seFloats = Floats bs _})
+  = isNilOL bs && isEmptyJoinFloats env
+
+isEmptyJoinFloats :: SimplEnv -> Bool
+isEmptyJoinFloats (SimplEnv {seJoinFloats = jbs})
+  = isNilOL jbs
 
 mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
-mapFloats env@SimplEnv { seFloats = Floats fs ff } fun
-   = env { seFloats = Floats (mapOL app fs) ff }
+mapFloats env@SimplEnv { seFloats = Floats fs ff, seJoinFloats = jfs } fun
+   = env { seFloats = Floats (mapOL app fs) ff
+         , seJoinFloats = mapOL app jfs }
    where
     app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
     app (Rec bs)     = Rec (map fun bs)
@@ -490,7 +585,7 @@ find that it has been substituted by b.  (Or conceivably cloned.)
 substId :: SimplEnv -> InId -> SimplSR
 -- Returns DoneEx only on a non-Var expression
 substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
-  = case lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
+  = case snd <$> lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
         Nothing               -> DoneId (refineFromInScope in_scope v)
         Just (DoneId v)       -> DoneId (refineFromInScope in_scope v)
         Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
@@ -499,6 +594,15 @@ substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
         -- Get the most up-to-date thing from the in-scope set
         -- Even though it isn't in the substitution, it may be in
         -- the in-scope set with better IdInfo
+
+isJoinIdInEnv_maybe :: SimplEnv -> InId -> Maybe JoinArity
+isJoinIdInEnv_maybe (SimplEnv { seInScope = inScope, seIdSubst = ids }) v
+  | not (isLocalId v)                         = Nothing
+  | Just (m_ar, _) <- lookupVarEnv ids v      = m_ar
+  | Just v'        <- lookupInScope inScope v = isJoinId_maybe v'
+  | otherwise                                 = WARN( True , ppr v )
+                                                isJoinId_maybe v
+
 refineFromInScope :: InScopeSet -> Var -> Var
 refineFromInScope in_scope v
   | isLocalId v = case lookupInScope in_scope v of
@@ -511,7 +615,7 @@ lookupRecBndr :: SimplEnv -> InId -> OutId
 -- but where we have not yet done its RHS
 lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
   = case lookupVarEnv ids v of
-        Just (DoneId v) -> v
+        Just (_, DoneId v) -> v
         Just _ -> pprPanic "lookupRecBndr" (ppr v)
         Nothing -> refineFromInScope in_scope v
 
@@ -539,33 +643,53 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 simplBinder env bndr
   | isTyVar bndr  = do  { let (env', tv) = substTyVarBndr env bndr
                         ; seqTyVar tv `seq` return (env', tv) }
-  | otherwise     = do  { let (env', id) = substIdBndr env bndr
+  | otherwise     = do  { let (env', id) = substIdBndr Nothing env bndr
                         ; seqId id `seq` return (env', id) }
 
 ---------------
 simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 -- A non-recursive let binder
 simplNonRecBndr env id
-  = do  { let (env1, id1) = substIdBndr env id
+  = do  { let (env1, id1) = substIdBndr Nothing env id
+        ; seqId id1 `seq` return (env1, id1) }
+
+---------------
+simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
+                    -> SimplM (SimplEnv, OutBndr)
+-- A non-recursive let binder for a join point; context being pushed inward may
+-- change the type
+simplNonRecJoinBndr env res_ty id
+  = do  { let (env1, id1) = substIdBndr (Just res_ty) env id
         ; seqId id1 `seq` return (env1, id1) }
 
 ---------------
 simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
 -- Recursive let binders
 simplRecBndrs env@(SimplEnv {}) ids
-  = do  { let (env1, ids1) = mapAccumL substIdBndr env ids
+  = ASSERT(all (not . isJoinId) ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
+        ; seqIds ids1 `seq` return env1 }
+
+---------------
+simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
+-- Recursive let binders for join points; context being pushed inward may
+-- change types
+simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
+  = ASSERT(all isJoinId ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
         ; seqIds ids1 `seq` return env1 }
 
 ---------------
-substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
+substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
 -- Might be a coercion variable
-substIdBndr env bndr
+substIdBndr new_res_ty env bndr
   | isCoVar bndr  = substCoVarBndr env bndr
-  | otherwise     = substNonCoVarIdBndr env bndr
+  | otherwise     = substNonCoVarIdBndr new_res_ty env bndr
 
 ---------------
 substNonCoVarIdBndr
-   :: SimplEnv
+   :: Maybe OutType -- New result type, if a join binder
+   -> SimplEnv
    -> InBndr    -- Env and binder to transform
    -> (SimplEnv, OutBndr)
 -- Clone Id if necessary, substitute its type
@@ -585,7 +709,9 @@ substNonCoVarIdBndr
 -- Similar to CoreSubst.substIdBndr, except that
 --      the type of id_subst differs
 --      all fragile info is zapped
-substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst })
+substNonCoVarIdBndr new_res_ty
+                    env@(SimplEnv { seInScope = in_scope
+                                  , seIdSubst = id_subst })
                     old_id
   = ASSERT2( not (isCoVar old_id), ppr old_id )
     (env { seInScope = in_scope `extendInScopeSet` new_id,
@@ -593,14 +719,19 @@ substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }
   where
     id1    = uniqAway in_scope old_id
     id2    = substIdType env id1
-    new_id = zapFragileIdInfo id2       -- Zaps rules, worker-info, unfolding
+    id3    | Just res_ty <- new_res_ty
+           = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
+           | otherwise
+           = id2
+    new_id = zapFragileIdInfo id3       -- Zaps rules, worker-info, unfolding
                                         -- and fragile OccInfo
 
         -- Extend the substitution if the unique has changed,
         -- or there's some useful occurrence information
         -- See the notes with substTyVarBndr for the delSubstEnv
     new_subst | new_id /= old_id
-              = extendVarEnv id_subst old_id (DoneId new_id)
+              = extendVarEnv id_subst old_id
+                             (isJoinId_maybe new_id, DoneId new_id)
               | otherwise
               = delVarEnv id_subst old_id
 
@@ -664,7 +795,8 @@ the letrec.
 -}
 
 getTCvSubst :: SimplEnv -> TCvSubst
-getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env })
+getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
+                      , seCvSubst = cv_env })
   = mkTCvSubst in_scope (tv_env, cv_env)
 
 substTy :: SimplEnv -> Type -> Type
index 3b48924..2e985c5 100644 (file)
@@ -19,7 +19,7 @@ module SimplUtils (
         -- The continuation type
         SimplCont(..), DupFlag(..),
         isSimplified,
-        contIsDupable, contResultType, contHoleType,
+        contIsDupable, contResultType, contHoleType, applyContToJoinType,
         contIsTrivial, contArgs,
         countArgs,
         mkBoringStop, mkRhsStop, mkLazyArgStop, contIsRhsOrArg,
@@ -47,6 +47,7 @@ import CoreArity
 import CoreUnfold
 import Name
 import Id
+import IdInfo
 import Var
 import Demand
 import SimplMonad
@@ -361,6 +362,10 @@ contHoleType (ApplyToVal { sc_arg = e, sc_env = se, sc_dup = dup, sc_cont = k })
 contHoleType (Select { sc_dup = d, sc_bndr =  b, sc_env = se })
   = perhapsSubstTy d se (idType b)
 
+applyContToJoinType :: JoinArity -> SimplCont -> OutType -> OutType
+applyContToJoinType ar cont ty
+  = setJoinResTy ar (contResultType cont) ty
+
 -------------------
 countArgs :: SimplCont -> Int
 -- Count all arguments, including types, coercions, and other values
@@ -629,7 +634,7 @@ interestingArg env e = go env 0 e
     -- n is # value args to which the expression is applied
     go env n (Var v)
        | SimplEnv { seIdSubst = ids, seInScope = in_scope } <- env
-       = case lookupVarEnv ids v of
+       = case snd <$> lookupVarEnv ids v of
            Nothing                     -> go_var n (refineFromInScope in_scope v)
            Just (DoneId v')            -> go_var n (refineFromInScope in_scope v')
            Just (DoneEx e)             -> go (zapSubstEnv env)             n e
@@ -1054,7 +1059,9 @@ preInlineUnconditionally dflags env top_lvl bndr rhs
   | isCoVar bndr                             = False -- Note [Do not inline CoVars unconditionally]
   | otherwise = case idOccInfo bndr of
                   IAmDead                    -> True -- Happens in ((\x.1) v)
-                  OneOcc in_lam True int_cxt -> try_once in_lam int_cxt
+                  occ@OneOcc { occ_one_br = True }
+                                             -> try_once (occ_in_lam occ)
+                                                         (occ_int_cxt occ)
                   _                          -> False
   where
     mode = getMode env
@@ -1180,7 +1187,8 @@ postInlineUnconditionally dflags env top_lvl bndr occ_info rhs unfolding
         --         False -> case x of ...
         -- This is very important in practice; e.g. wheel-seive1 doubles
         -- in allocation if you miss this out
-      OneOcc in_lam _one_br int_cxt     -- OneOcc => no code-duplication issue
+      OneOcc { occ_in_lam = in_lam, occ_int_cxt = int_cxt }
+               -- OneOcc => no code-duplication issue
         ->     smallEnoughToInline dflags unfolding     -- Small enough to dup
                         -- ToDo: consider discount on smallEnoughToInline if int_cxt is true
                         --
@@ -1398,9 +1406,10 @@ because the latter is not well-kinded.
 ************************************************************************
 -}
 
-tryEtaExpandRhs :: SimplEnv -> OutId -> OutExpr -> SimplM (Arity, OutExpr)
+tryEtaExpandRhs :: SimplEnv -> RecFlag -> OutId -> OutExpr
+                -> SimplM (Arity, OutExpr)
 -- See Note [Eta-expanding at let bindings]
-tryEtaExpandRhs env bndr rhs
+tryEtaExpandRhs env is_rec bndr rhs
   = do { dflags <- getDynFlags
        ; (new_arity, new_rhs) <- try_expand dflags
 
@@ -1419,8 +1428,12 @@ tryEtaExpandRhs env bndr rhs
             new_arity2 = idCallArity bndr
             new_arity  = max new_arity1 new_arity2
       , new_arity > old_arity      -- And the current manifest arity isn't enough
-      = do { tick (EtaExpansion bndr)
-           ; return (new_arity, etaExpand new_arity rhs) }
+      = if is_rec == Recursive && isJoinId bndr
+           then WARN(True, text "Can't eta-expand recursive join point:" <+>
+                             ppr bndr)
+                return (old_arity, rhs)
+           else do { tick (EtaExpansion bndr)
+                   ; return (new_arity, etaExpand new_arity rhs) }
       | otherwise
       = return (old_arity, rhs)
 
index c1f2a9f..7c6f875 100644 (file)
@@ -18,7 +18,7 @@ import SimplUtils
 import FamInstEnv       ( FamInstEnv )
 import Literal          ( litIsLifted ) --, mkMachInt ) -- temporalily commented out. See #8326
 import Id
-import MkId             ( seqId, voidPrimId )
+import MkId             ( seqId )
 import MkCore           ( mkImpossibleExpr, castBottomExpr )
 import IdInfo
 import Name             ( Name, mkSystemVarName, isExternalName, getOccFS )
@@ -37,10 +37,11 @@ import CoreArity
 import CoreSubst        ( pushCoTyArg, pushCoValArg )
 --import PrimOp           ( tagToEnumKey ) -- temporalily commented out. See #8326
 import Rules            ( mkRuleInfo, lookupRule, getRules )
-import TysPrim          ( voidPrimTy ) --, intPrimTy ) -- temporalily commented out. See #8326
-import BasicTypes       ( TopLevelFlag(..), isTopLevel, RecFlag(..) )
+--import TysPrim          ( intPrimTy ) -- temporalily commented out. See #8326
+import BasicTypes       ( TopLevelFlag(..), isNotTopLevel, isTopLevel,
+                          RecFlag(..) )
 import MonadUtils       ( foldlM, mapAccumLM, liftIO )
-import Maybes           ( orElse )
+import Maybes           ( isJust, fromJust, orElse )
 --import Unique           ( hasKey ) -- temporalily commented out. See #8326
 import Control.Monad
 import Outputable
@@ -203,6 +204,35 @@ we should eta expand wherever we find a (value) lambda?  Then the eta
 expansion at a let RHS can concentrate solely on the PAP case.
 
 
+Case-of-case and join points
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we perform the case-of-case transform (or otherwise push continuations
+inward), we want to treat join points specially. Since they're always
+tail-called and we want to maintain this invariant, we can do this (for any
+evaluation context E):
+
+  E[join j = e
+    in case ... of
+         A -> jump j 1
+         B -> jump j 2
+         C -> f 3]
+
+    -->
+
+  join j = E[e]
+  in case ... of
+       A -> jump j 1
+       B -> jump j 2
+       C -> E[f 3]
+
+As is evident from the example, there are two components to this behavior:
+
+  1. When entering the RHS of a join point, copy the context inside.
+  2. When a join point is invoked, discard the outer context.
+
+Clearly we need to be very careful here to remain consistent---neither part is
+optional!
+
 ************************************************************************
 *                                                                      *
 \subsection{Bindings}
@@ -232,9 +262,11 @@ simplTopBinds env0 binds0
     simpl_binds env (bind:binds) = do { env' <- simpl_bind env bind
                                       ; simpl_binds env' binds }
 
-    simpl_bind env (Rec pairs)  = simplRecBind      env  TopLevel pairs
+    simpl_bind env (Rec pairs)  = simplRecBind env TopLevel Nothing pairs
     simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b)
-                                     ; simplRecOrTopPair env' TopLevel NonRecursive b b' r }
+                                     ; simplRecOrTopPair env' TopLevel
+                                                         NonRecursive Nothing
+                                                         b b' r }
 
 {-
 ************************************************************************
@@ -247,10 +279,10 @@ simplRecBind is used for
         * recursive bindings only
 -}
 
-simplRecBind :: SimplEnv -> TopLevelFlag
+simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont
              -> [(InId, InExpr)]
              -> SimplM SimplEnv
-simplRecBind env0 top_lvl pairs0
+simplRecBind env0 top_lvl mb_cont pairs0
   = do  { (env_with_info, triples) <- mapAccumLM add_rules env0 pairs0
         ; env1 <- go (zapFloats env_with_info) triples
         ; return (env0 `addRecFloats` env1) }
@@ -266,7 +298,8 @@ simplRecBind env0 top_lvl pairs0
     go env [] = return env
 
     go env ((old_bndr, new_bndr, rhs) : pairs)
-        = do { env' <- simplRecOrTopPair env top_lvl Recursive old_bndr new_bndr rhs
+        = do { env' <- simplRecOrTopPair env top_lvl Recursive mb_cont
+                                         old_bndr new_bndr rhs
              ; go env' pairs }
 
 {-
@@ -278,18 +311,18 @@ It assumes the binder has already been simplified, but not its IdInfo.
 -}
 
 simplRecOrTopPair :: SimplEnv
-                  -> TopLevelFlag -> RecFlag
+                  -> TopLevelFlag -> RecFlag -> Maybe SimplCont
                   -> InId -> OutBndr -> InExpr  -- Binder and rhs
                   -> SimplM SimplEnv    -- Returns an env that includes the binding
 
-simplRecOrTopPair env top_lvl is_rec old_bndr new_bndr rhs
+simplRecOrTopPair env top_lvl is_rec mb_cont old_bndr new_bndr rhs
   = do { dflags <- getDynFlags
        ; trace_bind dflags $
            if preInlineUnconditionally dflags env top_lvl old_bndr rhs
                     -- Check for unconditional inline
            then do tick (PreInlineUnconditionally old_bndr)
                    return (extendIdSubst env old_bndr (mkContEx env rhs))
-           else simplLazyBind env top_lvl is_rec old_bndr new_bndr rhs env }
+           else simplBind env top_lvl is_rec mb_cont old_bndr new_bndr rhs env }
   where
     trace_bind dflags thing_inside
       | not (dopt Opt_D_verbose_core2core dflags)
@@ -300,7 +333,7 @@ simplRecOrTopPair env top_lvl is_rec old_bndr new_bndr rhs
         -- helps to locate the tracing for inlining and rule firing
 
 {-
-simplLazyBind is used for
+simplBind is used for
   * [simplRecOrTopPair] recursive bindings (whether top level or not)
   * [simplRecOrTopPair] top-level non-recursive bindings
   * [simplNonRecE]      non-top-level *lazy* non-recursive bindings
@@ -315,6 +348,19 @@ Nota bene:
        that should have been done already.
 -}
 
+simplBind :: SimplEnv
+          -> TopLevelFlag -> RecFlag -> Maybe SimplCont
+          -> InId -> OutId      -- Binder, both pre-and post simpl
+                                -- The OutId has IdInfo, except arity, unfolding
+          -> InExpr -> SimplEnv -- The RHS and its environment
+          -> SimplM SimplEnv
+simplBind env top_lvl is_rec mb_cont bndr bndr1 rhs rhs_se
+  | isJoinId bndr1
+  = ASSERT(isNotTopLevel top_lvl && isJust mb_cont)
+    simplJoinBind env is_rec (fromJust mb_cont) bndr bndr1 rhs rhs_se
+  | otherwise
+  = simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
+
 simplLazyBind :: SimplEnv
               -> TopLevelFlag -> RecFlag
               -> InId -> OutId          -- Binder, both pre-and post simpl
@@ -346,7 +392,10 @@ simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
 
         -- Simplify the RHS
         ; let   rhs_cont = mkRhsStop (substTy body_env (exprType body))
-        ; (body_env1, body1) <- simplExprF body_env body rhs_cont
+        ; (body_env0, body0) <- simplExprF (zapJoinFloats body_env)
+                                           body rhs_cont
+        ; let body1     = wrapJoinFloats body_env0 body0
+              body_env1 = body_env0 `restoreJoinFloats` body_env
         -- ANF-ise a constructor or PAP rhs
         ; (body_env2, body2) <- prepareRhs top_lvl body_env1 bndr1 body1
 
@@ -367,7 +416,24 @@ simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
                         ; env' <- foldlM (addPolyBind top_lvl) env poly_binds
                         ; return (env', rhs') }
 
-        ; completeBind env' top_lvl bndr bndr1 rhs' }
+        ; completeBind env' top_lvl is_rec Nothing bndr bndr1 rhs' }
+
+simplJoinBind :: SimplEnv
+              -> RecFlag
+              -> SimplCont
+              -> InId -> OutId          -- Binder, both pre-and post simpl
+                                        -- The OutId has IdInfo, except arity,
+                                        --   unfolding
+              -> InExpr -> SimplEnv     -- The RHS and its environment
+              -> SimplM SimplEnv
+simplJoinBind env is_rec cont bndr bndr1 rhs rhs_se
+  = -- pprTrace "simplLazyBind" ((ppr bndr <+> ppr bndr1) $$
+    --                           ppr rhs $$ ppr (seIdSubst rhs_se)) $
+    do  { let   rhs_env     = rhs_se `setInScope` env
+
+        -- Simplify the RHS
+        ; rhs' <- simplJoinRhs rhs_env cont bndr rhs
+        ; completeBind env NotTopLevel is_rec (Just cont) bndr bndr1 rhs' }
 
 {-
 A specialised variant of simplNonRec used when the RHS is already simplified,
@@ -402,13 +468,15 @@ completeNonRecX :: TopLevelFlag -> SimplEnv
 --               See Note [CoreSyn let/app invariant] in CoreSyn
 
 completeNonRecX top_lvl env is_strict old_bndr new_bndr new_rhs
-  = do  { (env1, rhs1) <- prepareRhs top_lvl (zapFloats env) new_bndr new_rhs
+  = ASSERT(not (isJoinId new_bndr))
+    do  { (env1, rhs1) <- prepareRhs top_lvl (zapFloats env) new_bndr new_rhs
         ; (env2, rhs2) <-
                 if doFloatFromRhs NotTopLevel NonRecursive is_strict rhs1 env1
                 then do { tick LetFloatFromLet
                         ; return (addFloats env env1, rhs1) }   -- Add the floats to the main env
                 else return (env, wrapFloats env1 rhs1)         -- Wrap the floats around the RHS
-        ; completeBind env2 NotTopLevel old_bndr new_bndr rhs2 }
+        ; completeBind env2 NotTopLevel NonRecursive Nothing
+                       old_bndr new_bndr rhs2 }
 
 {-
 {- No, no, no!  Do not try preInlineUnconditionally in completeNonRecX
@@ -664,6 +732,8 @@ Nor does it do the atomic-argument thing
 
 completeBind :: SimplEnv
              -> TopLevelFlag            -- Flag stuck into unfolding
+             -> RecFlag                 -- Recursive binding?
+             -> Maybe SimplCont         -- Required only for join point
              -> InId                    -- Old binder
              -> OutId -> OutExpr        -- New binder and RHS
              -> SimplM SimplEnv
@@ -672,7 +742,7 @@ completeBind :: SimplEnv
 --      * or by adding to the floats in the envt
 --
 -- Precondition: rhs obeys the let/app invariant
-completeBind env top_lvl old_bndr new_bndr new_rhs
+completeBind env top_lvl is_rec mb_cont old_bndr new_bndr new_rhs
  | isCoVar old_bndr
  = case new_rhs of
      Coercion co -> return (extendCvSubst env old_bndr co)
@@ -686,10 +756,15 @@ completeBind env top_lvl old_bndr new_bndr new_rhs
 
         -- Do eta-expansion on the RHS of the binding
         -- See Note [Eta-expanding at let bindings] in SimplUtils
-      ; (new_arity, final_rhs) <- tryEtaExpandRhs env new_bndr new_rhs
+      ; (new_arity, final_rhs) <- if isJoinId new_bndr
+                                    then return (manifestArity new_rhs, new_rhs)
+                                         -- Note [Don't eta-expand join points]
+                                    else tryEtaExpandRhs env is_rec
+                                                         new_bndr new_rhs
 
         -- Simplify the unfolding
-      ; new_unfolding <- simplLetUnfolding env top_lvl old_bndr final_rhs old_unf
+      ; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr
+                                           final_rhs old_unf
 
       ; dflags <- getDynFlags
       ; if postInlineUnconditionally dflags env top_lvl new_bndr occ_info
@@ -740,7 +815,8 @@ addPolyBind :: TopLevelFlag -> SimplEnv -> OutBind -> SimplM SimplEnv
 -- INVARIANT: the arity is correct on the incoming binders
 
 addPolyBind top_lvl env (NonRec poly_id rhs)
-  = do  { unfolding <- simplLetUnfolding env top_lvl poly_id rhs noUnfolding
+  = do  { unfolding <- simplLetUnfolding env top_lvl Nothing poly_id rhs
+                                         noUnfolding
                         -- Assumes that poly_id did not have an INLINE prag
                         -- which is perhaps wrong.  ToDo: think about this
         ; let final_id = setIdInfo poly_id $
@@ -793,6 +869,44 @@ After inlining f at some of its call sites the original binding may
 (for example) be no longer strictly demanded.
 The solution here is a bit ad hoc...
 
+Note [Don't eta-expand join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Similarly to CPR (see Note [Don't CPR join points] in WorkWrap), a join point
+stands well to gain from its outer binding's eta-expansion, and eta-expanding a
+join point is fraught with issues like how to deal with a cast:
+
+    let join $j1 :: IO ()
+             $j1 = ...
+             $j2 :: Int -> IO ()
+             $j2 n = if n > 0 then $j1
+                              else ...
+
+    =>
+
+    let join $j1 :: IO ()
+             $j1 = (\eta -> ...)
+                     `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+                                 ~  IO ()
+             $j2 :: Int -> IO ()
+             $j2 n = (\eta -> if n > 0 then $j1
+                                       else ...)
+                     `cast` N:IO :: State# RealWorld -> (# State# RealWorld, ())
+                                 ~  IO ()
+
+The cast here can't be pushed inside the lambda (since it's not casting to a
+function type), so the lambda has to stay, but it can't because it contains a
+reference to a join point. In fact, $j2 can't be eta-expanded at all. Rather
+than try and detect this situation (and whatever other situations crop up!), we
+don't bother; again, any surrounding eta-expansion will improve these join
+points anyway, since an outer cast can *always* be pushed inside. By the time
+CorePrep comes around, the code is very likely to look more like this:
+
+    let join $j1 :: State# RealWorld -> (# State# RealWorld, ())
+             $j1 = (...) eta
+             $j2 :: Int -> State# RealWorld -> (# State# RealWorld, ())
+             $j2 = if n > 0 then $j1
+                            else (...) eta
 
 ************************************************************************
 *                                                                      *
@@ -917,17 +1031,33 @@ simplExprF1 env (Case scrut bndr _ alts) cont
                                  , sc_env = env, sc_cont = cont })
 
 simplExprF1 env (Let (Rec pairs) body) cont
-  = do  { env' <- simplRecBndrs env (map fst pairs)
-                -- NB: bndrs' don't have unfoldings or rules
-                -- We add them as we go down
-
-        ; env'' <- simplRecBind env' NotTopLevel pairs
-        ; simplExprF env'' body cont }
+  = simplRecE env pairs body cont
 
 simplExprF1 env (Let (NonRec bndr rhs) body) cont
   = simplNonRecE env bndr (rhs, env) ([], body) cont
 
 ---------------------------------
+-- Simplify a join point, adding the context.
+-- Context goes *inside* the lambdas. IOW, if the join point has arity n, we do:
+--   \x1 .. xn -> e => \x1 .. xn -> E[e]
+-- Note that we need the arity of the join point, since e may be a lambda
+-- (though this is unlikely). See Note [Case-of-case and join points].
+simplJoinRhs :: SimplEnv -> SimplCont -> InId -> InExpr
+             -> SimplM OutExpr
+simplJoinRhs env cont bndr expr
+  | Just arity <- isJoinId_maybe bndr
+  = simpl_join_lams arity
+  | otherwise
+  = pprPanic "simplJoinRhs" (ppr bndr)
+  where
+    simpl_join_lams arity
+      = do { (env', join_bndrs') <- simplLamBndrs env join_bndrs
+           ; join_body' <- simplExprC env' join_body cont
+           ; return $ mkLams join_bndrs' join_body' }
+      where
+        (join_bndrs, join_body) = collectNBinders arity expr
+
+---------------------------------
 simplType :: SimplEnv -> InType -> SimplM OutType
         -- Kept monadic just so we can do the seqType
 simplType env ty
@@ -1270,7 +1400,7 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 simplLamBndr env bndr
   | isId bndr && isFragileUnfolding old_unf   -- Special case
   = do { (env1, bndr1) <- simplBinder env bndr
-       ; unf'          <- simplUnfolding env1 NotTopLevel bndr old_unf
+       ; unf'          <- simplUnfolding env1 NotTopLevel Nothing bndr old_unf
        ; let bndr2 = bndr1 `setIdUnfolding` unf'
        ; return (modifyInScope env1 bndr2, bndr2) }
 
@@ -1322,6 +1452,25 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
            -> simplExprF (rhs_se `setFloats` env) rhs
                          (StrictBind bndr bndrs body env cont)
 
+           | Just (bndr', rhs') <- matchOrConvertToJoinPoint bndr rhs
+           -> do { let cont_dup_res_ty = resultTypeOfDupableCont (getMode env)
+                                           [bndr'] cont
+                 ; (env1, bndr1) <- simplNonRecJoinBndr env
+                                                        cont_dup_res_ty bndr'
+                 ; (env2, bndr2) <- addBndrRules env1 bndr' bndr1
+                 ; (env3, cont_dup, cont_nodup)
+                     <- prepareLetCont (zapJoinFloats env2) [bndr'] cont
+                 ; MASSERT2(cont_dup_res_ty `eqType` contResultType cont_dup,
+                     ppr cont_dup_res_ty $$ blankLine $$
+                     ppr cont $$ blankLine $$
+                     ppr cont_dup $$ blankLine $$
+                     ppr cont_nodup)
+                 ; env4 <- simplJoinBind env3 NonRecursive cont_dup bndr' bndr2
+                                         rhs' rhs_se
+                 ; (env5, expr) <- simplLam env4 bndrs body cont_dup
+                 ; rebuild (env5 `restoreJoinFloats` env2)
+                           (wrapJoinFloats env5 expr) cont_nodup }
+
            | otherwise
            -> ASSERT( not (isTyVar bndr) )
               do { (env1, bndr1) <- simplNonRecBndr env bndr
@@ -1329,6 +1478,64 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
                  ; env3 <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
                  ; simplLam env3 bndrs body cont }
 
+------------------
+simplRecE :: SimplEnv
+          -> [(InId, InExpr)]
+          -> InExpr
+          -> SimplCont
+          -> SimplM (SimplEnv, OutExpr)
+
+-- simplRecE is used for
+--  * non-top-level recursive lets in expressions
+simplRecE env pairs body cont
+  | Just pairs' <- matchOrConvertToJoinPoints pairs
+  = do  { let bndrs' = map fst pairs'
+              cont_dup_res_ty = resultTypeOfDupableCont (getMode env)
+                                                        bndrs' cont
+        ; env1 <- simplRecJoinBndrs env cont_dup_res_ty bndrs'
+                -- NB: bndrs' don't have unfoldings or rules
+                -- We add them as we go down
+        ; (env2, cont_dup, cont_nodup) <- prepareLetCont (zapJoinFloats env1)
+                                                         bndrs' cont
+        ; MASSERT2(cont_dup_res_ty `eqType` contResultType cont_dup,
+            ppr cont_dup_res_ty $$ blankLine $$
+            ppr cont $$ blankLine $$
+            ppr cont_dup $$ blankLine $$
+            ppr cont_nodup)
+        ; env3 <- simplRecBind env2 NotTopLevel (Just cont_dup) pairs'
+        ; (env4, expr) <- simplExprF env3 body cont_dup
+        ; rebuild (env4 `restoreJoinFloats` env1)
+                  (wrapJoinFloats env4 expr) cont_nodup }
+  | otherwise
+  = do  { let bndrs = map fst pairs
+        ; MASSERT(all (not . isJoinId) bndrs)
+        ; env1 <- simplRecBndrs env bndrs
+                -- NB: bndrs' don't have unfoldings or rules
+                -- We add them as we go down
+        ; env2 <- simplRecBind env1 NotTopLevel (Just cont) pairs
+        ; simplExprF env2 body cont }
+
+-- | Perform the conversion of a value binding to a join point if it's marked
+-- as 'AlwaysTailCalled'. If it's already a join point, return it as is.
+-- Otherwise return 'Nothing'.
+matchOrConvertToJoinPoint :: InBndr -> InExpr -> Maybe (JoinId, InExpr)
+matchOrConvertToJoinPoint bndr rhs
+  | not (isId bndr)
+  = Nothing
+  | isJoinId bndr
+  = -- No point in keeping tailCallInfo around; very fragile
+    Just (zapIdTailCallInfo bndr, rhs)
+  | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
+  , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
+  = Just (zapIdTailCallInfo (bndr `asJoinId` join_arity),
+          mkLams bndrs body)
+  | otherwise
+  = Nothing
+
+matchOrConvertToJoinPoints :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]
+matchOrConvertToJoinPoints bndrs
+  = mapM (uncurry matchOrConvertToJoinPoint) bndrs
+
 {-
 ************************************************************************
 *                                                                      *
@@ -1351,9 +1558,11 @@ simplVar env var
 simplIdF :: SimplEnv -> InId -> SimplCont -> SimplM (SimplEnv, OutExpr)
 simplIdF env var cont
   = case substId env var of
-        DoneEx e             -> simplExprF (zapSubstEnv env) e cont
+        DoneEx e             -> simplExprF (zapSubstEnv env) e trimmed_cont
         ContEx tvs cvs ids e -> simplExprF (setSubstEnv env tvs cvs ids) e cont
-        DoneId var1          -> completeCall env var1 cont
+                                  -- Don't trim; haven't already simplified
+                                  -- the join, so the cont was never copied
+        DoneId var1          -> completeCall env var1 trimmed_cont
                 -- Note [zapSubstEnv]
                 -- The template is already simplified, so don't re-substitute.
                 -- This is VITAL.  Consider
@@ -1363,6 +1572,24 @@ simplIdF env var cont
                 -- We'll clone the inner \x, adding x->x' in the id_subst
                 -- Then when we inline y, we must *not* replace x by x' in
                 -- the inlined copy!!
+  where
+    trimmed_cont | Just arity <- isJoinIdInEnv_maybe env var
+                 = trim_cont arity cont
+                 | otherwise
+                 = cont
+
+    -- Drop outer context from join point invocation
+    -- Note [Case-of-case and join points]
+    trim_cont 0 cont@(Stop {})
+      = cont
+    trim_cont 0 cont
+      = mkBoringStop (contResultType cont)
+    trim_cont n cont@(ApplyToVal { sc_cont = k })
+      = cont { sc_cont = trim_cont (n-1) k }
+    trim_cont n cont@(ApplyToTy { sc_cont = k })
+      = cont { sc_cont = trim_cont (n-1) k } -- join arity counts types!
+    trim_cont _ cont
+      = pprPanic "completeCall" $ ppr var $$ ppr cont
 
 ---------------------------------------------------------
 --      Dealing with a call site
@@ -1935,7 +2162,8 @@ rebuildCase env scrut case_bndr alts cont
 reallyRebuildCase env scrut case_bndr alts cont
   = do  {       -- Prepare the continuation;
                 -- The new subst_env is in place
-          (env', dup_cont, nodup_cont) <- prepareCaseCont env alts cont
+          (env', dup_cont, nodup_cont) <- prepareCaseCont (zapJoinFloats env)
+                                                          alts cont
 
         -- Simplify the alternatives
         ; (scrut', case_bndr', alts') <- simplAlts env' scrut case_bndr alts dup_cont
@@ -1947,7 +2175,8 @@ reallyRebuildCase env scrut case_bndr alts cont
         -- Notice that rebuild gets the in-scope set from env', not alt_env
         -- (which in any case is only build in simplAlts)
         -- The case binder *not* scope over the whole returned case-expression
-        ; rebuild env' case_expr nodup_cont }
+        ; rebuild (env' `restoreJoinFloats` env)
+                  (wrapJoinFloats env' case_expr) nodup_cont }
 
 {-
 simplCaseBinder checks whether the scrutinee is a variable, v.  If so,
@@ -2348,23 +2577,87 @@ prepareCaseCont :: SimplEnv
 -- The idea is that we'll transform thus:
 --          Knodup[ (case _ of { p1 -> Kdup[r1]; ...; pn -> Kdup[rn] }
 --
--- We may also return some extra bindings in SimplEnv (that scope over
--- the entire continuation)
+-- We may also return some extra value bindings in SimplEnv (that scope over
+-- the entire continuation) as well as some join points (thus must *not* float
+-- past the continuation!).
+-- Hence, the full story is this:
+--     K[case _ of { p1 -> r1; ...; pn -> rn }] ==>
+--     F_v[Knodup[F_j[ (case _ of { p1 -> Kdup[r1]; ...; pn -> Kdup[rn] }) ]]]
+-- Here F_v represents some values that got floated out and F_j represents some
+-- join points that got floated out.
 --
 -- When case-of-case is off, just make the entire continuation non-dupable
 
 prepareCaseCont env alts cont
-  | not (sm_case_case (getMode env)) = return (env, mkBoringStop (contHoleType cont), cont)
-  | not (many_alts alts)             = return (env, cont, mkBoringStop (contResultType cont))
-  | otherwise                        = mkDupableCont env cont
-  where
-    many_alts :: [InAlt] -> Bool  -- True iff strictly > 1 non-bottom alternative
-    many_alts []  = False         -- See Note [Bottom alternatives]
-    many_alts [_] = False
-    many_alts (alt:alts)
-      | is_bot_alt alt = many_alts alts
-      | otherwise      = not (all is_bot_alt alts)
+  | not (sm_case_case (getMode env))
+  = return (env, mkBoringStop (contHoleType cont), cont)
+  | not (altsWouldDup alts)
+  = return (env, cont, mkBoringStop (contResultType cont))
+  | otherwise
+  = mkDupableCont env cont
+
+prepareLetCont :: SimplEnv
+               -> [InBndr] -> SimplCont
+               -> SimplM (SimplEnv,
+                          SimplCont,   -- Dupable part
+                          SimplCont)   -- Non-dupable part
+
+-- Similar to prepareCaseCont, only for
+--     K[let { j1 = r1; ...; jn -> rn } in _]
+-- If the js are join points, this will turn into
+--     Knodup[join { j1 = Kdup[r1]; ...; jn = Kdup[rn] } in Kdup[_]].
+--
+-- When case-of-case is off and it's a join binding, just make the entire
+-- continuation non-dupable. This is necessary because otherwise
+--     case (join j = ... in case e of { A -> jump j 1; ... }) of { B -> ... }
+-- becomes
+--     join j = case ... of { B -> ... } in
+--     case (case e of { A -> jump j 1; ... }) of { B -> ... },
+-- and the reference to j is invalid.
 
+prepareLetCont env bndrs cont
+  | not (isJoinId (head bndrs))
+  = return (env, cont, mkBoringStop (contResultType cont))
+  | not (sm_case_case (getMode env))
+  = return (env, mkBoringStop (contHoleType cont), cont)
+  | otherwise
+  = mkDupableCont env cont
+
+-- Predict the result type of the dupable cont returned by prepareLetCont (= the
+-- hole type of the non-dupable part). Ugly, but sadly necessary so that we can
+-- know what the new type of a recursive join point will be before we start
+-- simplifying it.
+resultTypeOfDupableCont :: SimplifierMode
+                        -> [InBndr]
+                        -> SimplCont
+                        -> OutType   -- INVARIANT: Result type of dupable cont
+                                     -- returned by prepareLetCont
+-- IMPORTANT: This must be kept in sync with mkDupableCont!
+resultTypeOfDupableCont mode bndrs cont
+  | not (any isJoinId bndrs)   = contResultType cont
+  | not (sm_case_case mode)    = contHoleType   cont
+  | otherwise                  = go cont
+  where
+    go cont | contIsDupable cont = contResultType cont
+    go (Stop {}) = panic "typeOfDupableCont" -- Handled by previous eqn
+    go (CastIt _  cont)     = go cont
+    go cont@(TickIt {})     = contHoleType cont
+    go cont@(StrictBind {}) = contHoleType cont
+    go (StrictArg _ _ cont) = go cont
+    go cont@(ApplyToTy  {}) = go (sc_cont cont)
+    go cont@(ApplyToVal {}) = go (sc_cont cont)
+    go (Select { sc_alts = alts, sc_cont = cont })
+      | not (sm_case_case mode) = contHoleType cont
+      | not (altsWouldDup alts) = contResultType cont
+      | otherwise               = go cont
+
+altsWouldDup :: [InAlt] -> Bool -- True iff strictly > 1 non-bottom alternative
+altsWouldDup []  = False        -- See Note [Bottom alternatives]
+altsWouldDup [_] = False
+altsWouldDup (alt:alts)
+  | is_bot_alt alt = altsWouldDup alts
+  | otherwise      = not (all is_bot_alt alts)
+  where
     is_bot_alt (_,_,rhs) = exprIsBottom rhs
 
 {-
@@ -2375,9 +2668,7 @@ When we have
        of alts
 then we can just duplicate those alts because the A and C cases
 will disappear immediately.  This is more direct than creating
-join points and inlining them away; and in some cases we would
-not even create the join points (see Note [Single-alternative case])
-and we would keep the case-of-case which is silly.  See Trac #4930.
+join points and inlining them away.  See Trac #4930.
 -}
 
 mkDupableCont :: SimplEnv -> SimplCont
@@ -2423,15 +2714,6 @@ mkDupableCont env (ApplyToVal { sc_arg = arg, sc_dup = dup, sc_env = se, sc_cont
                                     , sc_dup = OkToDup, sc_cont = dup_cont }
         ; return (env'', app_cont, nodup_cont) }
 
-mkDupableCont env cont@(Select { sc_bndr = case_bndr, sc_alts = [(_, bs, _rhs)] })
---  See Note [Single-alternative case]
---  | not (exprIsDupable rhs && contIsDupable case_cont)
---  | not (isDeadBinder case_bndr)
-  | all isDeadBinder bs  -- InIds
-    && not (isUnliftedType (idType case_bndr))
-    -- Note [Single-alternative-unlifted]
-  = return (env, mkBoringStop (contHoleType cont), cont)
-
 mkDupableCont env (Select { sc_bndr = case_bndr, sc_alts = alts
                           , sc_env = se, sc_cont = cont })
   =     -- e.g.         (case [...hole...] of { pi -> ei })
@@ -2509,19 +2791,16 @@ mkDupableAlt env case_bndr (con, bndrs', rhs') = do
                            -- The case binder is alive but trivial, so why has
                            -- it not been substituted away?
 
-              used_bndrs' | isDeadBinder case_bndr = filter abstract_over bndrs'
-                          | otherwise              = bndrs' ++ [case_bndr_w_unf]
+              final_bndrs'
+                | isDeadBinder case_bndr = filter abstract_over bndrs'
+                | otherwise              = bndrs' ++ [case_bndr_w_unf]
 
               abstract_over bndr
                   | isTyVar bndr = True -- Abstract over all type variables just in case
                   | otherwise    = not (isDeadBinder bndr)
                         -- The deadness info on the new Ids is preserved by simplBinders
-
-        ; (final_bndrs', final_args)    -- Note [Join point abstraction]
-                <- if (any isId used_bndrs')
-                   then return (used_bndrs', varsToCoreExprs used_bndrs')
-                    else do { rw_id <- newId (fsLit "w") voidPrimTy
-                            ; return ([setOneShotLambda rw_id], [Var voidPrimId]) }
+              final_args    -- Note [Join point abstraction]
+                = varsToCoreExprs final_bndrs'
 
         ; join_bndr <- newId (fsLit "$j") (mkLamTypes final_bndrs' rhs_ty')
                 -- Note [Funky mkLamTypes]
@@ -2534,10 +2813,14 @@ mkDupableAlt env case_bndr (con, bndrs', rhs') = do
                 one_shot v | isId v    = setOneShotLambda v
                            | otherwise = v
                 join_rhs   = mkLams really_final_bndrs rhs'
-                join_arity = exprArity join_rhs
-                join_call  = mkApps (Var join_bndr) final_args
-
-        ; env' <- addPolyBind NotTopLevel env (NonRec (join_bndr `setIdArity` join_arity) join_rhs)
+                arity      = length (filter (not . isTyVar) final_bndrs')
+                join_arity = length final_bndrs'
+                final_join_bndr = (join_bndr `setIdArity` arity)
+                                    `asJoinId` join_arity
+                join_call  = mkApps (Var final_join_bndr) final_args
+                final_join_bind = NonRec final_join_bndr join_rhs
+
+        ; env' <- addPolyBind NotTopLevel env final_join_bind
         ; return (env', (con, bndrs', join_call)) }
                 -- See Note [Duplicated env]
 
@@ -2660,6 +2943,12 @@ type variables as well as term variables.
 
 Note [Join point abstraction]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+NB: This note is now historical. Now that "join point" is not a fuzzy concept
+but a formal syntactic construct (as distinguished by the JoinId constructor of
+IdDetails), each of these concerns is handled separately, with no need for a
+vestigial extra argument.
+
 Join points always have at least one value argument,
 for several reasons
 
@@ -2769,114 +3058,6 @@ Unlike StrictArg, there doesn't seem anything to gain from
 duplicating a StrictBind continuation, so we don't.
 
 
-Note [Single-alternative cases]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-This case is just like the ArgOf case.  Here's an example:
-        data T a = MkT !a
-        ...(MkT (abs x))...
-Then we get
-        case (case x of I# x' ->
-              case x' <# 0# of
-                True  -> I# (negate# x')
-                False -> I# x') of y {
-          DEFAULT -> MkT y
-Because the (case x) has only one alternative, we'll transform to
-        case x of I# x' ->
-        case (case x' <# 0# of
-                True  -> I# (negate# x')
-                False -> I# x') of y {
-          DEFAULT -> MkT y
-But now we do *NOT* want to make a join point etc, giving
-        case x of I# x' ->
-        let $j = \y -> MkT y
-        in case x' <# 0# of
-                True  -> $j (I# (negate# x'))
-                False -> $j (I# x')
-In this case the $j will inline again, but suppose there was a big
-strict computation enclosing the orginal call to MkT.  Then, it won't
-"see" the MkT any more, because it's big and won't get duplicated.
-And, what is worse, nothing was gained by the case-of-case transform.
-
-So, in circumstances like these, we don't want to build join points
-and push the outer case into the branches of the inner one. Instead,
-don't duplicate the continuation.
-
-When should we use this strategy?  We should not use it on *every*
-single-alternative case:
-  e.g.  case (case ....) of (a,b) -> (# a,b #)
-Here we must push the outer case into the inner one!
-Other choices:
-
-   * Match [(DEFAULT,_,_)], but in the common case of Int,
-     the alternative-filling-in code turned the outer case into
-                case (...) of y { I# _ -> MkT y }
-
-   * Match on single alternative plus (not (isDeadBinder case_bndr))
-     Rationale: pushing the case inwards won't eliminate the construction.
-     But there's a risk of
-                case (...) of y { (a,b) -> let z=(a,b) in ... }
-     Now y looks dead, but it'll come alive again.  Still, this
-     seems like the best option at the moment.
-
-   * Match on single alternative plus (all (isDeadBinder bndrs))
-     Rationale: this is essentially  seq.
-
-   * Match when the rhs is *not* duplicable, and hence would lead to a
-     join point.  This catches the disaster-case above.  We can test
-     the *un-simplified* rhs, which is fine.  It might get bigger or
-     smaller after simplification; if it gets smaller, this case might
-     fire next time round.  NB also that we must test contIsDupable
-     case_cont *too, because case_cont might be big!
-
-     HOWEVER: I found that this version doesn't work well, because
-     we can get         let x = case (...) of { small } in ...case x...
-     When x is inlined into its full context, we find that it was a bad
-     idea to have pushed the outer case inside the (...) case.
-
-There is a cost to not doing case-of-case; see Trac #10626.
-
-Note [Single-alternative-unlifted]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Here's another single-alternative where we really want to do case-of-case:
-
-data Mk1 = Mk1 Int# | Mk2 Int#
-
-M1.f =
-    \r [x_s74 y_s6X]
-        case
-            case y_s6X of tpl_s7m {
-              M1.Mk1 ipv_s70 -> ipv_s70;
-              M1.Mk2 ipv_s72 -> ipv_s72;
-            }
-        of
-        wild_s7c
-        { __DEFAULT ->
-              case
-                  case x_s74 of tpl_s7n {
-                    M1.Mk1 ipv_s77 -> ipv_s77;
-                    M1.Mk2 ipv_s79 -> ipv_s79;
-                  }
-              of
-              wild1_s7b
-              { __DEFAULT -> ==# [wild1_s7b wild_s7c];
-              };
-        };
-
-So the outer case is doing *nothing at all*, other than serving as a
-join-point.  In this case we really want to do case-of-case and decide
-whether to use a real join point or just duplicate the continuation:
-
-    let $j s7c = case x of
-                   Mk1 ipv77 -> (==) s7c ipv77
-                   Mk1 ipv79 -> (==) s7c ipv79
-    in
-    case y of
-      Mk1 ipv70 -> $j ipv70
-      Mk2 ipv72 -> $j ipv72
-
-Hence: check whether the case binder's type is unlifted, because then
-the outer case is *not* a seq.
-
 ************************************************************************
 *                                                                      *
                     Unfoldings
@@ -2885,12 +3066,13 @@ the outer case is *not* a seq.
 -}
 
 simplLetUnfolding :: SimplEnv-> TopLevelFlag
+                  -> Maybe