Allow CSE'ing of work-wrapped bindings (#14186)
[ghc.git] / compiler / simplCore / SimplEnv.hs
index 8a26220..21ba4bc 100644 (file)
@@ -8,35 +8,45 @@
 
 module SimplEnv (
         -- * The simplifier mode
-        setMode, getMode, updMode,
+        setMode, getMode, updMode, seDynFlags,
 
         -- * Environments
         SimplEnv(..), StaticEnv, pprSimplEnv,   -- Temp not abstract
         mkSimplEnv, extendIdSubst,
         SimplEnv.extendTvSubst, SimplEnv.extendCvSubst,
         zapSubstEnv, setSubstEnv,
-        getInScope, setInScope, setInScopeSet, modifyInScope, addNewInScopeIds,
+        getInScope, setInScopeFromE, setInScopeFromF,
+        setInScopeSet, modifyInScope, addNewInScopeIds,
         getSimplRules,
 
         -- * Substitution results
         SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
 
         -- * Simplifying 'Id' binders
-        simplNonRecBndr, simplRecBndrs,
+        simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
         simplBinder, simplBinders,
         substTy, substTyVar, getTCvSubst,
         substCo, substCoVar,
 
         -- * Floats
-        Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, extendFloats,
-        wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
-        doFloatFromRhs, getFloatBinds
+        SimplFloats(..), emptyFloats, mkRecFloats,
+        mkFloatBind, addLetFloats, addJoinFloats, addFloats,
+        extendFloats, wrapFloats,
+        doFloatFromRhs, getTopFloatBinds,
+
+        -- * LetFloats
+        LetFloats, letFloatBinds, emptyLetFloats, unitLetFloat,
+        addLetFlts,  mapLetFloats,
+
+        -- * JoinFloats
+        JoinFloat, JoinFloats, emptyJoinFloats,
+        wrapJoinFloats, wrapJoinFloatsX, unitJoinFloat, addJoinFlts
     ) where
 
 #include "HsVersions.h"
 
 import SimplMonad
-import CoreMonad                ( SimplifierMode(..) )
+import CoreMonad                ( SimplMode(..) )
 import CoreSyn
 import CoreUtils
 import Var
@@ -45,6 +55,7 @@ import VarSet
 import OrdList
 import Id
 import MkCore                   ( mkWildValBinder )
+import DynFlags                 ( DynFlags )
 import TysWiredIn
 import qualified Type
 import Type hiding              ( substTy, substTyVar, substTyVarBndr )
@@ -54,6 +65,7 @@ import BasicTypes
 import MonadUtils
 import Outputable
 import Util
+import UniqFM                   ( pprUniqFM )
 
 import Data.List
 
@@ -71,12 +83,12 @@ data SimplEnv
      -- Static in the sense of lexically scoped,
      -- wrt the original expression
 
-        seMode      :: SimplifierMode,
+        seMode      :: SimplMode
 
         -- The current substitution
-        seTvSubst   :: TvSubstEnv,      -- InTyVar |--> OutType
-        seCvSubst   :: CvSubstEnv,      -- InCoVar |--> OutCoercion
-        seIdSubst   :: SimplIdSubst,    -- InId    |--> OutExpr
+      , seTvSubst   :: TvSubstEnv      -- InTyVar |--> OutType
+      , seCvSubst   :: CvSubstEnv      -- InCoVar |--> OutCoercion
+      , seIdSubst   :: SimplIdSubst    -- InId    |--> OutExpr
 
      ----------- Dynamic part of the environment -----------
      -- Dynamic in the sense of describing the setup where
@@ -84,43 +96,91 @@ data SimplEnv
 
         -- The current set of in-scope variables
         -- They are all OutVars, and all bound in this module
-        seInScope   :: InScopeSet,      -- OutVars only
-                -- Includes all variables bound by seFloats
-        seFloats    :: Floats
-                -- See Note [Simplifier floats]
+      , seInScope   :: InScopeSet       -- OutVars only
     }
 
 type StaticEnv = SimplEnv       -- Just the static part is relevant
 
