Join points
[ghc.git] / compiler / simplCore / SimplEnv.hs
index 99d8291..f35d120 100644 (file)
@@ -20,17 +20,22 @@ module SimplEnv (
 
         -- * Substitution results
         SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
+        isJoinIdInEnv_maybe,
 
         -- * Simplifying 'Id' binders
-        simplNonRecBndr, simplRecBndrs,
+        simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
         simplBinder, simplBinders,
         substTy, substTyVar, getTCvSubst,
         substCo, substCoVar,
 
         -- * Floats
-        Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, extendFloats,
+        Floats, emptyFloats, isEmptyFloats,
+        addNonRec, addFloats, extendFloats,
         wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
-        doFloatFromRhs, getFloatBinds
+        doFloatFromRhs, getFloatBinds,
+
+        JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
+        wrapJoinFloats, zapJoinFloats, restoreJoinFloats, getJoinFloatBinds,
     ) where
 
 #include "HsVersions.h"
@@ -54,6 +59,7 @@ import BasicTypes
 import MonadUtils
 import Outputable
 import Util
+import UniqFM                   ( pprUniqFM )
 
 import Data.List
 
@@ -86,8 +92,10 @@ data SimplEnv
         -- They are all OutVars, and all bound in this module
         seInScope   :: InScopeSet,      -- OutVars only
                 -- Includes all variables bound by seFloats
-        seFloats    :: Floats
+        seFloats    :: Floats,
                 -- See Note [Simplifier floats]
+        seJoinFloats :: JoinFloats
+                -- Handled separately; they don't go very far
     }
 
 type StaticEnv = SimplEnv       -- Just the static part is relevant
@@ -97,17 +105,24 @@ pprSimplEnv :: SimplEnv -> SDoc
 pprSimplEnv env
   = vcat [text "TvSubst:" <+> ppr (seTvSubst env),
           text "CvSubst:" <+> ppr (seCvSubst env),
-          text "IdSubst:" <+> ppr (seIdSubst env),
+          text "IdSubst:" <+> id_subst_doc,
           text "InScope:" <+> in_scope_vars_doc
     ]
   where
+   id_subst_doc = pprUniqFM ppr_id_subst (seIdSubst env)
+   ppr_id_subst (m_ar, sr) = arity_part <+> ppr sr
+     where arity_part = case m_ar of Just ar -> brackets $
+                                                  text "join" <+> int ar
+                                     Nothing -> empty
+
    in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
                                  (vcat . map ppr_one)
    ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
              | otherwise = ppr v
 
-type SimplIdSubst = IdEnv SimplSR       -- IdId |--> OutExpr
+type SimplIdSubst = IdEnv (Maybe JoinArity, SimplSR) -- IdId |--> OutExpr
         -- See Note [Extending the Subst] in CoreSubst
+        -- See Note [Join arity in SimplIdSubst]
 
 -- | A substitution result.
 data SimplSR
@@ -192,6 +207,20 @@ seIdSubst:
   map to the same target:  x->x, y->x.  Notably:
         case y of x { ... }
   That's why the "set" is actually a VarEnv Var
+
+Note [Join arity in SimplIdSubst]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We have to remember which incoming variables are join points (the occurrences
+may not be marked correctly yet; we're in change of propagating the change if
+OccurAnal makes something a join point). Normally the in-scope set is where we
+keep the latest information, but the in-scope set tracks only OutVars; if a
+binding is unconditionally inlined, it never makes it into the in-scope set,
+and we need to know at the occurrence site that the variable is a join point so
+that we know to drop the context. Thus we remember which join points we're
+substituting. Clumsily, finding whether an InVar is a join variable may require
+looking in both the substitution *and* the in-scope set (see
+'isJoinIdInEnv_maybe').
 -}
 
 mkSimplEnv :: SimplifierMode -> SimplEnv
@@ -199,6 +228,7 @@ mkSimplEnv mode
   = SimplEnv { seMode = mode
              , seInScope = init_in_scope
              , seFloats = emptyFloats
+             , seJoinFloats = emptyJoinFloats
              , seTvSubst = emptyVarEnv
              , seCvSubst = emptyVarEnv
              , seIdSubst = emptyVarEnv }
@@ -241,7 +271,7 @@ updMode upd env = env { seMode = upd (seMode env) }
 extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
 extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
   = ASSERT2( isId var && not (isCoVar var), ppr var )
-    env {seIdSubst = extendVarEnv subst var res}
+    env { seIdSubst = extendVarEnv subst var (isJoinId_maybe var, res) }
 
 extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
 extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
