Deal with join points with RULES
authorSimon Peyton Jones <simonpj@microsoft.com>
Mon, 26 Mar 2018 16:16:14 +0000 (17:16 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Tue, 27 Mar 2018 08:29:13 +0000 (09:29 +0100)
Trac #13900 showed that when we have a join point that
has a RULE, we must push the continuation into the RHS
of the RULE.

See Note [Rules and unfolding for join points]

It's hard to tickle this bug, so I have not added a regression test.

compiler/simplCore/SimplCore.hs
compiler/simplCore/Simplify.hs

index a34baa8..fe6d446 100644 (file)
@@ -767,7 +767,7 @@ simplifyPgmIO pass@(CoreDoSimplify max_iterations mode)
                       -- for imported Ids.  Eg  RULE map my_f = blah
                       -- If we have a substitution my_f :-> other_f, we'd better
                       -- apply it to the rule to, or it'll never match
-                  ; rules1 <- simplRules env1 Nothing rules
+                  ; rules1 <- simplRules env1 Nothing rules Nothing
 
                   ; return (getTopFloatBinds floats, rules1) } ;
 
index 53e3a21..a60df1c 100644 (file)
@@ -24,7 +24,7 @@ import Id
 import MkId             ( seqId )
 import MkCore           ( mkImpossibleExpr, castBottomExpr )
 import IdInfo
-import Name             ( Name, mkSystemVarName, isExternalName, getOccFS )
+import Name             ( mkSystemVarName, isExternalName, getOccFS )
 import Coercion hiding  ( substCo, substCoVar )
 import OptCoercion      ( optCoercion )
 import FamInstEnv       ( topNormaliseType_maybe )
@@ -143,11 +143,11 @@ simplTopBinds env0 binds0
                                       ; (floats, env2) <- simpl_binds env1 binds
                                       ; return (float `addFloats` floats, env2) }
 
-    simpl_bind env (Rec pairs)  = simplRecBind env TopLevel Nothing pairs
-    simpl_bind env (NonRec b r) = do { (env', b') <- addBndrRules env b (lookupRecBndr env b)
-                                     ; simplRecOrTopPair env' TopLevel
-                                                         NonRecursive Nothing
-                                                         b b' r }
+    simpl_bind env (Rec pairs)
+      = simplRecBind env TopLevel Nothing pairs
+    simpl_bind env (NonRec b r)
+      = do { (env', b') <- addBndrRules env b (lookupRecBndr env b) Nothing
+           ; simplRecOrTopPair env' TopLevel NonRecursive Nothing b b' r }
 
 {-
 ************************************************************************
@@ -160,7 +160,7 @@ simplRecBind is used for
         * recursive bindings only
 -}
 
-simplRecBind :: SimplEnv -> TopLevelFlag -> Maybe SimplCont
+simplRecBind :: SimplEnv -> TopLevelFlag -> MaybeJoinCont
              -> [(InId, InExpr)]
              -> SimplM (SimplFloats, SimplEnv)
 simplRecBind env0 top_lvl mb_cont pairs0
@@ -171,7 +171,7 @@ simplRecBind env0 top_lvl mb_cont pairs0
     add_rules :: SimplEnv -> (InBndr,InExpr) -> SimplM (SimplEnv, (InBndr, OutBndr, InExpr))
         -- Add the (substituted) rules to the binder
     add_rules env (bndr, rhs)
-        = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr)
+        = do { (env', bndr') <- addBndrRules env bndr (lookupRecBndr env bndr) mb_cont
              ; return (env', (bndr, bndr', rhs)) }
 
     go env [] = return (emptyFloats env, env)
@@ -191,7 +191,7 @@ It assumes the binder has already been simplified, but not its IdInfo.
 -}
 
 simplRecOrTopPair :: SimplEnv
-                  -> TopLevelFlag -> RecFlag -> Maybe SimplCont
+                  -> TopLevelFlag -> RecFlag -> MaybeJoinCont
                   -> InId -> OutBndr -> InExpr  -- Binder and rhs
                   -> SimplM (SimplFloats, SimplEnv)
 
@@ -616,7 +616,7 @@ Nor does it do the atomic-argument thing
 
 completeBind :: SimplEnv
              -> TopLevelFlag            -- Flag stuck into unfolding
-             -> Maybe SimplCont         -- Required only for join point
+             -> MaybeJoinCont           -- Required only for join point
              -> InId                    -- Old binder
              -> OutId -> OutExpr        -- New binder and RHS
              -> SimplM (SimplFloats, SimplEnv)
@@ -645,7 +645,7 @@ completeBind env top_lvl mb_cont old_bndr new_bndr new_rhs
 
         -- Simplify the unfolding
       ; new_unfolding <- simplLetUnfolding env top_lvl mb_cont old_bndr
-                                           final_rhs old_unf
+                                           final_rhs (idType new_bndr) old_unf
 
       ; let final_bndr = addLetBndrInfo new_bndr new_arity is_bot new_unfolding
 
