Join-point refactoring
[ghc.git] / compiler / simplCore / SimplEnv.hs
index a1a973e..9316ec0 100644 (file)
@@ -21,7 +21,6 @@ module SimplEnv (
 
         -- * Substitution results
         SimplSR(..), mkContEx, substId, lookupRecBndr, refineFromInScope,
-        isJoinIdInEnv_maybe,
 
         -- * Simplifying 'Id' binders
         simplNonRecBndr, simplNonRecJoinBndr, simplRecBndrs, simplRecJoinBndrs,
@@ -31,12 +30,12 @@ module SimplEnv (
 
         -- * Floats
         Floats, emptyFloats, isEmptyFloats,
-        addNonRec, addFloats, extendFloats,
+        addNonRec, addLetFloats, addFloats, extendFloats, addFlts,
         wrapFloats, setFloats, zapFloats, addRecFloats, mapFloats,
         doFloatFromRhs, getFloatBinds,
 
-        JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
-        wrapJoinFloats, zapJoinFloats, restoreJoinFloats, getJoinFloatBinds,
+        JoinFloat, JoinFloats, emptyJoinFloats, isEmptyJoinFloats,
+        wrapJoinFloats, wrapJoinFloatsX, zapJoinFloats, addJoinFloats
     ) where
 
 #include "HsVersions.h"
@@ -92,11 +91,19 @@ data SimplEnv
         -- The current set of in-scope variables
         -- They are all OutVars, and all bound in this module
         seInScope   :: InScopeSet,      -- OutVars only
-                -- Includes all variables bound by seFloats
-        seFloats    :: Floats,
+                -- Includes all variables bound
+                -- by seLetFloats and seJoinFloats
+
+        -- Ordinary bindings
+        seLetFloats  :: Floats,
                 -- See Note [Simplifier floats]
+
+        -- Join points
         seJoinFloats :: JoinFloats
                 -- Handled separately; they don't go very far
+                -- We consider these to be /inside/ seLetFloats
+                -- because join points can refer to ordinary bindings,
+                -- but not vice versa
     }
 
 type StaticEnv = SimplEnv       -- Just the static part is relevant
@@ -110,33 +117,45 @@ pprSimplEnv env
           text "InScope:" <+> in_scope_vars_doc
     ]
   where
-   id_subst_doc = pprUniqFM ppr_id_subst (seIdSubst env)
-   ppr_id_subst (m_ar, sr) = arity_part <+> ppr sr
-     where arity_part = case m_ar of Just ar -> brackets $
-                                                  text "join" <+> int ar
-                                     Nothing -> empty
-
+   id_subst_doc = pprUniqFM ppr (seIdSubst env)
    in_scope_vars_doc = pprVarSet (getInScopeVars (seInScope env))
                                  (vcat . map ppr_one)
    ppr_one v | isId v = ppr v <+> ppr (idUnfolding v)
              | otherwise = ppr v
 
-type SimplIdSubst = IdEnv (Maybe JoinArity, SimplSR) -- IdId |--> OutExpr
+type SimplIdSubst = IdEnv SimplSR -- IdId |--> OutExpr
         -- See Note [Extending the Subst] in CoreSubst
-        -- See Note [Join arity in SimplIdSubst]
 
 -- | A substitution result.
 data SimplSR
-  = DoneEx OutExpr              -- Completed term
-  | DoneId OutId                -- Completed term variable
-  | ContEx TvSubstEnv           -- A suspended substitution
+  = DoneEx OutExpr (Maybe JoinArity)
+       -- If  x :-> DoneEx e ja   is in the SimplIdSubst
+       -- then replace occurrences of x by e
+       -- and  ja = Just a <=> x is a join-point of arity a
+       -- See Note [Join arity in SimplIdSubst]
+
+
+  | DoneId OutId
+       -- If  x :-> DoneId v   is in the SimplIdSubst
+       -- then replace occurrences of x by v
+       -- and  v is a join-point of arity a
+       --      <=> x is a join-point of arity a
+
+  | ContEx TvSubstEnv                 -- A suspended substitution
            CvSubstEnv
            SimplIdSubst
            InExpr
+      -- If   x :-> ContEx tv cv id e   is in the SimplISubst
+      -- then replace occurrences of x by (subst (tv,cv,id) e)
 
 instance Outputable SimplSR where
