Fix arguments for unbound binders in RULE application
authorSimon Peyton Jones <simonpj@microsoft.com>
Wed, 25 Sep 2019 15:26:29 +0000 (16:26 +0100)
committerMarge Bot <ben+marge-bot@smart-cactus.org>
Tue, 1 Oct 2019 02:40:30 +0000 (22:40 -0400)
We were failing to correctly implement Note [Unbound RULE binders]
in Rules.hs.  In particular, when cooking up a fake Refl,
were were failing to apply the substitition.

This patch fixes that problem, and simultaneously tidies
up the impedence mis-match between RuleSubst and TCvSubst.

Thanks to Sebastian!

compiler/specialise/Rules.hs

index df6196a..72e2934 100644 (file)
@@ -40,7 +40,8 @@ import CoreUtils        ( exprType, eqExpr, mkTick, mkTicks,
                           stripTicksTopT, stripTicksTopE,
                           isJoinBind )
 import PprCore          ( pprRules )
-import Type             ( Type, Kind, substTy, mkTCvSubst )
+import Type             ( Type, TCvSubst, extendTvSubst, extendCvSubst
+                        , mkEmptyTCvSubst, substTy )
 import TcType           ( tcSplitTyConApp_maybe )
 import TysWiredIn       ( anyTypeOfKind )
 import Coercion
@@ -448,8 +449,9 @@ isMoreSpecific :: CoreRule -> CoreRule -> Bool
 isMoreSpecific (BuiltinRule {}) _                = False
 isMoreSpecific (Rule {})        (BuiltinRule {}) = True
 isMoreSpecific (Rule { ru_bndrs = bndrs1, ru_args = args1 })
-               (Rule { ru_bndrs = bndrs2, ru_args = args2, ru_name = rule_name2 })
-  = isJust (matchN (in_scope, id_unfolding_fun) rule_name2 bndrs2 args2 args1)
+               (Rule { ru_bndrs = bndrs2, ru_args = args2
+                     , ru_name = rule_name2, ru_rhs = rhs })
+  = isJust (matchN (in_scope, id_unfolding_fun) rule_name2 bndrs2 args2 args1 rhs)
   where
    id_unfolding_fun _ = NoUnfolding     -- Don't expand in templates
    in_scope = mkInScopeSet (mkVarSet bndrs1)
@@ -516,29 +518,26 @@ matchRule _ in_scope is_active _ args rough_args
                 , ru_bndrs = tpl_vars, ru_args = tpl_args, ru_rhs = rhs })
   | not (is_active act)               = Nothing
   | ruleCantMatch tpl_tops rough_args = Nothing
-  | otherwise
-  = case matchN in_scope rule_name tpl_vars tpl_args args of
-        Nothing                       -> Nothing
-        Just (bind_wrapper, tpl_vals) -> Just (bind_wrapper $
-                                               rule_fn `mkApps` tpl_vals)
-  where
-    rule_fn = mkLams tpl_vars rhs
+  | otherwise = matchN in_scope rule_name tpl_vars tpl_args args rhs
 
 ---------------------------------------
 matchN  :: InScopeEnv
         -> RuleName -> [Var] -> [CoreExpr]
-        -> [CoreExpr]           -- ^ Target; can have more elements than the template
-        -> Maybe (BindWrapper,  -- Floated bindings; see Note [Matching lets]
-                  [CoreExpr])
+        -> [CoreExpr] -> CoreExpr           -- ^ Target; can have more elements than the template
+        -> Maybe CoreExpr
 -- For a given match template and context, find bindings to wrap around
 -- the entire result and what should be substituted for each template variable.
 -- Fail if there are two few actual arguments from the target to match the template
 
-matchN (in_scope, id_unf) rule_name tmpl_vars tmpl_es target_es
-  = do  { subst <- go init_menv emptyRuleSubst tmpl_es target_es
-        ; let (_, matched_es) = mapAccumL lookup_tmpl subst $
+matchN (in_scope, id_unf) rule_name tmpl_vars tmpl_es target_es rhs
+  = do  { rule_subst <- go init_menv emptyRuleSubst tmpl_es target_es
+        ; let (_, matched_es) = mapAccumL (lookup_tmpl rule_subst)
+                                          (mkEmptyTCvSubst in_scope) $
                                 tmpl_vars `zip` tmpl_vars1
-        ; return (rs_binds subst, matched_es) }
+              bind_wrapper = rs_binds rule_subst
+                             -- Floated bindings; see Note [Matching lets]
+       ; return (bind_wrapper $
+                 mkLams tmpl_vars rhs `mkApps` matched_es) }
   where
     (init_rn_env, tmpl_vars1) = mapAccumL rnBndrL (mkRnEnv2 in_scope) tmpl_vars
                   -- See Note [Cloning the template binders]
@@ -553,29 +552,32 @@ matchN (in_scope, id_unf) rule_name tmpl_vars tmpl_es target_es
     go menv subst (t:ts) (e:es) = do { subst1 <- match menv subst t e
                                      ; go menv subst1 ts es }
 