@@ -1319,7 +1319,8 @@ simplLamBndr :: SimplEnv -> InBndr -> SimplM (SimplEnv, OutBndr)
 simplLamBndr env bndr
   | isId bndr && isFragileUnfolding old_unf   -- Special case
   = do { (env1, bndr1) <- simplBinder env bndr
-       ; unf'          <- simplStableUnfolding env1 NotTopLevel Nothing bndr old_unf
+       ; unf'          <- simplStableUnfolding env1 NotTopLevel Nothing bndr
+                                               old_unf (idType bndr1)
        ; let bndr2 = bndr1 `setIdUnfolding` unf'
        ; return (modifyInScope env1 bndr2, bndr2) }
 
@@ -1378,7 +1379,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
   | otherwise
   = ASSERT( not (isTyVar bndr) )
     do { (env1, bndr1) <- simplNonRecBndr env bndr
-       ; (env2, bndr2) <- addBndrRules env1 bndr bndr1
+       ; (env2, bndr2) <- addBndrRules env1 bndr bndr1 Nothing
        ; (floats1, env3) <- simplLazyBind env2 NotTopLevel NonRecursive bndr bndr2 rhs rhs_se
        ; (floats2, expr') <- simplLam env3 bndrs body cont
        ; return (floats1 `addFloats` floats2, expr') }
@@ -1450,6 +1451,33 @@ Here it'd be far better to drop the unfolding and use the actual RHS.
 *                                                                      *
 ********************************************************************* -}
 
+{- Note [Rules and unfolding for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have
+
+   simplExpr (join j x = rhs                         ) cont
+             (      {- RULE j (p:ps) = blah -}       )
+             (      {- StableUnfolding j = blah -}   )
+             (in blah                                )
+
+Then we will push 'cont' into the rhs of 'j'.  But we should *also* push
+'cont' into the RHS of
+  * Any RULEs for j, e.g. generated by SpecConstr
+  * Any stable unfolding for j, e.g. the result of an INLINE pragma
+
+Simplifying rules and stable-unfoldings happens a bit after
+simplifying the right-hand side, so we remember whether or not it
+is a join point, and what 'cont' is, in a value of type MaybeJoinCont
+
+Trac #13900 wsa caused by forgetting to push 'cont' into the RHS
+of a SpecConstr-generated RULE for a join point.
+-}
+
+type MaybeJoinCont = Maybe SimplCont
+  -- Nothing => Not a join point
+  -- Just k  => This is a join binding with continuation k
+  -- See Note [Rules and unfolding for join points]
+
 simplNonRecJoinPoint :: SimplEnv -> InId -> InExpr
                      -> InExpr -> SimplCont
                      -> SimplM (SimplFloats, OutExpr)
@@ -1465,7 +1493,7 @@ simplNonRecJoinPoint env bndr rhs body cont
           -- and wrap wrap_cont around the whole thing
         ; let res_ty = contResultType cont
         ; (env1, bndr1)    <- simplNonRecJoinBndr env res_ty bndr
-        ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1
+        ; (env2, bndr2)    <- addBndrRules env1 bndr bndr1 (Just cont)
         ; (floats1, env3)  <- simplJoinBind env2 cont bndr bndr2 rhs env
         ; (floats2, body') <- simplExprF env3 body cont
         ; return (floats1 `addFloats` floats2, body') }
@@ -3235,13 +3263,13 @@ because we don't know its usage in each RHS separately
 -}
 
 simplLetUnfolding :: SimplEnv-> TopLevelFlag
-                  -> Maybe SimplCont
+                  -> MaybeJoinCont
                   -> InId
-                  -> OutExpr
+                  -> OutExpr -> OutType
                   -> Unfolding -> SimplM Unfolding
-simplLetUnfolding env top_lvl cont_mb id new_rhs unf
+simplLetUnfolding env top_lvl cont_mb id new_rhs rhs_ty unf
   | isStableUnfolding unf
-  = simplStableUnfolding env top_lvl cont_mb id unf
+  = simplStableUnfolding env top_lvl cont_mb id unf rhs_ty
   | isExitJoinId id
   = return noUnfolding -- see Note [Do not inline exit join points]
   | otherwise
@@ -3265,26 +3293,26 @@ mkLetUnfolding dflags top_lvl src id new_rhs
 
 -------------------
 simplStableUnfolding :: SimplEnv -> TopLevelFlag
-                     -> Maybe SimplCont  -- Just k => a join point with continuation k
+                     -> MaybeJoinCont  -- Just k => a join point with continuation k
                      -> InId
-                     -> Unfolding -> SimplM Unfolding
+                     -> Unfolding -> OutType -> SimplM Unfolding
 -- Note [Setting the new unfolding]
-simplStableUnfolding env top_lvl mb_cont id unf
+simplStableUnfolding env top_lvl mb_cont id unf rhs_ty
   = case unf of
       NoUnfolding   -> return unf
       BootUnfolding -> return unf
       OtherCon {}   -> return unf
 
       DFunUnfolding { df_bndrs = bndrs, df_con = con, df_args = args }
-        -> do { (env', bndrs') <- simplBinders rule_env bndrs
+        -> do { (env', bndrs') <- simplBinders unf_env bndrs
               ; args' <- mapM (simplExpr env') args
               ; return (mkDFunUnfolding bndrs' con args') }
 
       CoreUnfolding { uf_tmpl = expr, uf_src = src, uf_guidance = guide }
         | isStableSource src
-        -> do { expr' <- case mb_cont of
-                           Just cont -> simplJoinRhs rule_env id expr cont
-                           Nothing   -> simplExpr rule_env expr
+        -> do { expr' <- case mb_cont of -- See Note [Rules and unfolding for join points]
+                           Just cont -> simplJoinRhs unf_env id expr cont
+                           Nothing   -> simplExprC unf_env expr (mkBoringStop rhs_ty)
               ; case guide of
                   UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok }  -- Happens for INLINE things
                      -> let guide' = UnfWhen { ug_arity = arity, ug_unsat_ok = sat_ok
@@ -3308,7 +3336,7 @@ simplStableUnfolding env top_lvl mb_cont id unf
     dflags     = seDynFlags env
     is_top_lvl = isTopLevel top_lvl
     act        = idInlineActivation id
-    rule_env   = updMode (updModeForStableUnfoldings act) env
+    unf_env    = updMode (updModeForStableUnfoldings act) env
          -- See Note [Simplifying inside stable unfoldings] in SimplUtils
 
 {-
@@ -3350,20 +3378,24 @@ to apply in that function's own right-hand side.
 See Note [Forming Rec groups] in OccurAnal
 -}
 
-addBndrRules :: SimplEnv -> InBndr -> OutBndr -> SimplM (SimplEnv, OutBndr)
+addBndrRules :: SimplEnv -> InBndr -> OutBndr
+             -> MaybeJoinCont   -- Just k for a join point binder
+                                -- Nothing otherwise
+             -> SimplM (SimplEnv, OutBndr)
 -- Rules are added back into the bin
-addBndrRules env in_id out_id
+addBndrRules env in_id out_id mb_cont
   | null old_rules
   = return (env, out_id)
   | otherwise
-  = do { new_rules <- simplRules env (Just (idName out_id)) old_rules
+  = do { new_rules <- simplRules env (Just out_id) old_rules mb_cont
        ; let final_id  = out_id `setIdSpecialisation` mkRuleInfo new_rules
        ; return (modifyInScope env final_id, final_id) }
   where
     old_rules = ruleInfoRules (idSpecialisation in_id)
 
-simplRules :: SimplEnv -> Maybe Name -> [CoreRule] -> SimplM [CoreRule]
-simplRules env mb_new_nm rules
+simplRules :: SimplEnv -> Maybe OutId -> [CoreRule]
+           -> MaybeJoinCont -> SimplM [CoreRule]
+simplRules env mb_new_id rules mb_cont
   = mapM simpl_rule rules
   where
     simpl_rule rule@(BuiltinRule {})
@@ -3373,11 +3405,29 @@ simplRules env mb_new_nm rules
                           , ru_fn = fn_name, ru_rhs = rhs })
       = do { (env', bndrs') <- simplBinders env bndrs
            ; let rhs_ty = substTy env' (exprType rhs)
-                 rule_cont = mkBoringStop rhs_ty
-                 rule_env  = updMode updModeForRules env'
+                 rhs_cont = case mb_cont of  -- See Note [Rules and unfolding for join points]
+                                Nothing   -> mkBoringStop rhs_ty
+                                Just cont -> ASSERT2( join_ok, bad_join_msg )
+                                             cont
+                 rule_env = updMode updModeForRules env'
+                 fn_name' = case mb_new_id of
+                              Just id -> idName id
+                              Nothing -> fn_name
+
+                 -- join_ok is an assertion check that the join-arity of the
+                 -- binder matches that of the rule, so that pushing the
+                 -- continuation into the RHS makes sense
+                 join_ok = case mb_new_id of
+                             Just id | Just join_arity <- isJoinId_maybe id
+                                     -> length args == join_arity
+                             _ -> False
+                 bad_join_msg = vcat [ ppr mb_new_id, ppr rule
+                                     , ppr (fmap isJoinId_maybe mb_new_id) ]
+
            ; args' <- mapM (simplExpr rule_env) args
-           ; rhs'  <- simplExprC rule_env rhs rule_cont
+           ; rhs'  <- simplExprC rule_env rhs rhs_cont
            ; return (rule { ru_bndrs = bndrs'
-                          , ru_fn    = mb_new_nm `orElse` fn_name
+                          , ru_fn    = fn_name'
                           , ru_args  = args'
                           , ru_rhs   = rhs' }) }
+