+data SimplFloats
+  = SimplFloats
+      { -- Ordinary let bindings
+        sfLetFloats  :: LetFloats
+                -- See Note [LetFloats]
+
+        -- Join points
+      , sfJoinFloats :: JoinFloats
+                -- Handled separately; they don't go very far
+                -- We consider these to be /inside/ sfLetFloats
+                -- because join points can refer to ordinary bindings,
+                -- but not vice versa
+
+        -- Includes all variables bound by sfLetFloats and
+        -- sfJoinFloats, plus at least whatever is in scope where
+        -- these bindings land up.
+      , sfInScope :: InScopeSet  -- All OutVars
+      }
+
+instance Outputable SimplFloats where
+  ppr (SimplFloats { sfLetFloats = lf, sfJoinFloats = jf, sfInScope = is })
+    = text "SimplFloats"
+      <+> braces (vcat [ text "lets: " <+> ppr lf
+                       , text "joins:" <+> ppr jf
+                       , text "in_scope:" <+> ppr is ])
+
+emptyFloats :: SimplEnv -> SimplFloats
+emptyFloats env
+  = SimplFloats { sfLetFloats  = emptyLetFloats
+                , sfJoinFloats = emptyJoinFloats
+                , sfInScope    = seInScope env }
+
 pprSimplEnv :: SimplEnv -> SDoc
 -- Used for debugging; selective
 pprSimplEnv env
   = vcat [text "TvSubst:" <+> ppr (seTvSubst env),
           text "CvSubst:" <+> ppr (seCvSubst env),
-          text "IdSubst:" <+> ppr (seIdSubst env),
+          text "IdSubst:" <+> id_subst_doc,
           text "InScope:" <+> in_scope_vars_doc
     ]
   where
+   id_subst_doc = pprUniqFM ppr (seIdSubst env)
    in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
                                  (vcat . map ppr_one)
    ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
              | otherwise = ppr v
 
-type SimplIdSubst = IdEnv SimplSR       -- IdId |--> OutExpr
+type SimplIdSubst = IdEnv SimplSR -- IdId |--> OutExpr
         -- See Note [Extending the Subst] in CoreSubst
 
 -- | A substitution result.
 data SimplSR
-  = DoneEx OutExpr              -- Completed term
-  | DoneId OutId                -- Completed term variable
-  | ContEx TvSubstEnv           -- A suspended substitution
+  = DoneEx OutExpr (Maybe JoinArity)
+       -- If  x :-> DoneEx e ja   is in the SimplIdSubst
+       -- then replace occurrences of x by e
+       -- and  ja = Just a <=> x is a join-point of arity a
+       -- See Note [Join arity in SimplIdSubst]
+
+
+  | DoneId OutId
+       -- If  x :-> DoneId v   is in the SimplIdSubst
+       -- then replace occurrences of x by v
+       -- and  v is a join-point of arity a
+       --      <=> x is a join-point of arity a
+
+  | ContEx TvSubstEnv                 -- A suspended substitution
            CvSubstEnv
            SimplIdSubst
            InExpr
+      -- If   x :-> ContEx tv cv id e   is in the SimplISubst
+      -- then replace occurrences of x by (subst (tv,cv,id) e)
 
 instance Outputable SimplSR where
-  ppr (DoneEx e) = text "DoneEx" <+> ppr e
-  ppr (DoneId v) = text "DoneId" <+> ppr v
+  ppr (DoneId v)    = text "DoneId" <+> ppr v
+  ppr (DoneEx e mj) = text "DoneEx" <> pp_mj <+> ppr e
+    where
+      pp_mj = case mj of
+                Nothing -> empty
+                Just n  -> parens (int n)
+
   ppr (ContEx _tv _cv _id e) = vcat [text "ContEx" <+> ppr e {-,
                                 ppr (filter_env tv), ppr (filter_env id) -}]
         -- where
@@ -192,13 +252,24 @@ seIdSubst:
   map to the same target:  x->x, y->x.  Notably:
         case y of x { ... }
   That's why the "set" is actually a VarEnv Var
--}
 