-    lookup_tmpl :: RuleSubst -> (InVar,OutVar) -> (RuleSubst, CoreExpr)
+    lookup_tmpl :: RuleSubst -> TCvSubst -> (InVar,OutVar) -> (TCvSubst, CoreExpr)
                    -- Need to return a RuleSubst solely for the benefit of mk_fake_ty
-    lookup_tmpl rs@(RS { rs_tv_subst = tv_subst, rs_id_subst = id_subst })
-                (tmpl_var, tmpl_var1)
+    lookup_tmpl (RS { rs_tv_subst = tv_subst, rs_id_subst = id_subst })
+                tcv_subst (tmpl_var, tmpl_var1)
         | isId tmpl_var1
         = case lookupVarEnv id_subst tmpl_var1 of
-             Just e -> (rs, e)
-             Nothing | Just refl_co <- isReflCoVar_maybe tmpl_var1
-                     , let co_expr   = Coercion refl_co
-                           id_subst' = extendVarEnv id_subst tmpl_var1 co_expr
-                           rs'       = rs { rs_id_subst = id_subst' }
-                     -> (rs', co_expr) -- See Note [Unbound RULE binders]
-                     | otherwise
-                     -> unbound tmpl_var
+            Just e | Coercion co <- e
+                   -> (Type.extendCvSubst tcv_subst tmpl_var1 co, Coercion co)
+                   | otherwise
+                   -> (tcv_subst, e)
+            Nothing | Just refl_co <- isReflCoVar_maybe tmpl_var1
+                    , let co = Coercion.substCo tcv_subst refl_co
+                    -> -- See Note [Unbound RULE binders]
+                       (Type.extendCvSubst tcv_subst tmpl_var1 co, Coercion co)
+                    | otherwise
+                    -> unbound tmpl_var
+
         | otherwise
-        = case lookupVarEnv tv_subst tmpl_var1 of
-             Just ty -> (rs, Type ty)
-             Nothing -> (rs', Type fake_ty) -- See Note [Unbound RULE binders]
+        = (Type.extendTvSubst tcv_subst tmpl_var1 ty', Type ty')
         where
-          rs'     = rs { rs_tv_subst = extendVarEnv tv_subst tmpl_var1 fake_ty }
-          fake_ty = mk_fake_ty in_scope rs tmpl_var1
-                    -- This call is the sole reason we accumulate
-                    -- RuleSubst in lookup_tmpl
+          ty' = case lookupVarEnv tv_subst tmpl_var1 of
+                  Just ty -> ty
+                  Nothing -> fake_ty   -- See Note [Unbound RULE binders]
+          fake_ty = anyTypeOfKind (Type.substTy tcv_subst (tyVarKind tmpl_var1))
+                    -- This substitution is the sole reason we accumulate
+                    -- TCvSubst in lookup_tmpl
 
     unbound tmpl_var
        = pprPanic "Template variable unbound in rewrite rule" $
@@ -586,33 +588,6 @@ matchN (in_scope, id_unf) rule_name tmpl_vars tmpl_es target_es
               , text "Actual args:" <+> ppr target_es ]
 
 
-mk_fake_ty :: InScopeSet -> RuleSubst -> TyVar -> Kind
--- Roughly:
---    mk_fake_ty subst tv = Any @(subst (tyVarKind tv))
--- That is: apply the substitution to the kind of the given tyvar,
--- and make an 'any' type of that kind.
--- Tiresomely, the RuleSubst is not well adapted to substTy, leading to
--- horrible impedence matching.
---
--- Happily, this function is seldom called
-mk_fake_ty in_scope (RS { rs_tv_subst = tv_subst, rs_id_subst = id_subst }) tmpl_var1
-  = anyTypeOfKind kind
-  where
-    kind = Type.substTy (mkTCvSubst in_scope (tv_subst, cv_subst))
-                        (tyVarKind tmpl_var1)
-
-    cv_subst = to_co_env id_subst
-
-    to_co_env :: IdSubstEnv -> CvSubstEnv
-    to_co_env env = nonDetFoldUFM_Directly to_co emptyVarEnv env
-      -- It's OK to use nonDetFoldUFM_Directly because we forget the
-      -- order immediately by creating a new env
-
-    to_co uniq expr env
-      = case exprToCoercion_maybe expr of
-          Just co -> extendVarEnv_Directly env uniq co
-          Nothing -> env
-
 {- Note [Unbound RULE binders]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 It can be the case that the binder in a rule is not actually
@@ -643,7 +618,7 @@ bound on the LHS:
   Now, if that binding is inlined, so that a=b=Int, we'd get
     RULE forall (c :: Int~Int). f (x |> c) = e
   and now when we simplify the LHS (Simplify.simplRule) we
-  optCoercion will turn that 'c' into Refl:
+  optCoercion (look at the CoVarCo case) will turn that 'c' into Refl:
     RULE forall (c :: Int~Int). f (x |> <Int>) = e
   and then perhaps drop it altogether.  Now 'c' is unbound.
 
@@ -655,7 +630,6 @@ bound on the LHS:
   This actually happened (in a RULE for a local function)
   in #13410, and also in test T10602.
 
-
 Note [Cloning the template binders]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider the following match (example 1):