@@ -264,13 +294,22 @@ setInScope :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set, and *zap* the floats
 setInScope env env_with_scope
   = env { seInScope = seInScope env_with_scope,
-          seFloats = emptyFloats }
+          seFloats = emptyFloats,
+          seJoinFloats = emptyJoinFloats }
 
 setFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set *and* the floats
 setFloats env env_with_floats
   = env { seInScope = seInScope env_with_floats,
-          seFloats  = seFloats  env_with_floats }
+          seFloats = seFloats  env_with_floats,
+          seJoinFloats = seJoinFloats env_with_floats }
+
+restoreJoinFloats :: SimplEnv -> SimplEnv -> SimplEnv
+-- Put back floats previously zapped
+-- Unlike 'setFloats', does *not* update the in-scope set, since the right-hand
+-- env is assumed to be *older*
+restoreJoinFloats env old_env
+  = env { seJoinFloats = seJoinFloats old_env }
 
 addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
         -- The new Ids are guaranteed to be freshly allocated
@@ -331,6 +370,8 @@ Can't happen:
 data Floats = Floats (OrdList OutBind) FloatFlag
         -- See Note [Simplifier floats]
 
+type JoinFloats = OrdList OutBind
+
 data FloatFlag
   = FltLifted   -- All bindings are lifted and lazy *or*
                 --     consist of a single primitive string literal
@@ -389,9 +430,13 @@ so we must take the 'or' of the two.
 emptyFloats :: Floats
 emptyFloats = Floats nilOL FltLifted
 
+emptyJoinFloats :: JoinFloats
+emptyJoinFloats = nilOL
+
 unitFloat :: OutBind -> Floats
 -- This key function constructs a singleton float with the right form
-unitFloat bind = Floats (unitOL bind) (flag bind)
+unitFloat bind = ASSERT(all (not . isJoinId) (bindersOf bind))
+                 Floats (unitOL bind) (flag bind)
   where
     flag (Rec {})                = FltLifted
     flag (NonRec bndr rhs)
@@ -404,6 +449,10 @@ unitFloat bind = Floats (unitOL bind) (flag bind)
                                    FltCareful
       -- Unlifted binders can only be let-bound if exprOkForSpeculation holds
 
+unitJoinFloat :: OutBind -> JoinFloats
+unitJoinFloat bind = ASSERT(all isJoinId (bindersOf bind))
+                     unitOL bind
+
 addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
 -- Add a non-recursive binding and extend the in-scope set
 -- The latter is important; the binder may already be in the
@@ -412,58 +461,104 @@ addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
 addNonRec env id rhs
   = id `seq`   -- This seq forces the Id, and hence its IdInfo,
                -- and hence any inner substitutions
-    env { seFloats = seFloats env `addFlts` unitFloat (NonRec id rhs),
+    env { seFloats = floats',
+          seJoinFloats = jfloats',
           seInScope = extendInScopeSet (seInScope env) id }
+  where
+    bind = NonRec id rhs
+
+    floats'  | isJoinId id = seFloats env
+             | otherwise   = seFloats env `addFlts` unitFloat bind
+    jfloats' | isJoinId id = seJoinFloats env `addJoinFlts` unitJoinFloat bind
+             | otherwise   = seJoinFloats env
 
 extendFloats :: SimplEnv -> OutBind -> SimplEnv
 -- Add these bindings to the floats, and extend the in-scope env too
 extendFloats env bind
-  = env { seFloats  = seFloats env `addFlts` unitFloat bind,
+  = ASSERT(all (not . isJoinId) (bindersOf bind))
+    env { seFloats  = floats',
+          seJoinFloats = jfloats',
           seInScope = extendInScopeSetList (seInScope env) bndrs }
   where
     bndrs = bindersOf bind
 
+    floats'  | isJoinBind bind = seFloats env
+             | otherwise       = seFloats env `addFlts` unitFloat bind
+    jfloats' | isJoinBind bind = seJoinFloats env `addJoinFlts`
+                                   unitJoinFloat bind
+             | otherwise       = seJoinFloats env
+
 addFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Add the floats for env2 to env1;
 -- *plus* the in-scope set for env2, which is bigger
 -- than that for env1
 addFloats env1 env2
   = env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
+          seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2,
           seInScope = seInScope env2 }
 
 addFlts :: Floats -> Floats -> Floats
 addFlts (Floats bs1 l1) (Floats bs2 l2)
   = Floats (bs1 `appOL` bs2) (l1 `andFF` l2)
 
+addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
+addJoinFlts = appOL
+
 zapFloats :: SimplEnv -> SimplEnv
-zapFloats env = env { seFloats = emptyFloats }
+zapFloats env = env { seFloats = emptyFloats
+                    , seJoinFloats = emptyJoinFloats }
+
+zapJoinFloats :: SimplEnv -> SimplEnv
+zapJoinFloats env = env { seJoinFloats = emptyJoinFloats }
 
 addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Flattens the floats from env2 into a single Rec group,
 -- prepends the floats from env1, and puts the result back in env2
 -- This is all very specific to the way recursive bindings are
 -- handled; see Simplify.simplRecBind
-addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff})
+addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
+                                 ,seJoinFloats = jbs })
   = ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