-mkSimplEnv :: SimplifierMode -> SimplEnv
+Note [Join arity in SimplIdSubst]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We have to remember which incoming variables are join points: the occurrences
+may not be marked correctly yet, and we're in change of propagating the change if
+OccurAnal makes something a join point).
+
+Normally the in-scope set is where we keep the latest information, but
+the in-scope set tracks only OutVars; if a binding is unconditionally
+inlined (via DoneEx), it never makes it into the in-scope set, and we
+need to know at the occurrence site that the variable is a join point
+so that we know to drop the context. Thus we remember which join
+points we're substituting. -}
+
+mkSimplEnv :: SimplMode -> SimplEnv
 mkSimplEnv mode
   = SimplEnv { seMode = mode
              , seInScope = init_in_scope
-             , seFloats = emptyFloats
              , seTvSubst = emptyVarEnv
              , seCvSubst = emptyVarEnv
              , seIdSubst = emptyVarEnv }
@@ -218,7 +289,7 @@ occurrences of 'wild-id' (with wildCardKey).  The easy
 way to do that is to start of with a representative
 Id in the in-scope set
 
-There can be be *occurrences* of wild-id.  For example,
+There can be *occurrences* of wild-id.  For example,
 MkCore.mkCoreApp transforms
    e (a /# b)   -->   case (a /# b) of wild { DEFAULT -> e wild }
 This is ok provided 'wild' isn't free in 'e', and that's the delicate
@@ -228,20 +299,23 @@ wild-ids before doing much else.
 It's a very dark corner of GHC.  Maybe it should be cleaned up.
 -}
 
-getMode :: SimplEnv -> SimplifierMode
+getMode :: SimplEnv -> SimplMode
 getMode env = seMode env
 
-setMode :: SimplifierMode -> SimplEnv -> SimplEnv
+seDynFlags :: SimplEnv -> DynFlags
+seDynFlags env = sm_dflags (seMode env)
+
+setMode :: SimplMode -> SimplEnv -> SimplEnv
 setMode mode env = env { seMode = mode }
 
-updMode :: (SimplifierMode -> SimplifierMode) -> SimplEnv -> SimplEnv
+updMode :: (SimplMode -> SimplMode) -> SimplEnv -> SimplEnv
 updMode upd env = env { seMode = upd (seMode env) }
 
 ---------------------
 extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
 extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
   = ASSERT2( isId var && not (isCoVar var), ppr var )
-    env {seIdSubst = extendVarEnv subst var res}
+    env { seIdSubst = extendVarEnv subst var res }
 
 extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
 extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
@@ -260,17 +334,11 @@ getInScope env = seInScope env
 setInScopeSet :: SimplEnv -> InScopeSet -> SimplEnv
 setInScopeSet env in_scope = env {seInScope = in_scope}
 
-setInScope :: SimplEnv -> SimplEnv -> SimplEnv
--- Set the in-scope set, and *zap* the floats
-setInScope env env_with_scope
-  = env { seInScope = seInScope env_with_scope,
-          seFloats = emptyFloats }
+setInScopeFromE :: SimplEnv -> SimplEnv -> SimplEnv
+setInScopeFromE env env' = env { seInScope = seInScope env' }
 
-setFloats :: SimplEnv -> SimplEnv -> SimplEnv
--- Set the in-scope set *and* the floats
-setFloats env env_with_floats
-  = env { seInScope = seInScope env_with_floats,
-          seFloats  = seFloats  env_with_floats }
+setInScopeFromF :: SimplEnv -> SimplFloats -> SimplEnv
+setInScopeFromF env floats = env { seInScope = sfInScope floats }
 
 addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
         -- The new Ids are guaranteed to be freshly allocated
@@ -303,13 +371,13 @@ mkContEx (SimplEnv { seTvSubst = tvs, seCvSubst = cvs, seIdSubst = ids }) e = Co
 {-
 ************************************************************************
 *                                                                      *
-\subsection{Floats}
+\subsection{LetFloats}
 *                                                                      *
 ************************************************************************
 
-Note [Simplifier floats]
-~~~~~~~~~~~~~~~~~~~~~~~~~
-The Floats is a bunch of bindings, classified by a FloatFlag.
+Note [LetFloats]
+~~~~~~~~~~~~~~~~
+The LetFloats is a bunch of bindings, classified by a FloatFlag.
 
 * All of them satisfy the let/app invariant
 
@@ -328,11 +396,15 @@ Can't happen:
   NonRec x# (f y)       -- Might diverge; does not satisfy let/app
 -}
 