-  ppr (DoneEx e) = text "DoneEx" <+> ppr e
-  ppr (DoneId v) = text "DoneId" <+> ppr v
+  ppr (DoneId v)    = text "DoneId" <+> ppr v
+  ppr (DoneEx e mj) = text "DoneEx" <> pp_mj <+> ppr e
+    where
+      pp_mj = case mj of
+                Nothing -> empty
+                Just n  -> parens (int n)
+
   ppr (ContEx _tv _cv _id e) = vcat [text "ContEx" <+> ppr e {-,
                                 ppr (filter_env tv), ppr (filter_env id) -}]
         -- where
@@ -211,24 +230,22 @@ seIdSubst:
 
 Note [Join arity in SimplIdSubst]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We have to remember which incoming variables are join points: the occurrences
+may not be marked correctly yet, and we're in change of propagating the change if
+OccurAnal makes something a join point).
 
-We have to remember which incoming variables are join points (the occurrences
-may not be marked correctly yet; we're in change of propagating the change if
-OccurAnal makes something a join point). Normally the in-scope set is where we
-keep the latest information, but the in-scope set tracks only OutVars; if a
-binding is unconditionally inlined, it never makes it into the in-scope set,
-and we need to know at the occurrence site that the variable is a join point so
-that we know to drop the context. Thus we remember which join points we're
-substituting. Clumsily, finding whether an InVar is a join variable may require
-looking in both the substitution *and* the in-scope set (see
-'isJoinIdInEnv_maybe').
--}
+Normally the in-scope set is where we keep the latest information, but
+the in-scope set tracks only OutVars; if a binding is unconditionally
+inlined (via DoneEx), it never makes it into the in-scope set, and we
+need to know at the occurrence site that the variable is a join point
+so that we know to drop the context. Thus we remember which join
+points we're substituting. -}
 
 mkSimplEnv :: SimplifierMode -> SimplEnv
 mkSimplEnv mode
   = SimplEnv { seMode = mode
              , seInScope = init_in_scope
-             , seFloats = emptyFloats
+             , seLetFloats = emptyFloats
              , seJoinFloats = emptyJoinFloats
              , seTvSubst = emptyVarEnv
              , seCvSubst = emptyVarEnv
@@ -272,7 +289,7 @@ updMode upd env = env { seMode = upd (seMode env) }
 extendIdSubst :: SimplEnv -> Id -> SimplSR -> SimplEnv
 extendIdSubst env@(SimplEnv {seIdSubst = subst}) var res
   = ASSERT2( isId var && not (isCoVar var), ppr var )
-    env { seIdSubst = extendVarEnv subst var (isJoinId_maybe var, res) }
+    env { seIdSubst = extendVarEnv subst var res }
 
 extendTvSubst :: SimplEnv -> TyVar -> Type -> SimplEnv
 extendTvSubst env@(SimplEnv {seTvSubst = tsubst}) var res
@@ -295,23 +312,16 @@ setInScopeAndZapFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set, and *zap* the floats
 setInScopeAndZapFloats env env_with_scope
   = env { seInScope    = seInScope env_with_scope,
-          seFloats     = emptyFloats,
+          seLetFloats  = emptyFloats,
           seJoinFloats = emptyJoinFloats }
 
 setFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Set the in-scope set *and* the floats
 setFloats env env_with_floats
   = env { seInScope    = seInScope env_with_floats,
-          seFloats     = seFloats  env_with_floats,
+          seLetFloats  = seLetFloats  env_with_floats,
           seJoinFloats = seJoinFloats env_with_floats }
 
-restoreJoinFloats :: SimplEnv -> SimplEnv -> SimplEnv
--- Put back floats previously zapped
--- Unlike 'setFloats', does *not* update the in-scope set, since the right-hand
--- env is assumed to be *older*
-restoreJoinFloats env old_env
-  = env { seJoinFloats = seJoinFloats old_env }
-
 addNewInScopeIds :: SimplEnv -> [CoreBndr] -> SimplEnv
         -- The new Ids are guaranteed to be freshly allocated
 addNewInScopeIds env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }) vs
