Fix deriveTyData's kind unification when two kind variables are unified
authorRyanGlScott <ryan.gl.scott@gmail.com>
Wed, 11 May 2016 13:57:24 +0000 (15:57 +0200)
committerBen Gamari <ben@smart-cactus.org>
Tue, 23 Aug 2016 21:03:19 +0000 (17:03 -0400)
When `deriveTyData` attempts to unify two kind variables (which can
happen if both the typeclass and the datatype are poly-kinded), it
mistakenly adds an extra mapping to its substitution which causes the
unification to fail when applying the substitution. This can be
prevented by checking both the domain and the range of the original
substitution to see which kind variables shouldn't be put into the
domain of the substitution. A more in-depth explanation is included in
`Note [Unification of two kind variables in deriving]`.

Fixes #11837.

Test Plan: ./validate

Reviewers: simonpj, hvr, goldfire, niteria, austin, bgamari

Reviewed By: bgamari

Subscribers: niteria, thomie

Differential Revision: https://phabricator.haskell.org/D2117

GHC Trac Issues: #11837

(cherry picked from commit e53f2180e89652c72e51ffa614c56294ba67cf37)

compiler/typecheck/TcDeriv.hs
compiler/types/TyCoRep.hs
compiler/types/Type.hs
testsuite/tests/deriving/should_compile/T11837.hs [new file with mode: 0644]
testsuite/tests/deriving/should_compile/all.T

index 944c513..d37762a 100644 (file)
@@ -651,17 +651,23 @@ deriveTyData tvs tc tc_args deriv_pred
               -- We are assuming the tycon tyvars and the class tyvars are distinct
               mb_match        = tcUnifyTy inst_ty_kind cls_arg_kind
               Just kind_subst = mb_match
+              ki_subst_range  = getTCvSubstRangeFVs kind_subst
 
               all_tkvs        = toposortTyVars $
                                 fvVarList $ unionFV
                                   (tyCoFVsOfTypes tc_args_to_keep)
                                   (FV.mkFVs deriv_tvs)
 
-              unmapped_tkvs   = filter (`notElemTCvSubst` kind_subst) all_tkvs
-              (subst, tkvs)   = mapAccumL substTyVarBndr
+              -- See Note [Unification of two kind variables in deriving]
+              unmapped_tkvs   = filter (\v -> v `notElemTCvSubst` kind_subst
+                                      && not (v `elemVarSet` ki_subst_range))
+                                       all_tkvs
+              (subst, _)      = mapAccumL substTyVarBndr
                                           kind_subst unmapped_tkvs
               final_tc_args   = substTys subst tc_args_to_keep
               final_cls_tys   = substTys subst cls_tys
