Avoid creating dependent types in FloatOut
authorSimon Peyton Jones <simonpj@microsoft.com>
Wed, 11 Oct 2017 13:58:38 +0000 (14:58 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Wed, 11 Oct 2017 14:00:48 +0000 (15:00 +0100)
This bug was exposed by Trac #14270.  The problem and its cure
is described in SetLevels, Note [Floating and kind casts].

It's simple and will affect very few programs.  But the very
fact that it was so unexpected is discomforting.

compiler/simplCore/SetLevels.hs
testsuite/tests/polykinds/T14270.hs [new file with mode: 0644]
testsuite/tests/polykinds/all.T

index 5a09db3..2b73128 100644 (file)
@@ -81,12 +81,13 @@ import Id
 import IdInfo
 import Var
 import VarSet
+import UniqSet          ( nonDetFoldUniqSet )
 import VarEnv
 import Literal          ( litIsTrivial )
 import Demand           ( StrictSig, Demand, isStrictDmd, splitStrictSig, increaseStrictSigArity )
 import Name             ( getOccName, mkSystemVarName )
 import OccName          ( occNameString )
-import Type             ( Type, mkLamTypes, splitTyConApp_maybe )
+import Type             ( Type, mkLamTypes, splitTyConApp_maybe, tyCoVarsOfType )
 import BasicTypes       ( Arity, RecFlag(..), isRec )
 import DataCon          ( dataConOrigResTy )
 import TysWiredIn
@@ -629,13 +630,14 @@ lvlMFE env strict_ctxt ann_expr
     expr         = deAnnotate ann_expr
     expr_ty      = exprType expr
     fvs          = freeVarsOf ann_expr
+    fvs_ty       = tyCoVarsOfType expr_ty
     is_bot       = isBottomThunk mb_bot_str
     is_function  = isFunction ann_expr
     mb_bot_str   = exprBotStrictness_maybe expr
                            -- See Note [Bottoming floats]
                            -- esp Bottoming floats (2)
     expr_ok_for_spec = exprOkForSpeculation expr
-    dest_lvl     = destLevel env fvs is_function is_bot False
+    dest_lvl     = destLevel env fvs fvs_ty is_function is_bot False
     abs_vars     = abstractVars dest_lvl env fvs
 
     -- float_is_new_lam: the floated thing will be a new value lambda
@@ -1028,7 +1030,7 @@ lvlBind env (AnnNonRec bndr rhs)
   || isCoVar bndr   -- Difficult to fix up CoVar occurrences (see extendPolyLvlEnv)
                     -- so we will ignore this case for now
   || not (profitableFloat env dest_lvl)
-  || (isTopLvl dest_lvl && not (exprIsTopLevelBindable deann_rhs (idType bndr)))
+  || (isTopLvl dest_lvl && not (exprIsTopLevelBindable deann_rhs bndr_ty))
           -- We can't float an unlifted binding to top level (except
           -- literal strings), so we don't float it at all.  It's a
           -- bit brutal, but unlifted bindings aren't expensive either
@@ -1057,10 +1059,12 @@ lvlBind env (AnnNonRec bndr rhs)
        ; return (NonRec (TB bndr2 (FloatMe dest_lvl)) rhs', env') }
 
   where
+    bndr_ty    = idType bndr
+    ty_fvs     = tyCoVarsOfType bndr_ty
     rhs_fvs    = freeVarsOf rhs
     bind_fvs   = rhs_fvs `unionDVarSet` dIdFreeVars bndr
     abs_vars   = abstractVars dest_lvl env bind_fvs
-    dest_lvl   = destLevel env bind_fvs (isFunction rhs) is_bot is_join
+    dest_lvl   = destLevel env bind_fvs ty_fvs (isFunction rhs) is_bot is_join
 
     deann_rhs  = deAnnotate rhs
     mb_bot_str = exprBotStrictness_maybe deann_rhs
@@ -1151,7 +1155,8 @@ lvlBind env (AnnRec pairs)
                `delDVarSetList`
                 bndrs
 
-    dest_lvl = destLevel env bind_fvs is_fun is_bot is_join
+    ty_fvs   = foldr (unionVarSet . tyCoVarsOfType . idType) emptyVarSet bndrs
+    dest_lvl = destLevel env bind_fvs ty_fvs is_fun is_bot is_join
     abs_vars = abstractVars dest_lvl env bind_fvs
 
 profitableFloat :: LevelEnv -> Level -> Bool