-    env2 {seFloats = seFloats env1 `addFlts` unitFloat (Rec (flattenBinds (fromOL bs)))}
+    env2 {seFloats = seFloats env1 `addFlts` floats'
+         ,seJoinFloats = seJoinFloats env1 `addJoinFlts` jfloats'}
+  where
+    floats'  | isNilOL bs  = emptyFloats
+             | otherwise   = unitFloat (Rec (flattenBinds (fromOL bs)))
+    jfloats' | isNilOL jbs = emptyJoinFloats
+             | otherwise   = unitJoinFloat (Rec (flattenBinds (fromOL jbs)))
 
 wrapFloats :: SimplEnv -> OutExpr -> OutExpr
 -- Wrap the floats around the expression; they should all
 -- satisfy the let/app invariant, so mkLets should do the job just fine
-wrapFloats (SimplEnv {seFloats = Floats bs _}) body
-  = foldrOL Let body bs
+wrapFloats env@(SimplEnv {seFloats = Floats bs _}) body
+  = foldrOL Let (wrapJoinFloats env body) bs
+      -- Note: Always safe to put the joins on the inside since the values
+      -- can't refer to them
+
+wrapJoinFloats :: SimplEnv -> OutExpr -> OutExpr
+wrapJoinFloats (SimplEnv {seJoinFloats = jbs}) body
+  = foldrOL Let body jbs
 
 getFloatBinds :: SimplEnv -> [CoreBind]
-getFloatBinds (SimplEnv {seFloats = Floats bs _})
-  = fromOL bs
+getFloatBinds env@(SimplEnv {seFloats = Floats bs _})
+  = fromOL bs ++ getJoinFloatBinds env
+
+getJoinFloatBinds :: SimplEnv -> [CoreBind]
+getJoinFloatBinds (SimplEnv {seJoinFloats = jbs})
+  = fromOL jbs
 
 isEmptyFloats :: SimplEnv -> Bool
-isEmptyFloats (SimplEnv {seFloats = Floats bs _})
-  = isNilOL bs
+isEmptyFloats env@(SimplEnv {seFloats = Floats bs _})
+  = isNilOL bs && isEmptyJoinFloats env
+
+isEmptyJoinFloats :: SimplEnv -> Bool
+isEmptyJoinFloats (SimplEnv {seJoinFloats = jbs})
+  = isNilOL jbs
 
 mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
-mapFloats env@SimplEnv { seFloats = Floats fs ff } fun
-   = env { seFloats = Floats (mapOL app fs) ff }
+mapFloats env@SimplEnv { seFloats = Floats fs ff, seJoinFloats = jfs } fun
+   = env { seFloats = Floats (mapOL app fs) ff
+         , seJoinFloats = mapOL app jfs }
    where
     app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
     app (Rec bs)     = Rec (map fun bs)
@@ -490,7 +585,7 @@ find that it has been substituted by b.  (Or conceivably cloned.)
 substId :: SimplEnv -> InId -> SimplSR
 -- Returns DoneEx only on a non-Var expression
 substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
-  = case lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
+  = case snd <$> lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
         Nothing               -> DoneId (refineFromInScope in_scope v)
         Just (DoneId v)       -> DoneId (refineFromInScope in_scope v)
         Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
@@ -499,6 +594,15 @@ substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
         -- Get the most up-to-date thing from the in-scope set
         -- Even though it isn't in the substitution, it may be in
         -- the in-scope set with better IdInfo
+
+isJoinIdInEnv_maybe :: SimplEnv -> InId -> Maybe JoinArity
+isJoinIdInEnv_maybe (SimplEnv { seInScope = inScope, seIdSubst = ids }) v
+  | not (isLocalId v)                         = Nothing
+  | Just (m_ar, _) <- lookupVarEnv ids v      = m_ar
+  | Just v'        <- lookupInScope inScope v = isJoinId_maybe v'
+  | otherwise                                 = WARN( True , ppr v )
+                                                isJoinId_maybe v
+
 refineFromInScope :: InScopeSet -> Var -> Var
 refineFromInScope in_scope v
   | isLocalId v = case lookupInScope in_scope v of
@@ -511,7 +615,7 @@ lookupRecBndr :: SimplEnv -> InId -> OutId
 -- but where we have not yet done its RHS
 lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
   = case lookupVarEnv ids v of
-        Just (DoneId v) -> v
+        Just (_, DoneId v) -> v
         Just _ -> pprPanic "lookupRecBndr" (ppr v)
         Nothing -> refineFromInScope in_scope v
 
@@ -539,33 +643,53 @@ simplBinder :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 simplBinder env bndr
   | isTyVar bndr  = do  { let (env', tv) = substTyVarBndr env bndr
                         ; seqTyVar tv `seq` return (env', tv) }
-  | otherwise     = do  { let (env', id) = substIdBndr env bndr
+  | otherwise     = do  { let (env', id) = substIdBndr Nothing env bndr
                         ; seqId id `seq` return (env', id) }
 
 ---------------
 simplNonRecBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 -- A non-recursive let binder
 simplNonRecBndr env id
-  = do  { let (env1, id1) = substIdBndr env id
+  = do  { let (env1, id1) = substIdBndr Nothing env id
+        ; seqId id1 `seq` return (env1, id1) }
+
+---------------
+simplNonRecJoinBndr :: SimplEnv -> OutType -> InBndr
+                    -> SimplM (SimplEnv, OutBndr)
+-- A non-recursive let binder for a join point; context being pushed inward may
+-- change the type
+simplNonRecJoinBndr env res_ty id
+  = do  { let (env1, id1) = substIdBndr (Just res_ty) env id
         ; seqId id1 `seq` return (env1, id1) }
 
 ---------------
 simplRecBndrs :: SimplEnv -> [InBndr] -> SimplM SimplEnv
 -- Recursive let binders
 simplRecBndrs env@(SimplEnv {}) ids
-  = do  { let (env1, ids1) = mapAccumL substIdBndr env ids
+  = ASSERT(all (not . isJoinId) ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr Nothing) env ids
+        ; seqIds ids1 `seq` return env1 }
+
+---------------
+simplRecJoinBndrs :: SimplEnv -> OutType -> [InBndr] -> SimplM SimplEnv
+-- Recursive let binders for join points; context being pushed inward may
+-- change types
+simplRecJoinBndrs env@(SimplEnv {}) res_ty ids
+  = ASSERT(all isJoinId ids)
+    do  { let (env1, ids1) = mapAccumL (substIdBndr (Just res_ty)) env ids
         ; seqIds ids1 `seq` return env1 }
 
 ---------------
-substIdBndr :: SimplEnv -> InBndr -> (SimplEnv, OutBndr)
+substIdBndr :: Maybe OutType -> SimplEnv -> InBndr -> (SimplEnv, OutBndr)
 -- Might be a coercion variable
-substIdBndr env bndr
+substIdBndr new_res_ty env bndr
   | isCoVar bndr  = substCoVarBndr env bndr
-  | otherwise     = substNonCoVarIdBndr env bndr
+  | otherwise     = substNonCoVarIdBndr new_res_ty env bndr
 
 ---------------
 substNonCoVarIdBndr
-   :: SimplEnv
+   :: Maybe OutType -- New result type, if a join binder
+   -> SimplEnv
    -> InBndr    -- Env and binder to transform
    -> (SimplEnv, OutBndr)
 -- Clone Id if necessary, substitute its type
@@ -585,7 +709,9 @@ substNonCoVarIdBndr
 -- Similar to CoreSubst.substIdBndr, except that
 --      the type of id_subst differs
 --      all fragile info is zapped
-substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst })
+substNonCoVarIdBndr new_res_ty
+                    env@(SimplEnv { seInScope = in_scope
+                                  , seIdSubst = id_subst })
                     old_id
   = ASSERT2( not (isCoVar old_id), ppr old_id )
     (env { seInScope = in_scope `extendInScopeSet` new_id,
@@ -593,14 +719,19 @@ substNonCoVarIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }
   where
     id1    = uniqAway in_scope old_id
     id2    = substIdType env id1
-    new_id = zapFragileIdInfo id2       -- Zaps rules, worker-info, unfolding
+    id3    | Just res_ty <- new_res_ty
+           = id2 `setIdType` setJoinResTy (idJoinArity id2) res_ty (idType id2)
+           | otherwise
+           = id2
+    new_id = zapFragileIdInfo id3       -- Zaps rules, worker-info, unfolding
                                         -- and fragile OccInfo
 
         -- Extend the substitution if the unique has changed,
         -- or there's some useful occurrence information
         -- See the notes with substTyVarBndr for the delSubstEnv
     new_subst | new_id /= old_id
-              = extendVarEnv id_subst old_id (DoneId new_id)
+              = extendVarEnv id_subst old_id
+                             (isJoinId_maybe new_id, DoneId new_id)
               | otherwise
               = delVarEnv id_subst old_id
 
@@ -664,7 +795,8 @@ the letrec.
 -}
 
 getTCvSubst :: SimplEnv -> TCvSubst
-getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env, seCvSubst = cv_env })
+getTCvSubst (SimplEnv { seInScope = in_scope, seTvSubst = tv_env
+                      , seCvSubst = cv_env })
   = mkTCvSubst in_scope (tv_env, cv_env)
 
 substTy :: SimplEnv -> Type -> Type