+              tkvs            = tyCoVarsOfTypesWellScoped $
+                                final_cls_tys ++ final_tc_args
 
         ; traceTc "derivTyData1" (vcat [ pprTvBndrs tvs, ppr tc, ppr tc_args, ppr deriv_pred
                                        , pprTvBndrs (tyCoVarsOfTypesList tc_args)
@@ -800,6 +806,46 @@ Even though we requested an derived instance of the form (Cat k Fun), the
 kind unification will actually generate (Cat * Fun) (i.e., the same thing as if
 the user wrote deriving (Cat *)).
 
+Note [Unification of two kind variables in deriving]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+As a special case of the Note above, it is possible to derive an instance of
+a poly-kinded typeclass for a poly-kinded datatype. For example:
+
+    class Category (cat :: k -> k -> *) where
+    newtype T (c :: k -> k -> *) a b = MkT (c a b) deriving Category
+
+This case is suprisingly tricky. To see why, let's write out what instance GHC
+will attempt to derive (using -fprint-explicit-kinds syntax):
+
+    instance Category k1 (T k2 c) where ...
+
+GHC will attempt to unify k1 and k2, which produces a substitution (kind_subst)
+that looks like [k2 :-> k1]. Importantly, we need to apply this substitution to
+the type variable binder for c, since its kind is (k2 -> k2 -> *).
+
+We used to accomplish this by doing the following:
+
+    unmapped_tkvs = filter (`notElemTCvSubst` kind_subst) all_tkvs
+    (subst, _)    = mapAccumL substTyVarBndr kind_subst unmapped_tkvs
+
+Where all_tkvs contains all kind variables in the class and instance types (in
+this case, all_tkvs = [k1,k2]). But since kind_subst only has one mapping,
+this results in unmapped_tkvs being [k1], and as a consequence, k1 gets mapped
+to another kind variable in subst! That is, subst = [k2 :-> k1, k1 :-> k_new].
+This is bad, because applying that substitution yields the following instance:
+
+   instance Category k_new (T k1 c) where ...
+
+In other words, keeping k1 in unmapped_tvks taints the substitution, resulting
+in an ill-kinded instance (this caused Trac #11837).
+
+To prevent this, we need to filter out any variable from all_tkvs which either
+
+1. Appears in the domain of kind_subst. notElemTCvSubst checks this.
+2. Appears in the range of kind_subst. To do this, we compute the free
+   variable set of the range of kind_subst with getTCvSubstRangeFVs, and check
+   if a kind variable appears in that set.
+
 Note [Eta-reducing type synonyms]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 One can instantiate a type in a data family instance with a type synonym that
@@ -1070,11 +1116,15 @@ inferConstraints tvs main_cls cls_tys inst_ty rep_tc rep_tc_args mkTheta
             -- (which is the case for functor-like constraints), then we
             -- explicitly unify the subtype's kinds with (* -> *).
             -- See Note [Inferring the instance context]
-            subst = foldl' composeTCvSubst emptyTCvSubst (catMaybes mbSubsts)
-            unmapped_tvs   = filter (`notElemTCvSubst` subst) tvs
-            (subst', tvs') = mapAccumL substTyVarBndr subst unmapped_tvs
+            subst          = foldl' composeTCvSubst
+                                    emptyTCvSubst (catMaybes mbSubsts)
+            subst_range    = getTCvSubstRangeFVs subst
+            unmapped_tvs   = filter (\v -> v `notElemTCvSubst` subst
+                                   && not (v `elemVarSet` subst_range)) tvs
+            (subst', _)    = mapAccumL substTyVarBndr subst unmapped_tvs
             preds'         = substThetaOrigin subst' preds
             inst_tys'      = substTys subst' inst_tys
+            tvs'           = tyCoVarsOfTypesWellScoped inst_tys'
         in mkTheta preds' tvs' inst_tys'
 
     -- is_functor_like: see Note [Inferring the instance context]
index 9dd54fb..470b240 100644 (file)
@@ -80,7 +80,8 @@ module TyCoRep (
         emptyTCvSubst, mkEmptyTCvSubst, isEmptyTCvSubst,
         mkTCvSubst, mkTvSubst,
         getTvSubstEnv,
-        getCvSubstEnv, getTCvInScope, isInScope, notElemTCvSubst,
+        getCvSubstEnv, getTCvInScope, getTCvSubstRangeFVs,
+        isInScope, notElemTCvSubst,
         setTvSubstEnv, setCvSubstEnv, zapTCvSubst,
         extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
         extendTCvSubst,
@@ -1654,6 +1655,15 @@ getCvSubstEnv (TCvSubst _ _ env) = env
 getTCvInScope :: TCvSubst -> InScopeSet
 getTCvInScope (TCvSubst in_scope _ _) = in_scope
 
+-- | Returns the free variables of the types in the range of a substitution as
+-- a non-deterministic set.
+getTCvSubstRangeFVs :: TCvSubst -> VarSet
+getTCvSubstRangeFVs (TCvSubst _ tenv cenv)
+    = unionVarSet tenvFVs cenvFVs
+  where
+    tenvFVs = tyCoVarsOfTypes $ varEnvElts tenv
+    cenvFVs = tyCoVarsOfCos   $ varEnvElts cenv
+
 isInScope :: Var -> TCvSubst -> Bool
 isInScope v (TCvSubst in_scope _ _) = v `elemInScopeSet` in_scope
 
index 69cf69f..d0089f4 100644 (file)
@@ -155,7 +155,7 @@ module Type (
         mkTCvSubst, zipTvSubst, mkTvSubstPrs,
         notElemTCvSubst,
         getTvSubstEnv, setTvSubstEnv,
-        zapTCvSubst, getTCvInScope,
+        zapTCvSubst, getTCvInScope, getTCvSubstRangeFVs,
         extendTCvInScope, extendTCvInScopeList, extendTCvInScopeSet,
         extendTCvSubst, extendCvSubst,
         extendTvSubst, extendTvSubstList, extendTvSubstAndInScope,
diff --git a/testsuite/tests/deriving/should_compile/T11837.hs b/testsuite/tests/deriving/should_compile/T11837.hs
new file mode 100644 (file)
index 0000000..917f9cb
--- /dev/null
@@ -0,0 +1,9 @@
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE PolyKinds #-}
+module T11837 where
+
+class Category (cat :: k -> k -> *) where
+  catId   :: cat a a
+  catComp :: cat b c -> cat a b -> cat a c
+
+newtype T (c :: k -> k -> *) a b = MkT (c a b) deriving Category
index 07242ec..9017687 100644 (file)
@@ -70,3 +70,4 @@ test('T11732a', normal, compile, [''])
 test('T11732b', normal, compile, [''])
 test('T11732c', normal, compile, [''])
 test('T11833', normal, compile, [''])
+test('T11837', normal, compile, [''])