Fix #12442.
[ghc.git] / compiler / specialise / Rules.hs
index f7a67ea..7909bdc 100644 (file)
@@ -33,13 +33,13 @@ import Module           ( Module, ModuleSet, elemModuleSet )
 import CoreSubst
 import OccurAnal        ( occurAnalyseExpr )
 import CoreFVs          ( exprFreeVars, exprsFreeVars, bindFreeVars
-                        , rulesFreeVarsDSet, exprsOrphNames )
+                        , rulesFreeVarsDSet, exprsOrphNames, exprFreeVarsList )
 import CoreUtils        ( exprType, eqExpr, mkTick, mkTicks,
                           stripTicksTopT, stripTicksTopE )
 import PprCore          ( pprRules )
-import Type             ( Type, substTy, mkTvSubst )
+import Type             ( Type, substTy, mkTCvSubst )
 import TcType           ( tcSplitTyConApp_maybe )
-import TysPrim          ( anyTypeOfKind )
+import TysWiredIn       ( anyTypeOfKind )
 import Coercion
 import CoreTidy         ( tidyRules )
 import Id
@@ -50,7 +50,8 @@ import VarSet
 import Name             ( Name, NamedThing(..), nameIsLocalOrFrom )
 import NameSet
 import NameEnv
-import Unify            ( ruleMatchTyX, MatchEnv(..) )
+import UniqFM
+import Unify            ( ruleMatchTyKiX )
 import BasicTypes       ( Activation, CompilerPhase, isActive, pprRuleName )
 import StaticFlags      ( opt_PprStyle_Debug )
 import DynFlags         ( DynFlags )
@@ -61,6 +62,7 @@ import Bag
 import Util
 import Data.List
 import Data.Ord
+import Control.Monad    ( guard )
 
 {-
 Note [Overall plumbing for rules]
@@ -179,13 +181,13 @@ mkRule this_mod is_auto is_local name act fn bndrs args rhs
         -- Compute orphanhood.  See Note [Orphans] in InstEnv
         -- A rule is an orphan only if none of the variables
         -- mentioned on its left-hand side are locally defined
-    lhs_names = nameSetElems (extendNameSet (exprsOrphNames args) fn)
+    lhs_names = extendNameSet (exprsOrphNames args) fn
 
         -- Since rules get eventually attached to one of the free names
         -- from the definition when compiling the ABI hash, we should make
         -- it deterministic. This chooses the one with minimal OccName
         -- as opposed to uniq value.
-    local_lhs_names = filter (nameIsLocalOrFrom this_mod) lhs_names
+    local_lhs_names = filterNameSet (nameIsLocalOrFrom this_mod) lhs_names
     orph = chooseOrphanAnchor local_lhs_names
 
 --------------
@@ -356,8 +358,9 @@ extendRuleBase rule_base rule
   = extendNameEnv_Acc (:) singleton rule_base (ruleIdName rule) rule
 
 pprRuleBase :: RuleBase -> SDoc
-pprRuleBase rules = vcat [ pprRules (tidyRules emptyTidyEnv rs)
-                         | rs <- nameEnvElts rules ]
+pprRuleBase rules = pprUFM rules $ \rss ->
+  vcat [ pprRules (tidyRules emptyTidyEnv rs)
+       | rs <- rss ]
 
 {-
 ************************************************************************
@@ -420,10 +423,10 @@ findBest target (rule1,ans1) ((rule2,ans2):prs)
                         | otherwise          = doubleQuotes (ftext (ru_name rule))
                 in pprTrace "Rules.findBest: rule overlap (Rule 1 wins)"
                          (vcat [if opt_PprStyle_Debug then
-                                   ptext (sLit "Expression to match:") <+> ppr fn <+> sep (map ppr args)
+                                   text "Expression to match:" <+> ppr fn <+> sep (map ppr args)
                                 else empty,
-                                ptext (sLit "Rule 1:") <+> pp_rule rule1,
-                                ptext (sLit "Rule 2:") <+> pp_rule rule2]) $
+                                text "Rule 1:" <+> pp_rule rule1,
+                                text "Rule 2:" <+> pp_rule rule2]) $
                 findBest target (rule1,ans1) prs
   | otherwise = findBest target (rule1,ans1) prs
   where
@@ -561,14 +564,26 @@ matchN (in_scope, id_unf) rule_name tmpl_vars tmpl_es target_es
              -- See Note [Unbound template type variables]
         where
           fake_ty = anyTypeOfKind kind
-          kind = Type.substTy (mkTvSubst in_scope tv_subst) (tyVarKind tmpl_var)
+          cv_subst = to_co_env id_subst
+          kind = Type.substTy (mkTCvSubst in_scope (tv_subst, cv_subst))
+                              (tyVarKind tmpl_var)
+
+          to_co_env env = nonDetFoldUFM_Directly to_co emptyVarEnv env
+            -- It's OK to use nonDetFoldUFM_Directly because we forget the
+            -- order immediately by creating a new env
+          to_co uniq expr env
+            | Just co <- exprToCoercion_maybe expr
+            = extendVarEnv_Directly env uniq co
+
+            | otherwise
+            = env
 
     unbound var = pprPanic "Template variable unbound in rewrite rule" $