@@ -1314,13 +1319,16 @@ stayPut new_lvl bndr = TB bndr (StayPut new_lvl)
 
   -- Destination level is the max Id level of the expression
   -- (We'll abstract the type variables, if any.)
-destLevel :: LevelEnv -> DVarSet
+destLevel :: LevelEnv
+          -> DVarSet    -- Free vars of the term
+          -> TyCoVarSet -- Free in the /type/ of the term
+                        -- (a subset of the previous argument)
           -> Bool   -- True <=> is function
           -> Bool   -- True <=> is bottom
           -> Bool   -- True <=> is a join point
           -> Level
 -- INVARIANT: if is_join=True then result >= join_ceiling
-destLevel env fvs is_function is_bot is_join
+destLevel env fvs fvs_ty is_function is_bot is_join
   | isTopLvl max_fv_id_level  -- Float even joins if they get to top level
                               -- See Note [Floating join point bindings]
   = tOP_LEVEL
@@ -1332,21 +1340,48 @@ destLevel env fvs is_function is_bot is_join
     else max_fv_id_level
 
   | is_bot              -- Send bottoming bindings to the top
-  = tOP_LEVEL           -- regardless; see Note [Bottoming floats]
+  = as_far_as_poss      -- regardless; see Note [Bottoming floats]
                         -- Esp Bottoming floats (1)
 
   | Just n_args <- floatLams env
   , n_args > 0  -- n=0 case handled uniformly by the 'otherwise' case
   , is_function
   , countFreeIds fvs <= n_args
-  = tOP_LEVEL   -- Send functions to top level; see
-                -- the comments with isFunction
+  = as_far_as_poss  -- Send functions to top level; see
+                    -- the comments with isFunction
 
   | otherwise = max_fv_id_level
   where
-    max_fv_id_level = maxFvLevel isId env fvs -- Max over Ids only; the tyvars
-                                              -- will be abstracted
-    join_ceiling = joinCeilingLevel env
+    join_ceiling    = joinCeilingLevel env
+    max_fv_id_level = maxFvLevel isId env fvs -- Max over Ids only; the
+                                              -- tyvars will be abstracted
+
+    as_far_as_poss = maxFvLevel' isId env fvs_ty
+                     -- See Note [Floating and kind casts]
+
+{- Note [Floating and kind casts]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this
+   case x of
+     K (co :: * ~# k) -> let v :: Int |> co
+                             v = e
+                         in blah
+
+Then, even if we are abstracting over Ids, or if e is bottom, we can't
+float v outside the 'co' binding.  Reason: if we did we'd get
+    v' :: forall k. (Int ~# Age) => Int |> co
+and now 'co' isn't in scope in that type. The underlying reason is
+that 'co' is a value-level thing and we can't abstract over that in a
+type (else we'd get a dependent type).  So if v's /type/ mentions 'co'
+we can't float it out beyond the binding site of 'co'.
+
+That's why we have this as_far_as_poss stuff.  Usually as_far_as_poss
+is just tOP_LEVEL; but occasionally a coercion variable (which is an
+Id) mentioned in type prevents this.
+
+Example Trac #14270 comment:15.
+-}
+
 
 isFunction :: CoreExprWithFVs -> Bool
 -- The idea here is that we want to float *functions* to
@@ -1480,14 +1515,20 @@ placeJoinCeiling le@(LE { le_ctxt_lvl = lvl })
     lvl' = asJoinCeilLvl (incMinorLvl lvl)
 
 maxFvLevel :: (Var -> Bool) -> LevelEnv -> DVarSet -> Level
-maxFvLevel max_me (LE { le_lvl_env = lvl_env, le_env = id_env }) var_set
-  = foldDVarSet max_in tOP_LEVEL var_set
+maxFvLevel max_me env var_set
+  = foldDVarSet (maxIn max_me env) tOP_LEVEL var_set
+
+maxFvLevel' :: (Var -> Bool) -> LevelEnv -> TyCoVarSet -> Level
+-- Same but for TyCoVarSet
+maxFvLevel' max_me env var_set
+  = nonDetFoldUniqSet (maxIn max_me env) tOP_LEVEL var_set
+
+maxIn :: (Var -> Bool) -> LevelEnv -> InVar -> Level -> Level
+maxIn max_me (LE { le_lvl_env = lvl_env, le_env = id_env }) in_var lvl
+  = case lookupVarEnv id_env in_var of
+      Just (abs_vars, _) -> foldr max_out lvl abs_vars
+      Nothing            -> max_out in_var lvl
   where
-    max_in in_var lvl
-       = foldr max_out lvl (case lookupVarEnv id_env in_var of
-                                Just (abs_vars, _) -> abs_vars
-                                Nothing            -> [in_var])
-
     max_out out_var lvl
         | max_me out_var = case lookupVarEnv lvl_env out_var of
                                 Just lvl' -> maxLvl lvl' lvl
diff --git a/testsuite/tests/polykinds/T14270.hs b/testsuite/tests/polykinds/T14270.hs
new file mode 100644 (file)
index 0000000..2d11a29
--- /dev/null
@@ -0,0 +1,110 @@
+{-# LANGUAGE TypeInType #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE TypeApplications #-}
+module T14270 (pattern App) where
+
+import Data.Kind (Type)
+import GHC.Fingerprint (Fingerprint, fingerprintFingerprints)
+import GHC.Types (RuntimeRep, TYPE, TyCon)
+
+data (a :: k1) :~~: (b :: k2) where
+  HRefl :: a :~~: a
+
+data TypeRep (a :: k) where
+    TrTyCon :: {-# UNPACK #-} !Fingerprint -> !TyCon -> [SomeTypeRep]
+            -> TypeRep (a :: k)
+
+    TrApp   :: forall k1 k2 (a :: k1 -> k2) (b :: k1).
+               {-# UNPACK #-} !Fingerprint
+            -> TypeRep (a :: k1 -> k2)
+            -> TypeRep (b :: k1)
+            -> TypeRep (a b)
+
+    TrFun   :: forall (r1 :: RuntimeRep) (r2 :: RuntimeRep)
+                      (a :: TYPE r1) (b :: TYPE r2).
+               {-# UNPACK #-} !Fingerprint
+            -> TypeRep a
+            -> TypeRep b
+            -> TypeRep (a -> b)
+
+data SomeTypeRep where
+    SomeTypeRep :: forall k (a :: k). !(TypeRep a) -> SomeTypeRep
+
+typeRepFingerprint :: TypeRep a -> Fingerprint
+typeRepFingerprint = undefined
+
+mkTrApp :: forall k1 k2 (a :: k1 -> k2) (b :: k1).
+           TypeRep (a :: k1 -> k2)
+        -> TypeRep (b :: k1)
+        -> TypeRep (a b)
+mkTrApp rep@(TrApp _ (TrTyCon _ con _) (x :: TypeRep x)) (y :: TypeRep y)
+  | con == funTyCon  -- cheap check first
+  , Just (IsTYPE (rx :: TypeRep rx)) <- isTYPE (typeRepKind x)
+  , Just (IsTYPE (ry :: TypeRep ry)) <- isTYPE (typeRepKind y)
+  , Just HRefl <- withTypeable x $ withTypeable rx $ withTypeable ry
+                  $ typeRep @((->) x :: TYPE ry -> Type) `eqTypeRep` rep
+  = undefined
+mkTrApp a b = TrApp fpr a b
+  where
+    fpr_a = typeRepFingerprint a
+    fpr_b = typeRepFingerprint b
+    fpr   = fingerprintFingerprints [fpr_a, fpr_b]
+
+pattern App :: forall k2 (t :: k2). ()
+            => forall k1 (a :: k1 -> k2) (b :: k1). (t ~ a b)
+            => TypeRep a -> TypeRep b -> TypeRep t
+pattern App f x <- (splitApp -> Just (IsApp f x))
+  where App f x = mkTrApp f x
+
+data IsApp (a :: k) where
+    IsApp :: forall k k' (f :: k' -> k) (x :: k'). ()
+          => TypeRep f -> TypeRep x -> IsApp (f x)
+
+splitApp :: forall k (a :: k). ()
+         => TypeRep a
+         -> Maybe (IsApp a)
+splitApp (TrApp _ f x)     = Just (IsApp f x)
+splitApp rep@(TrFun _ a b) = Just (IsApp (mkTrApp arr a) b)
+  where arr = bareArrow rep
+splitApp (TrTyCon{})       = Nothing
+
+withTypeable :: forall a r. TypeRep a -> (Typeable a => r) -> r
+withTypeable = undefined
+
+eqTypeRep :: forall k1 k2 (a :: k1) (b :: k2).
+             TypeRep a -> TypeRep b -> Maybe (a :~~: b)
+eqTypeRep = undefined
+
+typeRepKind :: TypeRep (a :: k) -> TypeRep k
+typeRepKind = undefined
+
+bareArrow :: forall (r1 :: RuntimeRep) (r2 :: RuntimeRep)
+                    (a :: TYPE r1) (b :: TYPE r2). ()
+          => TypeRep (a -> b)
+          -> TypeRep ((->) :: TYPE r1 -> TYPE r2 -> Type)
+bareArrow = undefined
+
+data IsTYPE (a :: Type) where
+    IsTYPE :: forall (r :: RuntimeRep). TypeRep r -> IsTYPE (TYPE r)
+
+isTYPE :: TypeRep (a :: Type) -> Maybe (IsTYPE a)
+isTYPE = undefined
+
+class Typeable (a :: k) where
+
+typeRep :: Typeable a => TypeRep a
+typeRep = undefined
+
+funTyCon :: TyCon
+funTyCon = undefined
+
+instance (Typeable f, Typeable a) => Typeable (f a)
+instance Typeable ((->) :: TYPE r -> TYPE s -> Type)
+instance Typeable TYPE
index fc7249e..66bd9b1 100644 (file)
@@ -172,3 +172,4 @@ test('T14209', normal, compile, [''])
 test('T14265', normal, compile_fail, [''])
 test('T13391', normal, compile_fail, [''])
 test('T13391a', normal, compile, [''])
+test('T14270', normal, compile, [''])