-data Floats = Floats (OrdList OutBind) FloatFlag
-        -- See Note [Simplifier floats]
+data LetFloats = LetFloats (OrdList OutBind) FloatFlag
+                 -- See Note [LetFloats]
+
+type JoinFloat  = OutBind
+type JoinFloats = OrdList JoinFloat
 
 data FloatFlag
-  = FltLifted   -- All bindings are lifted and lazy
+  = FltLifted   -- All bindings are lifted and lazy *or*
+                --     consist of a single primitive string literal
                 --  Hence ok to float to top level, or recursive
 
   | FltOkSpec   -- All bindings are FltLifted *or*
@@ -347,12 +419,12 @@ data FloatFlag
                 --      and not guaranteed cheap
                 --      Do not float these bindings out of a lazy let
 
-instance Outputable Floats where
-  ppr (Floats binds ff) = ppr ff $$ ppr (fromOL binds)
+instance Outputable LetFloats where
+  ppr (LetFloats binds ff) = ppr ff $$ ppr (fromOL binds)
 
 instance Outputable FloatFlag where
-  ppr FltLifted = text "FltLifted"
-  ppr FltOkSpec = text "FltOkSpec"
+  ppr FltLifted  = text "FltLifted"
+  ppr FltOkSpec  = text "FltOkSpec"
   ppr FltCareful = text "FltCareful"
 
 andFF :: FloatFlag -> FloatFlag -> FloatFlag
@@ -361,9 +433,9 @@ andFF FltOkSpec  FltCareful = FltCareful
 andFF FltOkSpec  _          = FltOkSpec
 andFF FltLifted  flt        = flt
 
-doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> OutExpr -> SimplEnv -> Bool
+doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> SimplFloats -> OutExpr -> Bool
 -- If you change this function look also at FloatIn.noFloatFromRhs
-doFloatFromRhs lvl rec str rhs (SimplEnv {seFloats = Floats fs ff})
+doFloatFromRhs lvl rec str (SimplFloats { sfLetFloats = LetFloats fs ff }) rhs
   =  not (isNilOL fs) && want_to_float && can_float
   where
      want_to_float = isTopLevel lvl || exprIsCheap rhs || exprIsExpandable rhs
@@ -385,81 +457,158 @@ But there are
 so we must take the 'or' of the two.
 -}
 
-emptyFloats :: Floats
-emptyFloats = Floats nilOL FltLifted
+emptyLetFloats :: LetFloats
+emptyLetFloats = LetFloats nilOL FltLifted
 
-unitFloat :: OutBind -> Floats
+emptyJoinFloats :: JoinFloats
+emptyJoinFloats = nilOL
+
+unitLetFloat :: OutBind -> LetFloats
 -- This key function constructs a singleton float with the right form
-unitFloat bind = Floats (unitOL bind) (flag bind)
+unitLetFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
+                    LetFloats (unitOL bind) (flag bind)
   where
     flag (Rec {})                = FltLifted
     flag (NonRec bndr rhs)
       | not (isStrictId bndr)    = FltLifted
+      | exprIsLiteralString rhs  = FltLifted
+          -- String literals can be floated freely.
+          -- See Note [CoreSyn top-level string ltierals] in CoreSyn.
       | exprOkForSpeculation rhs = FltOkSpec  -- Unlifted, and lifted but ok-for-spec (eg HNF)
       | otherwise                = ASSERT2( not (isUnliftedType (idType bndr)), ppr bndr )
                                    FltCareful
       -- Unlifted binders can only be let-bound if exprOkForSpeculation holds
 
-addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
--- Add a non-recursive binding and extend the in-scope set
--- The latter is important; the binder may already be in the
--- in-scope set (although it might also have been created with newId)
--- but it may now have more IdInfo
-addNonRec env id rhs
-  = id `seq`   -- This seq forces the Id, and hence its IdInfo,
-               -- and hence any inner substitutions
-    env { seFloats = seFloats env `addFlts` unitFloat (NonRec id rhs),
-          seInScope = extendInScopeSet (seInScope env) id }
-
-extendFloats :: SimplEnv -> OutBind -> SimplEnv
--- Add these bindings to the floats, and extend the in-scope env too
-extendFloats env bind
-  = env { seFloats  = seFloats env `addFlts` unitFloat bind,
-          seInScope = extendInScopeSetList (seInScope env) bndrs }
+unitJoinFloat :: OutBind -> JoinFloats
+unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
+                     unitOL bind
+
+mkFloatBind :: SimplEnv -> OutBind -> (SimplFloats, SimplEnv)
+-- Make a singleton SimplFloats, and
+-- extend the incoming SimplEnv's in-scope set with its binders
+-- These binders may already be in the in-scope set,
+-- but may have by now been augmented with more IdInfo
+mkFloatBind env bind
+  = (floats, env { seInScope = in_scope' })
+  where
+    floats
+      | isJoinBind bind
+      = SimplFloats { sfLetFloats  = emptyLetFloats
+                    , sfJoinFloats = unitJoinFloat bind
+                    , sfInScope    = in_scope' }
+      | otherwise
+      = SimplFloats { sfLetFloats  = unitLetFloat bind
+                    , sfJoinFloats = emptyJoinFloats
+                    , sfInScope    = in_scope' }
+
+    in_scope' = seInScope env `extendInScopeSetBind` bind
+
+extendFloats :: SimplFloats -> OutBind -> SimplFloats
+-- Add this binding to the floats, and extend the in-scope env too
+extendFloats (SimplFloats { sfLetFloats  = floats
+                          , sfJoinFloats = jfloats
+                          , sfInScope    = in_scope })
+             bind
+  | isJoinBind bind
+  = SimplFloats { sfInScope    = in_scope'
+                , sfLetFloats  = floats
+                , sfJoinFloats = jfloats' }
+  | otherwise
+  = SimplFloats { sfInScope    = in_scope'
+                , sfLetFloats  = floats'
+                , sfJoinFloats = jfloats }
   where
-    bndrs = bindersOf bind
+    in_scope' = in_scope `extendInScopeSetBind` bind
+    floats'   = floats  `addLetFlts`  unitLetFloat bind
+    jfloats'  = jfloats `addJoinFlts` unitJoinFloat bind
 
-addFloats :: SimplEnv -> SimplEnv -> SimplEnv
--- Add the floats for env2 to env1;
+addLetFloats :: SimplFloats -> LetFloats -> SimplFloats
+-- Add the let-floats for env2 to env1;
+-- *plus* the in-scope set for env2, which is bigger
+-- than that for env1
+addLetFloats floats let_floats@(LetFloats binds _)
+  = floats { sfLetFloats = sfLetFloats floats `addLetFlts` let_floats
+           , sfInScope   = foldlOL extendInScopeSetBind
+                                   (sfInScope floats) binds }
+
+addJoinFloats :: SimplFloats -> JoinFloats -> SimplFloats
+addJoinFloats floats join_floats
+  = floats { sfJoinFloats = sfJoinFloats floats `addJoinFlts` join_floats
+           , sfInScope    = foldlOL extendInScopeSetBind
+                                    (sfInScope floats) join_floats }
+
+extendInScopeSetBind :: InScopeSet -> CoreBind -> InScopeSet
+extendInScopeSetBind in_scope bind
+  = extendInScopeSetList in_scope (bindersOf bind)
+
+addFloats :: SimplFloats -> SimplFloats -> SimplFloats
+-- Add both let-floats and join-floats for env2 to env1;
 -- *plus* the in-scope set for env2, which is bigger
 -- than that for env1
-addFloats env1 env2
-  = env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
-          seInScope = seInScope env2 }
+addFloats (SimplFloats { sfLetFloats = lf1, sfJoinFloats = jf1 })
+          (SimplFloats { sfLetFloats = lf2, sfJoinFloats = jf2, sfInScope = in_scope })
+  = SimplFloats { sfLetFloats  = lf1 `addLetFlts` lf2
+                , sfJoinFloats = jf1 `addJoinFlts` jf2
+                , sfInScope    = in_scope }
+
+addLetFlts :: LetFloats -> LetFloats -> LetFloats
+addLetFlts (LetFloats bs1 l1) (LetFloats bs2 l2)
+  = LetFloats (bs1 `appOL` bs2) (l1 `andFF` l2)
 
-addFlts :: Floats -> Floats -> Floats
-addFlts (Floats bs1 l1) (Floats bs2 l2)
-  = Floats (bs1 `appOL` bs2) (l1 `andFF` l2)
+letFloatBinds :: LetFloats -> [CoreBind]
+letFloatBinds (LetFloats bs _) = fromOL bs
 
-zapFloats :: SimplEnv -> SimplEnv
-zapFloats env = env { seFloats = emptyFloats }
+addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
+addJoinFlts = appOL
 
-addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
+mkRecFloats :: SimplFloats -> SimplFloats
 -- Flattens the floats from env2 into a single Rec group,
--- prepends the floats from env1, and puts the result back in env2
--- This is all very specific to the way recursive bindings are
--- handled; see Simplify.simplRecBind
-addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff})
+-- They must either all be lifted LetFloats or all JoinFloats
+mkRecFloats floats@(SimplFloats { sfLetFloats  = LetFloats bs ff
+                                , sfJoinFloats = jbs
+                                , sfInScope    = in_scope })
   = ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
-    env2 {seFloats = seFloats env1 `addFlts` unitFloat (Rec (flattenBinds (fromOL bs)))}
+    ASSERT2( isNilOL bs || isNilOL jbs, ppr floats )
+    SimplFloats { sfLetFloats  = floats'
+                , sfJoinFloats = jfloats'
+                , sfInScope    = in_scope }
+  where
+    floats'  | isNilOL bs  = emptyLetFloats
+             | otherwise   = unitLetFloat (Rec (flattenBinds (fromOL bs)))
+    jfloats' | isNilOL jbs = emptyJoinFloats
+             | otherwise   = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
 
-wrapFloats :: SimplEnv -> OutExpr -> OutExpr
+wrapFloats :: SimplFloats -> OutExpr -> OutExpr
 -- Wrap the floats around the expression; they should all
 -- satisfy the let/app invariant, so mkLets should do the job just fine
-wrapFloats (SimplEnv {seFloats = Floats bs _}) body
-  = foldrOL Let body bs
-
-getFloatBinds :: SimplEnv -> [CoreBind]
-getFloatBinds (SimplEnv {seFloats = Floats bs _})
-  = fromOL bs
-
-isEmptyFloats :: SimplEnv -> Bool
-isEmptyFloats (SimplEnv {seFloats = Floats bs _})
-  = isNilOL bs
-
-mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
-mapFloats env@SimplEnv { seFloats = Floats fs ff } fun
-   = env { seFloats = Floats (mapOL app fs) ff }
+wrapFloats (SimplFloats { sfLetFloats  = LetFloats bs _
+                        , sfJoinFloats = jbs }) body
+  = foldrOL Let (wrapJoinFloats jbs body) bs
+     -- Note: Always safe to put the joins on the inside
+     -- since the values can't refer to them
+
+wrapJoinFloatsX :: SimplFloats -> OutExpr -> (SimplFloats, OutExpr)
+-- Wrap the sfJoinFloats of the env around the expression,
+-- and take them out of the SimplEnv
+wrapJoinFloatsX floats body
+  = ( floats { sfJoinFloats = emptyJoinFloats }
+    , wrapJoinFloats (sfJoinFloats floats) body )
+
+wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
+-- Wrap the sfJoinFloats of the env around the expression,
+-- and take them out of the SimplEnv
+wrapJoinFloats join_floats body
+  = foldrOL Let body join_floats
+
+getTopFloatBinds :: SimplFloats -> [CoreBind]
+getTopFloatBinds (SimplFloats { sfLetFloats  = lbs
+                              , sfJoinFloats = jbs})
+  = ASSERT( isNilOL jbs )  -- Can't be any top-level join bindings
+    letFloatBinds lbs
+
+mapLetFloats :: LetFloats -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> LetFloats
+mapLetFloats (LetFloats fs ff) fun
+   = LetFloats (mapOL app fs) ff
    where
     app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
     app (Rec bs)     = Rec (map fun bs)
@@ -489,12 +638,12 @@ substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
   = case lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
         Nothing               -> DoneId (refineFromInScope in_scope v)
         Just (DoneId v)       -> DoneId (refineFromInScope in_scope v)
-        Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
         Just res              -> res    -- DoneEx non-var, or ContEx
 
         -- Get the most up-to-date thing from the in-scope set
         -- Even though it isn't in the substitution, it may be in
         -- the in-scope set with better IdInfo
+
 refineFromInScope :: InScopeSet -> Var -> Var
 refineFromInScope in_scope v
   | isLocalId v = case lookupInScope in_scope v of
@@ -535,33 +684,53 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 simplBinder env bndr
   | isTyVar bndr  = do  { let (env', tv) = substTyVarBndr env bndr
                         ; seqTyVar tv `seq` return (env', tv) }
-  | otherwise     = do  { let (env', id) = substIdBndr env bndr
+  | otherwise     = do  { let (env', id) = substIdBndr Nothing env bndr
                         ; seqId id `seq` return (env', id) }
 
 ---------------
 simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 -- A non-recursive let binder
 simplNonRecBndr env id
-  = do  { let (env1, id1) = substIdBndr env id
+  = do  { let (env1, id1) = substIdBndr Nothing env id
+        ; seqId id1 `seq` return (env1, id1) }
+
+---------------
+simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
+                    -> SimplM (SimplEnv, OutBndr)
+-- A non-recursive let binder for a join point; context being pushed inward may
+-- change the type
+simplNonRecJoinBndr env res_ty id
+  = do  { let (env1, id1) = substIdBndr (Just res_ty) env id
         ; seqId id1 `seq` return (env1, id1) }
 
 ---------------
 simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
 -- Recursive let binders
 simplRecBndrs env@(SimplEnv {}) ids
-  = do  { let (env1, ids1) = mapAccumL substIdBndr env ids
+  = ASSERT(all (not . isJoinId) ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
+        ; seqIds ids1 `seq` return env1 }
+
+---------------
+simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
+-- Recursive let binders for join points; context being pushed inward may
+-- change types
+simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
+  = ASSERT(all isJoinId ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
         ; seqIds ids1 `seq` return env1 }
 
 ---------------
-substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
+substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
 -- Might be a coercion variable
-substIdBndr env bndr
+substIdBndr new_res_ty env bndr
   | isCoVar bndr  = substCoVarBndr env bndr
-  | otherwise     = substNonCoVarIdBndr env bndr
+  | otherwise     = substNonCoVarIdBndr new_res_ty env bndr
 
 ---------------
 substNonCoVarIdBndr
-   :: SimplEnv
+   :: Maybe OutType -- New result type, if a join binder
+   -> SimplEnv
    -> InBndr    -- Env and binder to transform
    -> (SimplEnv, OutBndr)
 -- Clone Id if necessary, substitute its type
@@ -581,7 +750,9 @@ substNonCoVarIdBndr
 -- Similar to CoreSubst.substIdBndr, except that
 --      the type of id_subst differs
 --      all fragile info is zapped
-substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst })
+substNonCoVarIdBndr new_res_ty
+                    env@(SimplEnv { seInScope = in_scope
+                                  , seIdSubst = id_subst })
                     old_id
   = ASSERT2( not (isCoVar old_id), ppr old_id )
     (env { seInScope = in_scope `extendInScopeSet` new_id,
@@ -589,7 +760,11 @@ substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }
   where
     id1    = uniqAway in_scope old_id
     id2    = substIdType env id1
-    new_id = zapFragileIdInfo id2       -- Zaps rules, worker-info, unfolding
+    id3    | Just res_ty <- new_res_ty
+           = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
+           | otherwise
+           = id2
+    new_id = zapFragileIdInfo id3       -- Zaps rules, worker-info, unfolding
                                         -- and fragile OccInfo
 
         -- Extend the substitution if the unique has changed,
@@ -654,13 +829,14 @@ the letrec.
 {-
 ************************************************************************
 *                                                                      *
-                Impedence matching to type substitution
+                Impedance matching to type substitution
 *                                                                      *
 ************************************************************************
 -}
 
 getTCvSubst :: SimplEnv -> TCvSubst
-getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env })
+getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
+                      , seCvSubst = cv_env })
   = mkTCvSubst in_scope (tv_env, cv_env)
 
 substTy :: SimplEnv -> Type -> Type