-                  vcat [ ptext (sLit "Variable:") <+> ppr var
-                       , ptext (sLit "Rule") <+> pprRuleName rule_name
-                       , ptext (sLit "Rule bndrs:") <+> ppr tmpl_vars
-                       , ptext (sLit "LHS args:") <+> ppr tmpl_es
-                       , ptext (sLit "Actual args:") <+> ppr target_es ]
+                  vcat [ text "Variable:" <+> ppr var
+                       , text "Rule" <+> pprRuleName rule_name
+                       , text "Rule bndrs:" <+> ppr tmpl_vars
+                       , text "LHS args:" <+> ppr tmpl_es
+                       , text "Actual args:" <+> ppr target_es ]
 
 {- Note [Unbound template type variables]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -779,19 +794,20 @@ match_co :: RuleMatchEnv
          -> Coercion
          -> Coercion
          -> Maybe RuleSubst
-match_co renv subst (CoVarCo cv) co
-  = match_var renv subst cv (Coercion co)
-match_co renv subst (Refl r1 ty1) co
-  = case co of
-       Refl r2 ty2
-         | r1 == r2 -> match_ty renv subst ty1 ty2
-       _            -> Nothing
-match_co renv subst (TyConAppCo r1 tc1 cos1) co2
-  = case co2 of
-       TyConAppCo r2 tc2 cos2
-         | r1 == r2 && tc1 == tc2
-         -> match_cos renv subst cos1 cos2
-       _ -> Nothing
+match_co renv subst co1 co2
+  | Just cv <- getCoVar_maybe co1
+  = match_var renv subst cv (Coercion co2)
+  | Just (ty1, r1) <- isReflCo_maybe co1
+  = do { (ty2, r2) <- isReflCo_maybe co2
+       ; guard (r1 == r2)
+       ; match_ty renv subst ty1 ty2 }
+match_co renv subst co1 co2
+  | Just (tc1, cos1) <- splitTyConAppCo_maybe co1
+  = case splitTyConAppCo_maybe co2 of
+      Just (tc2, cos2)
+        |  tc1 == tc2
+        -> match_cos renv subst cos1 cos2
+      _ -> Nothing
 match_co _ _ _co1 _co2
     -- Currently just deals with CoVarCo, TyConAppCo and Refl
 #ifdef DEBUG
@@ -806,13 +822,11 @@ match_cos :: RuleMatchEnv
          -> [Coercion]
          -> Maybe RuleSubst
 match_cos renv subst (co1:cos1) (co2:cos2) =
-    case match_co renv subst co1 co2 of
-       Just subst' -> match_cos renv subst' cos1 cos2
-       Nothing -> Nothing
+  do { subst' <- match_co renv subst co1 co2
+     ; match_cos renv subst' cos1 cos2 }
 match_cos _ subst [] [] = Just subst
 match_cos _ _ cos1 cos2 = pprTrace "match_cos: not same length" (ppr cos1 $$ ppr cos2) Nothing
 
-
 -------------
 rnMatchBndr2 :: RuleMatchEnv -> RuleSubst -> Var -> Var -> RuleMatchEnv
 rnMatchBndr2 renv subst x1 x2
@@ -845,7 +859,7 @@ match_alts _ _ _ _
 ------------------------------------------
 okToFloat :: RnEnv2 -> VarSet -> Bool
 okToFloat rn_env bind_fvs
-  = foldVarSet ((&&) . not_captured) True bind_fvs
+  = varSetAll not_captured bind_fvs
   where
     not_captured fv = not (inRnEnvR rn_env fv)
 
@@ -888,7 +902,7 @@ match_tmpl_var :: RuleMatchEnv
 match_tmpl_var renv@(RV { rv_lcl = rn_env, rv_fltR = flt_env })
                subst@(RS { rs_id_subst = id_subst, rs_bndrs = let_bndrs })
                v1' e2
-  | any (inRnEnvR rn_env) (varSetElems (exprFreeVars e2))
+  | any (inRnEnvR rn_env) (exprFreeVarsList e2)
   = Nothing     -- Occurs check failure
                 -- e.g. match forall a. (\x-> a x) against (\y. y y)
 
@@ -932,11 +946,11 @@ match_ty :: RuleMatchEnv
 -- We only want to replace (f T) with f', not (f Int).
 
 match_ty renv subst ty1 ty2
-  = do  { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2
+  = do  { tv_subst'
+            <- Unify.ruleMatchTyKiX (rv_tmpls renv) (rv_lcl renv) tv_subst ty1 ty2
         ; return (subst { rs_tv_subst = tv_subst' }) }
   where
     tv_subst = rs_tv_subst subst
-    menv = ME { me_tmpls = rv_tmpls renv, me_env = rv_lcl renv }
 
 {-
 Note [Expanding variables]
@@ -1175,9 +1189,9 @@ ruleAppCheck_help env fn args rules
                       rule_herald rule <> colon <+> rule_info dflags rule
 
     rule_herald (BuiltinRule { ru_name = name })
-        = ptext (sLit "Builtin rule") <+> doubleQuotes (ftext name)
+        = text "Builtin rule" <+> doubleQuotes (ftext name)
     rule_herald (Rule { ru_name = name })
-        = ptext (sLit "Rule") <+> doubleQuotes (ftext name)
+        = text "Rule" <+> doubleQuotes (ftext name)
 
     rule_info dflags rule
         | Just _ <- matchRule dflags (emptyInScopeSet, rc_id_unf env)