Fix #11754 by adding an additional check.
authorRichard Eisenberg <eir@cis.upenn.edu>
Fri, 25 Mar 2016 19:11:24 +0000 (15:11 -0400)
committerBen Gamari <ben@smart-cactus.org>
Mon, 28 Mar 2016 09:31:42 +0000 (11:31 +0200)
This was just plain wrong previously.

Test case: typecheck/should_compile/T11754

(cherry picked from commit 4da8e73d5235b0000ae27aa8ff8438a3687b6e9c)

compiler/types/OptCoercion.hs
testsuite/tests/typecheck/should_compile/T11754.hs [new file with mode: 0644]
testsuite/tests/typecheck/should_compile/all.T

index fb6c68e..e39f0aa 100644 (file)
@@ -874,10 +874,11 @@ etaTyConAppCo_maybe tc (TyConAppCo _ tc2 cos2)
 
 etaTyConAppCo_maybe tc co
   | mightBeUnsaturatedTyCon tc
-  , Pair ty1 ty2     <- coercionKind co
-  , Just (tc1, tys1) <- splitTyConApp_maybe ty1
-  , Just (tc2, tys2) <- splitTyConApp_maybe ty2
+  , (Pair ty1 ty2, r) <- coercionKindRole co
+  , Just (tc1, tys1)  <- splitTyConApp_maybe ty1
+  , Just (tc2, tys2)  <- splitTyConApp_maybe ty2
   , tc1 == tc2
+  , isInjectiveTyCon tc r  -- See Note [NthCo and newtypes] in TyCoRep
   , let n = length tys1
   = ASSERT( tc == tc1 )
     ASSERT( n == length tys2 )
diff --git a/testsuite/tests/typecheck/should_compile/T11754.hs b/testsuite/tests/typecheck/should_compile/T11754.hs
new file mode 100644 (file)
index 0000000..248be2b
--- /dev/null
@@ -0,0 +1,28 @@
+{-# LANGUAGE TypeOperators, UndecidableSuperClasses, KindSignatures,
+TypeFamilies, FlexibleContexts #-}
+
+module T11754 where
+
+import Data.Kind
+import Data.Void
+
+newtype K a x = K a
+newtype I   x = I x
+
+data (f + g) x = L (f x) | R (g x)
+data (f × g) x = f x :×: g x
+
+class Differentiable (D f) => Differentiable f where
+  type D (f :: Type -> Type) :: Type -> Type
+
+instance Differentiable (K a) where
+  type D (K a) = K Void
+
+instance Differentiable I where
+  type D I = K ()
+
+instance (Differentiable f₁, Differentiable f₂) => Differentiable (f₁ + f₂) where
+  type D (f₁ + f₂) = D f₁ + D f₂
+
+instance (Differentiable f₁, Differentiable f₂) => Differentiable (f₁ × f₂) where
+  type D (f₁ × f₂) = (D f₁ × f₂) + (f₁ × D f₂)
index f1403da..158de37 100644 (file)
@@ -508,3 +508,4 @@ test('T11608', normal, compile, [''])
 test('T11401', normal, compile, [''])
 test('T11699', normal, compile, [''])
 test('T11512', normal, compile, [''])
+test('T11754', normal, compile, [''])