@@ -371,7 +381,8 @@ Can't happen:
 data Floats = Floats (OrdList OutBind) FloatFlag
         -- See Note [Simplifier floats]
 
-type JoinFloats = OrdList OutBind
+type JoinFloat  = OutBind
+type JoinFloats = OrdList JoinFloat
 
 data FloatFlag
   = FltLifted   -- All bindings are lifted and lazy *or*
@@ -406,7 +417,7 @@ andFF FltLifted  flt        = flt
 
 doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> OutExpr -> SimplEnv -> Bool
 -- If you change this function look also at FloatIn.noFloatFromRhs
-doFloatFromRhs lvl rec str rhs (SimplEnv {seFloats = Floats fs ff})
+doFloatFromRhs lvl rec str rhs (SimplEnv {seLetFloats = Floats fs ff})
   =  not (isNilOL fs) && want_to_float && can_float
   where
      want_to_float = isTopLevel lvl || exprIsCheap rhs || exprIsExpandable rhs
@@ -459,44 +470,62 @@ addNonRec :: SimplEnv -> OutId -> OutExpr -> SimplEnv
 -- The latter is important; the binder may already be in the
 -- in-scope set (although it might also have been created with newId)
 -- but it may now have more IdInfo
-addNonRec env id rhs
-  = id `seq`   -- This seq forces the Id, and hence its IdInfo,
-               -- and hence any inner substitutions
-    env { seFloats = floats',
-          seJoinFloats = jfloats',
-          seInScope = extendInScopeSet (seInScope env) id }
+addNonRec env@(SimplEnv { seLetFloats  = floats
+                        , seJoinFloats = jfloats
+                        , seInScope = in_scope })
+          id rhs
+  | isJoinId id  -- This test incidentally forces the Id, and hence
+                 -- its IdInfo, and hence any inner substitutions
+  = env { seInScope    = in_scope'
+        , seLetFloats  = floats
+        , seJoinFloats = jfloats' }
+  | otherwise
+  = env { seInScope    = in_scope'
+        , seLetFloats  = floats'
+        , seJoinFloats = jfloats }
   where
-    bind = NonRec id rhs
-
-    floats'  | isJoinId id = seFloats env
-             | otherwise   = seFloats env `addFlts` unitFloat bind
-    jfloats' | isJoinId id = seJoinFloats env `addJoinFlts` unitJoinFloat bind
-             | otherwise   = seJoinFloats env
+    bind      = NonRec id rhs
+    in_scope' = extendInScopeSet in_scope id
+    floats'   = floats  `addFlts`     unitFloat     bind
+    jfloats'  = jfloats `addJoinFlts` unitJoinFloat bind
 
 extendFloats :: SimplEnv -> OutBind -> SimplEnv
--- Add these bindings to the floats, and extend the in-scope env too
-extendFloats env bind
-  = ASSERT(all (not . isJoinId) (bindersOf bind))
-    env { seFloats  = floats',
-          seJoinFloats = jfloats',
-          seInScope = extendInScopeSetList (seInScope env) bndrs }
+-- Add this binding to the floats, and extend the in-scope env too
+extendFloats env@(SimplEnv { seLetFloats  = floats
+                           , seJoinFloats = jfloats
+                           , seInScope = in_scope })
+             bind
+  | isJoinBind bind
+  = env { seInScope    = in_scope'
+        , seLetFloats  = floats
+        , seJoinFloats = jfloats' }
+  | otherwise
+  = env { seInScope    = in_scope'
+        , seLetFloats  = floats'
+        , seJoinFloats = jfloats }
   where
     bndrs = bindersOf bind
 
-    floats'  | isJoinBind bind = seFloats env
-             | otherwise       = seFloats env `addFlts` unitFloat bind
-    jfloats' | isJoinBind bind = seJoinFloats env `addJoinFlts`
-                                   unitJoinFloat bind
-             | otherwise       = seJoinFloats env
+    in_scope' = extendInScopeSetList in_scope bndrs
+    floats'   = floats  `addFlts`     unitFloat bind
+    jfloats'  = jfloats `addJoinFlts` unitJoinFloat bind
+
+addLetFloats :: SimplEnv -> SimplEnv -> SimplEnv
+-- Add the let-floats for env2 to env1;
+-- *plus* the in-scope set for env2, which is bigger
+-- than that for env1
+addLetFloats env1 env2
+  = env1 { seLetFloats = seLetFloats env1 `addFlts` seLetFloats env2
+         , seInScope   = seInScope env2 }
 
 addFloats :: SimplEnv -> SimplEnv -> SimplEnv
--- Add the floats for env2 to env1;
+-- Add both let-floats and join-floats for env2 to env1;
 -- *plus* the in-scope set for env2, which is bigger
 -- than that for env1
 addFloats env1 env2
-  = env1 {seFloats = seFloats env1 `addFlts` seFloats env2,
-          seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2,
-          seInScope = seInScope env2 }
+  = env1 { seLetFloats  = seLetFloats env1 `addFlts` seLetFloats env2
+         , seJoinFloats = seJoinFloats env1 `addJoinFlts` seJoinFloats env2
+         , seInScope    = seInScope env2 }
 
 addFlts :: Floats -> Floats -> Floats
 addFlts (Floats bs1 l1) (Floats bs2 l2)
@@ -506,21 +535,25 @@ addJoinFlts :: JoinFloats -> JoinFloats -> JoinFloats
 addJoinFlts = appOL
 
 zapFloats :: SimplEnv -> SimplEnv
-zapFloats env = env { seFloats = emptyFloats
+zapFloats env = env { seLetFloats  = emptyFloats
                     , seJoinFloats = emptyJoinFloats }
 
 zapJoinFloats :: SimplEnv -> SimplEnv
 zapJoinFloats env = env { seJoinFloats = emptyJoinFloats }
 
+addJoinFloats :: SimplEnv -> JoinFloats -> SimplEnv
+addJoinFloats env@(SimplEnv { seJoinFloats = fb1 }) fb2
+  = env { seJoinFloats = fb1 `addJoinFlts` fb2 }
+
 addRecFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Flattens the floats from env2 into a single Rec group,
 -- prepends the floats from env1, and puts the result back in env2
 -- This is all very specific to the way recursive bindings are
 -- handled; see Simplify.simplRecBind
-addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
+addRecFloats env1 env2@(SimplEnv {seLetFloats  = Floats bs ff
                                  ,seJoinFloats = jbs })
   = ASSERT2( case ff of { FltLifted -> True; _ -> False }, ppr (fromOL bs) )
-    env2 {seFloats = seFloats env1 `addFlts` floats'
+    env2 {seLetFloats = seLetFloats env1 `addFlts` floats'
          ,seJoinFloats = seJoinFloats env1 `addJoinFlts` jfloats'}
   where
     floats'  | isNilOL bs  = emptyFloats
@@ -531,35 +564,39 @@ addRecFloats env1 env2@(SimplEnv {seFloats = Floats bs ff
 wrapFloats :: SimplEnv -> OutExpr -> OutExpr
 -- Wrap the floats around the expression; they should all
 -- satisfy the let/app invariant, so mkLets should do the job just fine
-wrapFloats env@(SimplEnv {seFloats = Floats bs _}) body
-  = foldrOL Let (wrapJoinFloats env body) bs
-      -- Note: Always safe to put the joins on the inside since the values
-      -- can't refer to them
-
-wrapJoinFloats :: SimplEnv -> OutExpr -> OutExpr
-wrapJoinFloats (SimplEnv {seJoinFloats = jbs}) body
-  = foldrOL Let body jbs
+wrapFloats (SimplEnv { seLetFloats  = Floats bs _
+                     , seJoinFloats = jbs }) body
+  = foldrOL Let (wrapJoinFloats jbs body) bs
+     -- Note: Always safe to put the joins on the inside
+     -- since the values can't refer to them
+
+wrapJoinFloatsX :: SimplEnv -> OutExpr -> (SimplEnv, OutExpr)
+-- Wrap the seJoinFloats of the env around the expression,
+-- and take them out of the SimplEnv
+wrapJoinFloatsX env@(SimplEnv { seJoinFloats = jbs }) body
+  = (zapJoinFloats env, wrapJoinFloats jbs body)
+
+wrapJoinFloats :: JoinFloats -> OutExpr -> OutExpr
+-- Wrap the seJoinFloats of the env around the expression,
+-- and take them out of the SimplEnv
+wrapJoinFloats join_floats body
+  = foldrOL Let body join_floats
 
 getFloatBinds :: SimplEnv -> [CoreBind]
-getFloatBinds env@(SimplEnv {seFloats = Floats bs _})
-  = fromOL bs ++ getJoinFloatBinds env
-
-getJoinFloatBinds :: SimplEnv -> [CoreBind]
-getJoinFloatBinds (SimplEnv {seJoinFloats = jbs})
-  = fromOL jbs
+getFloatBinds (SimplEnv {seLetFloats = Floats bs _, seJoinFloats = jbs})
+  = fromOL bs ++ fromOL jbs
 
 isEmptyFloats :: SimplEnv -> Bool
-isEmptyFloats env@(SimplEnv {seFloats = Floats bs _})
+isEmptyFloats env@(SimplEnv {seLetFloats = Floats bs _})
   = isNilOL bs && isEmptyJoinFloats env
 
 isEmptyJoinFloats :: SimplEnv -> Bool
 isEmptyJoinFloats (SimplEnv {seJoinFloats = jbs})
   = isNilOL jbs
 
-mapFloats :: SimplEnv -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> SimplEnv
-mapFloats env@SimplEnv { seFloats = Floats fs ff, seJoinFloats = jfs } fun
-   = env { seFloats = Floats (mapOL app fs) ff
-         , seJoinFloats = mapOL app jfs }
+mapFloats :: Floats -> ((Id,CoreExpr) -> (Id,CoreExpr)) -> Floats
+mapFloats (Floats fs ff) fun
+   = Floats (mapOL app fs) ff
    where
     app (NonRec b e) = case fun (b,e) of (b',e') -> NonRec b' e'
     app (Rec bs)     = Rec (map fun bs)
@@ -586,24 +623,15 @@ find that it has been substituted by b.  (Or conceivably cloned.)
 substId :: SimplEnv -> InId -> SimplSR
 -- Returns DoneEx only on a non-Var expression
 substId (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
-  = case snd <$> lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
+  = case lookupVarEnv ids v of  -- Note [Global Ids in the substitution]
         Nothing               -> DoneId (refineFromInScope in_scope v)
         Just (DoneId v)       -> DoneId (refineFromInScope in_scope v)
-        Just (DoneEx (Var v)) -> DoneId (refineFromInScope in_scope v)
         Just res              -> res    -- DoneEx non-var, or ContEx
 
         -- Get the most up-to-date thing from the in-scope set
         -- Even though it isn't in the substitution, it may be in
         -- the in-scope set with better IdInfo
 
-isJoinIdInEnv_maybe :: SimplEnv -> InId -> Maybe JoinArity
-isJoinIdInEnv_maybe (SimplEnv { seInScope = inScope, seIdSubst = ids }) v
-  | not (isLocalId v)                         = Nothing
-  | Just (m_ar, _) <- lookupVarEnv ids v      = m_ar
-  | Just v'        <- lookupInScope inScope v = isJoinId_maybe v'
-  | otherwise                                 = WARN( True , ppr v )
-                                                isJoinId_maybe v
-
 refineFromInScope :: InScopeSet -> Var -> Var
 refineFromInScope in_scope v
   | isLocalId v = case lookupInScope in_scope v of
@@ -616,7 +644,7 @@ lookupRecBndr :: SimplEnv -> InId -> OutId
 -- but where we have not yet done its RHS
 lookupRecBndr (SimplEnv { seInScope = in_scope, seIdSubst = ids }) v
   = case lookupVarEnv ids v of
-        Just (_, DoneId v) -> v
+        Just (DoneId v) -> v
         Just _ -> pprPanic "lookupRecBndr" (ppr v)
         Nothing -> refineFromInScope in_scope v
 
@@ -731,8 +759,7 @@ substNonCoVarIdBndr new_res_ty
         -- or there's some useful occurrence information
         -- See the notes with substTyVarBndr for the delSubstEnv
     new_subst | new_id /= old_id
-              = extendVarEnv id_subst old_id
-                             (isJoinId_maybe new_id, DoneId new_id)
+              = extendVarEnv id_subst old_id (DoneId new_id)
               | otherwise
               = delVarEnv id_subst old_id