Use a ReaderT in TcDeriv to avoid some tedious plumbing
authorRyan Scott <ryan.gl.scott@gmail.com>
Tue, 15 Aug 2017 00:56:04 +0000 (20:56 -0400)
committerBen Gamari <ben@smart-cactus.org>
Tue, 15 Aug 2017 01:32:17 +0000 (21:32 -0400)
Addresses point (2) of https://phabricator.haskell.org/D3337#107865.

Before, several functions in `TcDeriv` and `TcDerivInfer` which compute
an `EarlyDerivSpec` were manually threading through about 10 different
arguments, which contribute to quite a lot of clutter whenever they need
to be updated. To minimize this plumbing, and to make it clearer which
of these 10 values are being used where, I refactored the code in
`TcDeriv` and `TcDerivInfer` to use a new `DerivM` type:

```lang=haskell
type DerivM = ReaderT DerivEnv TcRn
```

where `DerivEnv` contains the 10 aforementioned values. In addition to
cleaning up the code, this should make some subsequent changes planned
for later less noisy.

Test Plan: ./validate

Reviewers: austin, bgamari

Subscribers: rwbarton, thomie

Differential Revision: https://phabricator.haskell.org/D3846

compiler/typecheck/TcDeriv.hs
compiler/typecheck/TcDerivInfer.hs
compiler/typecheck/TcDerivUtils.hs

index c462256..829b4c9 100644 (file)
@@ -28,7 +28,6 @@ import InstEnv
 import Inst
 import FamInstEnv
 import TcHsType
-import TcMType
 
 import RnNames( extendGlobalRdrEnvRn )
 import RnBinds
@@ -63,6 +62,8 @@ import FV (fvVarList, unionFV, mkFVs)
 import qualified GHC.LanguageExtensions as LangExt
 
 import Control.Monad
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Reader
 import Data.List
 
 {-
@@ -954,13 +955,19 @@ mkEqnHelp overlap_mode tvs cls cls_tys tycon tc_args mtheta deriv_strat
        ; when (isDataFamilyTyCon rep_tc)
               (bale_out (text "No family instance for" <+> quotes (pprTypeApp tycon tc_args)))
 
-       ; dflags <- getDynFlags
-       ; if isDataTyCon rep_tc then
-            mkDataTypeEqn dflags overlap_mode tvs cls cls_tys
-                          tycon tc_args rep_tc rep_tc_args mtheta deriv_strat
-         else
-            mkNewTypeEqn dflags overlap_mode tvs cls cls_tys
-                         tycon tc_args rep_tc rep_tc_args mtheta deriv_strat }
+       ; let deriv_env = DerivEnv
+                         { denv_overlap_mode = overlap_mode
+                         , denv_tvs          = tvs
+                         , denv_cls          = cls
+                         , denv_cls_tys      = cls_tys
+                         , denv_tc           = tycon
+                         , denv_tc_args      = tc_args
+                         , denv_rep_tc       = rep_tc
+                         , denv_rep_tc_args  = rep_tc_args
+                         , denv_mtheta       = mtheta
+                         , denv_strat        = deriv_strat }
+       ; flip runReaderT deriv_env $
+         if isDataTyCon rep_tc then mkDataTypeEqn else mkNewTypeEqn }
   where
      bale_out msg = failWithTc (derivingThingErr False cls cls_tys
                       (mkTyConApp tycon tc_args) deriv_strat msg)
@@ -1031,59 +1038,42 @@ See Note [Eta reduction for data families] in FamInstEnv
 ************************************************************************
 -}
 
-mkDataTypeEqn :: DynFlags
-              -> Maybe OverlapMode
-              -> [TyVar]                -- Universally quantified type variables in the instance
-              -> Class                  -- Class for which we need to derive an instance
-              -> [Type]                 -- Other parameters to the class except the last
-              -> TyCon                  -- Type constructor for which the instance is requested
-                                        --    (last parameter to the type class)
-              -> [Type]                 -- Parameters to the type constructor
-              -> TyCon                  -- rep of the above (for type families)
-              -> [Type]                 -- rep of the above
-              -> DerivContext        -- Context of the instance, for standalone deriving
-              -> Maybe DerivStrategy    -- 'Just' if user requests a particular
-                                        -- deriving strategy.
-                                        -- Otherwise, 'Nothing'.
-              -> TcRn EarlyDerivSpec    -- Return 'Nothing' if error
-
-mkDataTypeEqn dflags overlap_mode tvs cls cls_tys
-              tycon tc_args rep_tc rep_tc_args mtheta deriv_strat
-  = case deriv_strat of
-      Just StockStrategy    -> mk_eqn_stock dflags mtheta cls cls_tys rep_tc
-                                go_for_it bale_out
-      Just AnyclassStrategy -> mk_eqn_anyclass dflags go_for_it bale_out
-      -- GeneralizedNewtypeDeriving makes no sense for non-newtypes
-      Just NewtypeStrategy  -> bale_out gndNonNewtypeErr
-      -- Lacking a user-requested deriving strategy, we will try to pick
-      -- between the stock or anyclass strategies
-      Nothing -> mk_eqn_no_mechanism dflags tycon mtheta cls cls_tys rep_tc
-                   go_for_it bale_out
-  where
-    go_for_it    = mk_data_eqn overlap_mode tvs cls cls_tys tycon tc_args
-                     rep_tc rep_tc_args mtheta (isJust deriv_strat)
-    bale_out msg = failWithTc (derivingThingErr False cls cls_tys
-                     (mkTyConApp tycon tc_args) deriv_strat msg)
-
-mk_data_eqn :: Maybe OverlapMode -> [TyVar] -> Class -> [Type]
-            -> TyCon -> [TcType] -> TyCon -> [TcType] -> DerivContext
-            -> Bool -- True if an explicit deriving strategy keyword was
-                    -- provided
-            -> DerivSpecMechanism -- How GHC should proceed attempting to
+mkDataTypeEqn :: DerivM EarlyDerivSpec
+mkDataTypeEqn
+  = do mb_strat <- asks denv_strat
+       let bale_out msg = do err <- derivingThingErrM False msg
+                             lift $ failWithTc err
+       case mb_strat of
+         Just StockStrategy    -> mk_eqn_stock    mk_data_eqn bale_out
+         Just AnyclassStrategy -> mk_eqn_anyclass mk_data_eqn bale_out
+         -- GeneralizedNewtypeDeriving makes no sense for non-newtypes
+         Just NewtypeStrategy  -> bale_out gndNonNewtypeErr
+         -- Lacking a user-requested deriving strategy, we will try to pick
+         -- between the stock or anyclass strategies
+         Nothing -> mk_eqn_no_mechanism mk_data_eqn bale_out
+
+mk_data_eqn :: DerivSpecMechanism -- How GHC should proceed attempting to
                                   -- derive this instance, determined in
                                   -- mkDataTypeEqn/mkNewTypeEqn
-            -> TcM EarlyDerivSpec
-mk_data_eqn overlap_mode tvs cls cls_tys tycon tc_args rep_tc rep_tc_args
-            mtheta strat_used mechanism
-  = do doDerivInstErrorChecks1 cls cls_tys tycon tc_args rep_tc mtheta
-                               strat_used mechanism
-       loc                  <- getSrcSpanM
-       dfun_name            <- newDFunName' cls tycon
+            -> DerivM EarlyDerivSpec
+mk_data_eqn mechanism
+  = do DerivEnv { denv_overlap_mode = overlap_mode
+                , denv_tvs          = tvs
+                , denv_tc           = tc
+                , denv_tc_args      = tc_args
+                , denv_rep_tc       = rep_tc
+                , denv_cls          = cls
+                , denv_cls_tys      = cls_tys
+                , denv_mtheta       = mtheta } <- ask
+       let inst_ty  = mkTyConApp tc tc_args
+           inst_tys = cls_tys ++ [inst_ty]
+       doDerivInstErrorChecks1 mechanism
+       loc       <- lift getSrcSpanM
+       dfun_name <- lift $ newDFunName' cls tc
        case mtheta of
         Nothing -> -- Infer context
           do { (inferred_constraints, tvs', inst_tys')
-                 <- inferConstraints tvs cls cls_tys inst_ty
-                                     rep_tc rep_tc_args mechanism
+                 <- inferConstraints mechanism
              ; return $ InferTheta $ DS
                    { ds_loc = loc
                    , ds_name = dfun_name, ds_tvs = tvs'
@@ -1102,58 +1092,66 @@ mk_data_eqn overlap_mode tvs cls cls_tys tycon tc_args rep_tc rep_tc_args
                    , ds_theta = theta
                    , ds_overlap = overlap_mode
                    , ds_mechanism = mechanism }
-  where
-    inst_ty  = mkTyConApp tycon tc_args
-    inst_tys = cls_tys ++ [inst_ty]
-
-mk_eqn_stock :: DynFlags -> DerivContext -> Class -> [Type] -> TyCon
-             -> (DerivSpecMechanism -> TcRn EarlyDerivSpec)
-             -> (SDoc -> TcRn EarlyDerivSpec)
-             -> TcRn EarlyDerivSpec
-mk_eqn_stock dflags mtheta cls cls_tys rep_tc go_for_it bale_out
-  = case checkSideConditions dflags mtheta cls cls_tys rep_tc of
-        CanDerive               -> mk_eqn_stock' cls go_for_it
-        DerivableClassError msg -> bale_out msg
-        _                       -> bale_out (nonStdErr cls)
-
-mk_eqn_stock' :: Class -> (DerivSpecMechanism -> TcRn EarlyDerivSpec)
-                -> TcRn EarlyDerivSpec
-mk_eqn_stock' cls go_for_it
-  = go_for_it $ case hasStockDeriving cls of
-        Just gen_fn -> DerivSpecStock gen_fn
-        Nothing ->
-          pprPanic "mk_eqn_stock': Not a stock class!" (ppr cls)
-
-mk_eqn_anyclass :: DynFlags
-                -> (DerivSpecMechanism -> TcRn EarlyDerivSpec)
-                -> (SDoc -> TcRn EarlyDerivSpec)
-                -> TcRn EarlyDerivSpec
-mk_eqn_anyclass dflags go_for_it bale_out
-  = case canDeriveAnyClass dflags of
-        IsValid      -> go_for_it DerivSpecAnyClass
-        NotValid msg -> bale_out msg
-
-mk_eqn_no_mechanism :: DynFlags -> TyCon -> DerivContext
-                    -> Class -> [Type] -> TyCon
-                    -> (DerivSpecMechanism -> TcRn EarlyDerivSpec)
-                    -> (SDoc -> TcRn EarlyDerivSpec)
-                    -> TcRn EarlyDerivSpec
-mk_eqn_no_mechanism dflags tc mtheta cls cls_tys rep_tc go_for_it bale_out
-  = case checkSideConditions dflags mtheta cls cls_tys rep_tc of
-        -- NB: pass the *representation* tycon to checkSideConditions
-        NonDerivableClass   msg -> bale_out (dac_error msg)
-        DerivableClassError msg -> bale_out msg
-        CanDerive               -> mk_eqn_stock' cls go_for_it
-        DerivableViaInstance    -> go_for_it DerivSpecAnyClass
-  where
-    -- See Note [Deriving instances for classes themselves]
-    dac_error msg
-      | isClassTyCon rep_tc
-      = quotes (ppr tc) <+> text "is a type class,"
-                        <+> text "and can only have a derived instance"
-                        $+$ text "if DeriveAnyClass is enabled"
-      | otherwise
-      = nonStdErr cls $$ msg
+
+mk_eqn_stock :: (DerivSpecMechanism -> DerivM EarlyDerivSpec)
+             -> (SDoc -> DerivM EarlyDerivSpec)
+             -> DerivM EarlyDerivSpec
+mk_eqn_stock go_for_it bale_out
+  = do DerivEnv { denv_rep_tc  = rep_tc
+                , denv_cls     = cls
+                , denv_cls_tys = cls_tys
+                , denv_mtheta  = mtheta } <- ask
+       dflags <- getDynFlags
+       case checkSideConditions dflags mtheta cls cls_tys rep_tc of
+         CanDerive               -> mk_eqn_stock' go_for_it
+         DerivableClassError msg -> bale_out msg
+         _                       -> bale_out (nonStdErr cls)
+
+mk_eqn_stock' :: (DerivSpecMechanism -> DerivM EarlyDerivSpec)
+              -> DerivM EarlyDerivSpec
+mk_eqn_stock' go_for_it
+  = do cls <- asks denv_cls
+       go_for_it $
+         case hasStockDeriving cls of
+           Just gen_fn -> DerivSpecStock gen_fn
+           Nothing ->
+             pprPanic "mk_eqn_stock': Not a stock class!" (ppr cls)
+
+mk_eqn_anyclass :: (DerivSpecMechanism -> DerivM EarlyDerivSpec)
+                -> (SDoc -> DerivM EarlyDerivSpec)
+                -> DerivM EarlyDerivSpec
+mk_eqn_anyclass go_for_it bale_out
+  = do dflags <- getDynFlags
+       case canDeriveAnyClass dflags of
+         IsValid      -> go_for_it DerivSpecAnyClass
+         NotValid msg -> bale_out msg
+
+mk_eqn_no_mechanism :: (DerivSpecMechanism -> DerivM EarlyDerivSpec)
+                    -> (SDoc -> DerivM EarlyDerivSpec)
+                    -> DerivM EarlyDerivSpec
+mk_eqn_no_mechanism go_for_it bale_out
+  = do DerivEnv { denv_tc      = tc
+                , denv_rep_tc  = rep_tc
+                , denv_cls     = cls
+                , denv_cls_tys = cls_tys
+                , denv_mtheta  = mtheta } <- ask
+       dflags <- getDynFlags
+
+           -- See Note [Deriving instances for classes themselves]
+       let dac_error msg
+             | isClassTyCon rep_tc
+             = quotes (ppr tc) <+> text "is a type class,"
+                               <+> text "and can only have a derived instance"
+                               $+$ text "if DeriveAnyClass is enabled"
+             | otherwise
+             = nonStdErr cls $$ msg
+
+       case checkSideConditions dflags mtheta cls cls_tys rep_tc of
+           -- NB: pass the *representation* tycon to checkSideConditions
+           NonDerivableClass   msg -> bale_out (dac_error msg)
+           DerivableClassError msg -> bale_out msg
+           CanDerive               -> mk_eqn_stock' go_for_it
+           DerivableViaInstance    -> go_for_it DerivSpecAnyClass
 
 {-
 ************************************************************************
@@ -1163,244 +1161,249 @@ mk_eqn_no_mechanism dflags tc mtheta cls cls_tys rep_tc go_for_it bale_out
 ************************************************************************
 -}
 
-mkNewTypeEqn :: DynFlags -> Maybe OverlapMode -> [TyVar] -> Class
-             -> [Type] -> TyCon -> [Type] -> TyCon -> [Type]
-             -> DerivContext -> Maybe DerivStrategy
-             -> TcRn EarlyDerivSpec
-mkNewTypeEqn dflags overlap_mode tvs
-             cls cls_tys tycon tc_args rep_tycon rep_tc_args
-             mtheta deriv_strat
+mkNewTypeEqn :: DerivM EarlyDerivSpec
+mkNewTypeEqn
 -- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ...
-  = ASSERT( cls_tys `lengthIs` (classArity cls - 1) )
-    case deriv_strat of
-      Just StockStrategy    -> mk_eqn_stock dflags mtheta cls cls_tys rep_tycon
-                                 go_for_it_other bale_out
-      Just AnyclassStrategy -> mk_eqn_anyclass dflags go_for_it_other bale_out
-      Just NewtypeStrategy  ->
-        -- Since the user explicitly asked for GeneralizedNewtypeDeriving, we
-        -- don't need to perform all of the checks we normally would, such as
-        -- if the class being derived is known to produce ill-roled coercions
-        -- (e.g., Traversable), since we can just derive the instance and let
-        -- it error if need be.
-        -- See Note [Determining whether newtype-deriving is appropriate]
-        if coercion_looks_sensible && newtype_deriving
-          then go_for_it_gnd
-          else bale_out (cant_derive_err $$
-                         if newtype_deriving then empty else suggest_gnd)
-      Nothing
-        | might_derive_via_coercible
-          && ((newtype_deriving && not deriveAnyClass)
-               || std_class_via_coercible cls)
-       -> go_for_it_gnd
-        | otherwise
-       -> case checkSideConditions dflags mtheta cls cls_tys rep_tycon of
-            DerivableClassError msg
-              -- There's a particular corner case where
-              --
-              -- 1. -XGeneralizedNewtypeDeriving and -XDeriveAnyClass are both
-              --    enabled at the same time
-              -- 2. We're deriving a particular stock derivable class
-              --    (such as Functor)
-              --
-              -- and the previous cases won't catch it. This fixes the bug
-              -- reported in Trac #10598.
-              | might_derive_via_coercible && newtype_deriving
-             -> go_for_it_gnd
-              -- Otherwise, throw an error for a stock class
-              | might_derive_via_coercible && not newtype_deriving
-             -> bale_out (msg $$ suggest_gnd)
-              | otherwise
-             -> bale_out msg
-
-            -- Must use newtype deriving or DeriveAnyClass
-            NonDerivableClass _msg
-              -- Too hard, even with newtype deriving
-              | newtype_deriving           -> bale_out cant_derive_err
-              -- Try newtype deriving!
-              -- Here we suggest GeneralizedNewtypeDeriving even in cases where
-              -- it may not be applicable. See Trac #9600.
-              | otherwise                  -> bale_out (non_std $$ suggest_gnd)
-
-            -- DerivableViaInstance
-            DerivableViaInstance -> do
-              -- If both DeriveAnyClass and GeneralizedNewtypeDeriving are
-              -- enabled, we take the diplomatic approach of defaulting to
-              -- DeriveAnyClass, but emitting a warning about the choice.
-              -- See Note [Deriving strategies]
-              when (newtype_deriving && deriveAnyClass) $
-                addWarnTc NoReason $ sep
-                  [ text "Both DeriveAnyClass and"
-                    <+> text "GeneralizedNewtypeDeriving are enabled"
-                  , text "Defaulting to the DeriveAnyClass strategy"
-                    <+> text "for instantiating" <+> ppr cls ]
-              go_for_it_other DerivSpecAnyClass
-            -- CanDerive
-            CanDerive -> mk_eqn_stock' cls go_for_it_other
-  where
-        newtype_deriving  = xopt LangExt.GeneralizedNewtypeDeriving dflags
-        deriveAnyClass    = xopt LangExt.DeriveAnyClass             dflags
-        go_for_it_gnd     = do
-          traceTc "newtype deriving:" $
-            ppr tycon <+> ppr rep_tys <+> ppr all_thetas
-          let mechanism = DerivSpecNewtype rep_inst_ty
-          doDerivInstErrorChecks1 cls cls_tys tycon tc_args rep_tycon mtheta
-                                  strat_used mechanism
-          dfun_name <- newDFunName' cls tycon
-          loc <- getSrcSpanM
-          case mtheta of
-           Just theta -> return $ GivenTheta $ DS
-               { ds_loc = loc
-               , ds_name = dfun_name, ds_tvs = tvs
-               , ds_cls = cls, ds_tys = inst_tys
-               , ds_tc = rep_tycon
-               , ds_theta = theta
-               , ds_overlap = overlap_mode
-               , ds_mechanism = mechanism }
-           Nothing -> return $ InferTheta $ DS
-               { ds_loc = loc
-               , ds_name = dfun_name, ds_tvs = tvs
-               , ds_cls = cls, ds_tys = inst_tys
-               , ds_tc = rep_tycon
-               , ds_theta = all_thetas
-               , ds_overlap = overlap_mode
-               , ds_mechanism = mechanism }
-        go_for_it_other = mk_data_eqn overlap_mode tvs cls cls_tys tycon
-                            tc_args rep_tycon rep_tc_args mtheta strat_used
-        bale_out    = bale_out' newtype_deriving
-        bale_out' b = failWithTc . derivingThingErr b cls cls_tys inst_ty
-                                                    deriv_strat
-
-        strat_used  = isJust deriv_strat
-        non_std     = nonStdErr cls
-        suggest_gnd = text "Try GeneralizedNewtypeDeriving for GHC's newtype-deriving extension"
-
-        -- Here is the plan for newtype derivings.  We see
-        --        newtype T a1...an = MkT (t ak+1...an) deriving (.., C s1 .. sm, ...)
-        -- where t is a type,
-        --       ak+1...an is a suffix of a1..an, and are all tyvars
-        --       ak+1...an do not occur free in t, nor in the s1..sm
-        --       (C s1 ... sm) is a  *partial applications* of class C
-        --                      with the last parameter missing
-        --       (T a1 .. ak) matches the kind of C's last argument
-        --              (and hence so does t)
-        -- The latter kind-check has been done by deriveTyData already,
-        -- and tc_args are already trimmed
-        --
-        -- We generate the instance
-        --       instance forall ({a1..ak} u fvs(s1..sm)).
-        --                C s1 .. sm t => C s1 .. sm (T a1...ak)
-        -- where T a1...ap is the partial application of
-        --       the LHS of the correct kind and p >= k
-        --
-        --      NB: the variables below are:
-        --              tc_tvs = [a1, ..., an]
-        --              tyvars_to_keep = [a1, ..., ak]
-        --              rep_ty = t ak .. an
-        --              deriv_tvs = fvs(s1..sm) \ tc_tvs
-        --              tys = [s1, ..., sm]
-        --              rep_fn' = t
-        --
-        -- Running example: newtype T s a = MkT (ST s a) deriving( Monad )
-        -- We generate the instance
-        --      instance Monad (ST s) => Monad (T s) where
-
-        nt_eta_arity = newTyConEtadArity rep_tycon
-                -- For newtype T a b = MkT (S a a b), the TyCon machinery already
-                -- eta-reduces the representation type, so we know that
-                --      T a ~ S a a
-                -- That's convenient here, because we may have to apply
-                -- it to fewer than its original complement of arguments
-
-        -- Note [Newtype representation]
-        -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-        -- Need newTyConRhs (*not* a recursive representation finder)
-        -- to get the representation type. For example
-        --      newtype B = MkB Int
-        --      newtype A = MkA B deriving( Num )
-        -- We want the Num instance of B, *not* the Num instance of Int,
-        -- when making the Num instance of A!
-        rep_inst_ty = newTyConInstRhs rep_tycon rep_tc_args
-        rep_tys     = cls_tys ++ [rep_inst_ty]
-        rep_pred    = mkClassPred cls rep_tys
-        rep_pred_o  = mkPredOrigin DerivOrigin TypeLevel rep_pred
-                -- rep_pred is the representation dictionary, from where
-                -- we are gong to get all the methods for the newtype
-                -- dictionary
-
-        -- Next we figure out what superclass dictionaries to use
-        -- See Note [Newtype deriving superclasses] above
-        sc_preds   :: [PredOrigin]
-        cls_tyvars = classTyVars cls
-        inst_ty    = mkTyConApp tycon tc_args
-        inst_tys   = cls_tys ++ [inst_ty]
-        sc_preds   = map (mkPredOrigin DerivOrigin TypeLevel) $
-                     substTheta (zipTvSubst cls_tyvars inst_tys) $
-                     classSCTheta cls
-
-        -- Next we collect constraints for the class methods
-        -- If there are no methods, we don't need any constraints
-        -- Otherwise we need (C rep_ty), for the representation methods,
-        -- and constraints to coerce each individual method
-        meth_preds :: [PredOrigin]
-        meths = classMethods cls
-        meth_preds | null meths = [] -- No methods => no constraints
-                                     -- (Trac #12814)
-                   | otherwise = rep_pred_o : coercible_constraints
-        coercible_constraints
-          = [ mkPredOrigin (DerivOriginCoerce meth t1 t2) TypeLevel
-                           (mkReprPrimEqPred t1 t2)
-            | meth <- meths
-            , let (Pair t1 t2) = mkCoerceClassMethEqn cls tvs
-                                         inst_tys rep_inst_ty meth ]
-
-        all_thetas :: [ThetaOrigin]
-        all_thetas = [mkThetaOriginFromPreds $ meth_preds ++ sc_preds]
-
-        -------------------------------------------------------------------
-        --  Figuring out whether we can only do this newtype-deriving thing
-
-        -- See Note [Determining whether newtype-deriving is appropriate]
-        might_derive_via_coercible
-           =  not (non_coercible_class cls)
-           && coercion_looks_sensible
---         && not (isRecursiveTyCon tycon)      -- Note [Recursive newtypes]
-        coercion_looks_sensible
-           =  eta_ok
-              -- Check (a) from Note [GND and associated type families]
-           && ats_ok
-              -- Check (b) from Note [GND and associated type families]
-           && isNothing at_without_last_cls_tv
-
-        -- Check that eta reduction is OK
-        eta_ok = rep_tc_args `lengthAtLeast` nt_eta_arity
-                -- The newtype can be eta-reduced to match the number
-                --     of type argument actually supplied
-                --        newtype T a b = MkT (S [a] b) deriving( Monad )
-                --     Here the 'b' must be the same in the rep type (S [a] b)
-                --     And the [a] must not mention 'b'.  That's all handled
-                --     by nt_eta_rity.
-
-        (adf_tcs, atf_tcs) = partition isDataFamilyTyCon at_tcs
-        ats_ok             = null adf_tcs
-               -- We cannot newtype-derive data family instances
-
-        at_without_last_cls_tv
-          = find (\tc -> last_cls_tv `notElem` tyConTyVars tc) atf_tcs
-        at_tcs = classATs cls
-        last_cls_tv = ASSERT( notNull cls_tyvars )
-                      last cls_tyvars
-
-        cant_derive_err
-           = vcat [ ppUnless eta_ok eta_msg
-                  , ppUnless ats_ok ats_msg
-                  , maybe empty at_tv_msg
-                          at_without_last_cls_tv]
-        eta_msg   = text "cannot eta-reduce the representation type enough"
-        ats_msg   = text "the class has associated data types"
-        at_tv_msg at_tc = hang
-          (text "the associated type" <+> quotes (ppr at_tc)
-           <+> text "is not parameterized over the last type variable")
-          2 (text "of the class" <+> quotes (ppr cls))
+  = do DerivEnv { denv_overlap_mode = overlap_mode
+                , denv_tvs          = tvs
+                , denv_tc           = tycon
+                , denv_tc_args      = tc_args
+                , denv_rep_tc       = rep_tycon
+                , denv_rep_tc_args  = rep_tc_args
+                , denv_cls          = cls
+                , denv_cls_tys      = cls_tys
+                , denv_mtheta       = mtheta
+                , denv_strat        = mb_strat } <- ask
+       dflags <- getDynFlags
+
+       let newtype_deriving  = xopt LangExt.GeneralizedNewtypeDeriving dflags
+           deriveAnyClass    = xopt LangExt.DeriveAnyClass             dflags
+           go_for_it_gnd     = do
+             lift $ traceTc "newtype deriving:" $
+               ppr tycon <+> ppr rep_tys <+> ppr all_thetas
+             let mechanism = DerivSpecNewtype rep_inst_ty
+             doDerivInstErrorChecks1 mechanism
+             dfun_name <- lift $ newDFunName' cls tycon
+             loc       <- lift getSrcSpanM
+             case mtheta of
+              Just theta -> return $ GivenTheta $ DS
+                  { ds_loc = loc
+                  , ds_name = dfun_name, ds_tvs = tvs
+                  , ds_cls = cls, ds_tys = inst_tys
+                  , ds_tc = rep_tycon
+                  , ds_theta = theta
+                  , ds_overlap = overlap_mode
+                  , ds_mechanism = mechanism }
+              Nothing -> return $ InferTheta $ DS
+                  { ds_loc = loc
+                  , ds_name = dfun_name, ds_tvs = tvs
+                  , ds_cls = cls, ds_tys = inst_tys
+                  , ds_tc = rep_tycon
+                  , ds_theta = all_thetas
+                  , ds_overlap = overlap_mode
+                  , ds_mechanism = mechanism }
+           bale_out        = bale_out' newtype_deriving
+           bale_out' b msg = do err <- derivingThingErrM b msg
+                                lift $ failWithTc err
+
+           non_std     = nonStdErr cls
+           suggest_gnd = text "Try GeneralizedNewtypeDeriving for GHC's"
+                     <+> text "newtype-deriving extension"
+
+           -- Here is the plan for newtype derivings.  We see
+           --        newtype T a1...an = MkT (t ak+1...an)
+           --          deriving (.., C s1 .. sm, ...)
+           -- where t is a type,
+           --       ak+1...an is a suffix of a1..an, and are all tyvars
+           --       ak+1...an do not occur free in t, nor in the s1..sm
+           --       (C s1 ... sm) is a  *partial applications* of class C
+           --                      with the last parameter missing
+           --       (T a1 .. ak) matches the kind of C's last argument
+           --              (and hence so does t)
+           -- The latter kind-check has been done by deriveTyData already,
+           -- and tc_args are already trimmed
+           --
+           -- We generate the instance
+           --       instance forall ({a1..ak} u fvs(s1..sm)).
+           --                C s1 .. sm t => C s1 .. sm (T a1...ak)
+           -- where T a1...ap is the partial application of
+           --       the LHS of the correct kind and p >= k
+           --
+           --      NB: the variables below are:
+           --              tc_tvs = [a1, ..., an]
+           --              tyvars_to_keep = [a1, ..., ak]
+           --              rep_ty = t ak .. an
+           --              deriv_tvs = fvs(s1..sm) \ tc_tvs
+           --              tys = [s1, ..., sm]
+           --              rep_fn' = t
+           --
+           -- Running example: newtype T s a = MkT (ST s a) deriving( Monad )
+           -- We generate the instance
+           --      instance Monad (ST s) => Monad (T s) where
+
+           nt_eta_arity = newTyConEtadArity rep_tycon
+                   -- For newtype T a b = MkT (S a a b), the TyCon
+                   -- machinery already eta-reduces the representation type, so
+                   -- we know that
+                   --      T a ~ S a a
+                   -- That's convenient here, because we may have to apply
+                   -- it to fewer than its original complement of arguments
+
+           -- Note [Newtype representation]
+           -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+           -- Need newTyConRhs (*not* a recursive representation finder)
+           -- to get the representation type. For example
+           --      newtype B = MkB Int
+           --      newtype A = MkA B deriving( Num )
+           -- We want the Num instance of B, *not* the Num instance of Int,
+           -- when making the Num instance of A!
+           rep_inst_ty = newTyConInstRhs rep_tycon rep_tc_args
+           rep_tys     = cls_tys ++ [rep_inst_ty]
+           rep_pred    = mkClassPred cls rep_tys
+           rep_pred_o  = mkPredOrigin DerivOrigin TypeLevel rep_pred
+                   -- rep_pred is the representation dictionary, from where
+                   -- we are gong to get all the methods for the newtype
+                   -- dictionary
+
+           -- Next we figure out what superclass dictionaries to use
+           -- See Note [Newtype deriving superclasses] above
+           sc_preds   :: [PredOrigin]
+           cls_tyvars = classTyVars cls
+           inst_ty    = mkTyConApp tycon tc_args
+           inst_tys   = cls_tys ++ [inst_ty]
+           sc_preds   = map (mkPredOrigin DerivOrigin TypeLevel) $
+                        substTheta (zipTvSubst cls_tyvars inst_tys) $
+                        classSCTheta cls
+
+           -- Next we collect constraints for the class methods
+           -- If there are no methods, we don't need any constraints
+           -- Otherwise we need (C rep_ty), for the representation methods,
+           -- and constraints to coerce each individual method
+           meth_preds :: [PredOrigin]
+           meths = classMethods cls
+           meth_preds | null meths = [] -- No methods => no constraints
+                                        -- (Trac #12814)
+                      | otherwise = rep_pred_o : coercible_constraints
+           coercible_constraints
+             = [ mkPredOrigin (DerivOriginCoerce meth t1 t2) TypeLevel
+                              (mkReprPrimEqPred t1 t2)
+               | meth <- meths
+               , let (Pair t1 t2) = mkCoerceClassMethEqn cls tvs
+                                            inst_tys rep_inst_ty meth ]
+
+           all_thetas :: [ThetaOrigin]
+           all_thetas = [mkThetaOriginFromPreds $ meth_preds ++ sc_preds]
+
+           -------------------------------------------------------------------
+           --  Figuring out whether we can only do this newtype-deriving thing
+
+           -- See Note [Determining whether newtype-deriving is appropriate]
+           might_derive_via_coercible
+              =  not (non_coercible_class cls)
+              && coercion_looks_sensible
+--            && not (isRecursiveTyCon tycon)      -- Note [Recursive newtypes]
+           coercion_looks_sensible
+              =  eta_ok
+                 -- Check (a) from Note [GND and associated type families]
+              && ats_ok
+                 -- Check (b) from Note [GND and associated type families]
+              && isNothing at_without_last_cls_tv
+
+           -- Check that eta reduction is OK
+           eta_ok = rep_tc_args `lengthAtLeast` nt_eta_arity
+             -- The newtype can be eta-reduced to match the number
+             --     of type argument actually supplied
+             --        newtype T a b = MkT (S [a] b) deriving( Monad )
+             --     Here the 'b' must be the same in the rep type (S [a] b)
+             --     And the [a] must not mention 'b'.  That's all handled
+             --     by nt_eta_rity.
+
+           (adf_tcs, atf_tcs) = partition isDataFamilyTyCon at_tcs
+           ats_ok             = null adf_tcs
+                  -- We cannot newtype-derive data family instances
+
+           at_without_last_cls_tv
+             = find (\tc -> last_cls_tv `notElem` tyConTyVars tc) atf_tcs
+           at_tcs = classATs cls
+           last_cls_tv = ASSERT( notNull cls_tyvars )
+                         last cls_tyvars
+
+           cant_derive_err
+              = vcat [ ppUnless eta_ok eta_msg
+                     , ppUnless ats_ok ats_msg
+                     , maybe empty at_tv_msg
+                             at_without_last_cls_tv]
+           eta_msg   = text "cannot eta-reduce the representation type enough"
+           ats_msg   = text "the class has associated data types"
+           at_tv_msg at_tc = hang
+             (text "the associated type" <+> quotes (ppr at_tc)
+              <+> text "is not parameterized over the last type variable")
+             2 (text "of the class" <+> quotes (ppr cls))
+
+       MASSERT( cls_tys `lengthIs` (classArity cls - 1) )
+       case mb_strat of
+         Just StockStrategy    -> mk_eqn_stock    mk_data_eqn bale_out
+         Just AnyclassStrategy -> mk_eqn_anyclass mk_data_eqn bale_out
+         Just NewtypeStrategy  ->
+           -- Since the user explicitly asked for GeneralizedNewtypeDeriving,
+           -- we don't need to perform all of the checks we normally would,
+           -- such as if the class being derived is known to produce ill-roled
+           -- coercions (e.g., Traversable), since we can just derive the
+           -- instance and let it error if need be.
+           -- See Note [Determining whether newtype-deriving is appropriate]
+           if coercion_looks_sensible && newtype_deriving
+             then go_for_it_gnd
+             else bale_out (cant_derive_err $$
+                            if newtype_deriving then empty else suggest_gnd)
+         Nothing
+           | might_derive_via_coercible
+             && ((newtype_deriving && not deriveAnyClass)
+                  || std_class_via_coercible cls)
+          -> go_for_it_gnd
+           | otherwise
+          -> case checkSideConditions dflags mtheta cls cls_tys rep_tycon of
+               DerivableClassError msg
+                 -- There's a particular corner case where
+                 --
+                 -- 1. -XGeneralizedNewtypeDeriving and -XDeriveAnyClass are
+                 --    both enabled at the same time
+                 -- 2. We're deriving a particular stock derivable class
+                 --    (such as Functor)
+                 --
+                 -- and the previous cases won't catch it. This fixes the bug
+                 -- reported in Trac #10598.
+                 | might_derive_via_coercible && newtype_deriving
+                -> go_for_it_gnd
+                 -- Otherwise, throw an error for a stock class
+                 | might_derive_via_coercible && not newtype_deriving
+                -> bale_out (msg $$ suggest_gnd)
+                 | otherwise
+                -> bale_out msg
+
+               -- Must use newtype deriving or DeriveAnyClass
+               NonDerivableClass _msg
+                 -- Too hard, even with newtype deriving
+                 | newtype_deriving           -> bale_out cant_derive_err
+                 -- Try newtype deriving!
+                 -- Here we suggest GeneralizedNewtypeDeriving even in cases
+                 -- where it may not be applicable. See Trac #9600.
+                 | otherwise                  -> bale_out (non_std $$ suggest_gnd)
+
+               -- DerivableViaInstance
+               DerivableViaInstance -> do
+                 -- If both DeriveAnyClass and GeneralizedNewtypeDeriving are
+                 -- enabled, we take the diplomatic approach of defaulting to
+                 -- DeriveAnyClass, but emitting a warning about the choice.
+                 -- See Note [Deriving strategies]
+                 when (newtype_deriving && deriveAnyClass) $
+                   lift $ addWarnTc NoReason $ sep
+                     [ text "Both DeriveAnyClass and"
+                       <+> text "GeneralizedNewtypeDeriving are enabled"
+                     , text "Defaulting to the DeriveAnyClass strategy"
+                       <+> text "for instantiating" <+> ppr cls ]
+                 mk_data_eqn DerivSpecAnyClass
+               -- CanDerive
+               CanDerive -> mk_eqn_stock' mk_data_eqn
 
 {-
 Note [Recursive newtypes]
@@ -1628,31 +1631,31 @@ genInst spec@(DS { ds_tvs = tvs, ds_tc = rep_tycon
     set_span_and_ctxt :: TcM a -> TcM a
     set_span_and_ctxt = setSrcSpan loc . addErrCtxt (instDeclCtxt3 clas tys)
 
-doDerivInstErrorChecks1 :: Class -> [Type] -> TyCon -> [Type] -> TyCon
-                        -> DerivContext -> Bool -> DerivSpecMechanism
-                        -> TcM ()
-doDerivInstErrorChecks1 cls cls_tys tc tc_args rep_tc mtheta
-                        strat_used mechanism = do
+doDerivInstErrorChecks1 :: DerivSpecMechanism -> DerivM ()
+doDerivInstErrorChecks1 mechanism = do
+    DerivEnv { denv_tc      = tc
+             , denv_rep_tc  = rep_tc
+             , denv_mtheta  = mtheta } <- ask
+    let anyclass_strategy = isDerivSpecAnyClass mechanism
+        bale_out msg = do err <- derivingThingErrMechanism mechanism msg
+                          lift $ failWithTc err
+
     -- For standalone deriving (mtheta /= Nothing),
     -- check that all the data constructors are in scope...
-    rdr_env <- getGlobalRdrEnv
+    rdr_env <- lift getGlobalRdrEnv
     let data_con_names = map dataConName (tyConDataCons rep_tc)
         hidden_data_cons = not (isWiredInName (tyConName rep_tc)) &&
                            (isAbstractTyCon rep_tc ||
                             any not_in_scope data_con_names)
         not_in_scope dc  = isNothing (lookupGRE_Name rdr_env dc)
 
-    addUsedDataCons rdr_env rep_tc
+    lift $ addUsedDataCons rdr_env rep_tc
+
     -- ...however, we don't perform this check if we're using DeriveAnyClass,
     -- since it doesn't generate any code that requires use of a data
     -- constructor.
     unless (anyclass_strategy || isNothing mtheta || not hidden_data_cons) $
            bale_out $ derivingHiddenErr tc
-  where
-    anyclass_strategy = isDerivSpecAnyClass mechanism
-
-    bale_out msg = failWithTc (derivingThingErrMechanism cls cls_tys
-                     (mkTyConApp tc tc_args) strat_used mechanism msg)
 
 doDerivInstErrorChecks2 :: Class -> ClsInst -> DerivSpecMechanism -> TcM ()
 doDerivInstErrorChecks2 clas clas_inst mechanism
@@ -1871,33 +1874,45 @@ derivingEtaErr cls cls_tys inst_ty
          nest 2 (text "instance (...) =>"
                 <+> pprClassPred cls (cls_tys ++ [inst_ty]))]
 
-derivingThingErr :: Bool -> Class -> [Type] -> Type -> Maybe DerivStrategy
-                 -> MsgDoc -> MsgDoc
-derivingThingErr newtype_deriving clas tys ty deriv_strat why
-  = derivingThingErr' newtype_deriving clas tys ty (isJust deriv_strat)
-                      (maybe empty ppr deriv_strat) why
-
-derivingThingErrMechanism :: Class -> [Type] -> Type
-                          -> Bool -- True if an explicit deriving strategy
-                                  -- keyword was provided
-                          -> DerivSpecMechanism
-                          -> MsgDoc -> MsgDoc
-derivingThingErrMechanism clas tys ty strat_used mechanism why
-  = derivingThingErr' (isDerivSpecNewtype mechanism) clas tys ty strat_used
-                      (ppr mechanism) why
-
-derivingThingErr' :: Bool -> Class -> [Type] -> Type -> Bool -> MsgDoc
-                  -> MsgDoc -> MsgDoc
-derivingThingErr' newtype_deriving clas tys ty strat_used strat_msg why
+derivingThingErr :: Bool -> Class -> [Type] -> Type
+                 -> Maybe DerivStrategy -> MsgDoc -> MsgDoc
+derivingThingErr newtype_deriving cls cls_tys inst_ty mb_strat why
+  = derivingThingErr' newtype_deriving cls cls_tys inst_ty mb_strat
+                      (maybe empty ppr mb_strat) why
+
+derivingThingErrM :: Bool -> MsgDoc -> DerivM MsgDoc
+derivingThingErrM newtype_deriving why
+  = do DerivEnv { denv_tc      = tc
+                , denv_tc_args = tc_args
+                , denv_cls     = cls
+                , denv_cls_tys = cls_tys
+                , denv_strat   = mb_strat } <- ask
+       pure $ derivingThingErr newtype_deriving cls cls_tys
+                               (mkTyConApp tc tc_args) mb_strat why
+
+derivingThingErrMechanism :: DerivSpecMechanism -> MsgDoc -> DerivM MsgDoc
+derivingThingErrMechanism mechanism why
+  = do DerivEnv { denv_tc      = tc
+                , denv_tc_args = tc_args
+                , denv_cls     = cls
+                , denv_cls_tys = cls_tys
+                , denv_strat   = mb_strat } <- ask
+       pure $ derivingThingErr' (isDerivSpecNewtype mechanism) cls cls_tys
+                (mkTyConApp tc tc_args) mb_strat (ppr mechanism) why
+
+derivingThingErr' :: Bool -> Class -> [Type] -> Type
+                  -> Maybe DerivStrategy -> MsgDoc -> MsgDoc -> MsgDoc
+derivingThingErr' newtype_deriving cls cls_tys inst_ty mb_strat strat_msg why
   = sep [(hang (text "Can't make a derived instance of")
              2 (quotes (ppr pred) <+> via_mechanism)
           $$ nest 2 extra) <> colon,
          nest 2 why]
   where
+    strat_used = isJust mb_strat
     extra | not strat_used, newtype_deriving
           = text "(even with cunning GeneralizedNewtypeDeriving)"
           | otherwise = empty
-    pred = mkClassPred clas (tys ++ [ty])
+    pred = mkClassPred cls (cls_tys ++ [inst_ty])
     via_mechanism | strat_used
                   = text "with the" <+> strat_msg <+> text "strategy"
                   | otherwise
index 7d39c31..85ff250 100644 (file)
@@ -7,6 +7,7 @@ Functions for inferring (and simplifying) the context for derived instances.
 -}
 
 {-# LANGUAGE CPP #-}
+{-# LANGUAGE MultiWayIf #-}
 
 module TcDerivInfer (inferConstraints, simplifyInstanceContexts) where
 
@@ -41,14 +42,15 @@ import VarEnv
 import VarSet
 
 import Control.Monad
+import Control.Monad.Trans.Class  (lift)
+import Control.Monad.Trans.Reader (ask)
 import Data.List
 import Data.Maybe
 
 ----------------------
 
-inferConstraints :: [TyVar] -> Class -> [TcType] -> TcType
-                 -> TyCon -> [TcType] -> DerivSpecMechanism
-                 -> TcM ([ThetaOrigin], [TyVar], [TcType])
+inferConstraints :: DerivSpecMechanism
+                 -> DerivM ([ThetaOrigin], [TyVar], [TcType])
 -- inferConstraints figures out the constraints needed for the
 -- instance declaration generated by a 'deriving' clause on a
 -- data type declaration. It also returns the new in-scope type
@@ -64,191 +66,215 @@ inferConstraints :: [TyVar] -> Class -> [TcType] -> TcType
 -- Generate a sufficiently large set of constraints that typechecking the
 -- generated method definitions should succeed.   This set will be simplified
 -- before being used in the instance declaration
-inferConstraints tvs main_cls cls_tys inst_ty
-                 rep_tc rep_tc_args
-                 mechanism
-  = do { (inferred_constraints, tvs', inst_tys') <- infer_constraints
-       ; traceTc "inferConstraints" $ vcat
+inferConstraints mechanism
+  = do { DerivEnv { denv_tc          = tc
+                  , denv_tc_args     = tc_args
+                  , denv_cls         = main_cls
+                  , denv_cls_tys     = cls_tys } <- ask
+       ; let is_anyclass = isDerivSpecAnyClass mechanism
+             infer_constraints
+               | is_anyclass = inferConstraintsDAC inst_tys
+               | otherwise   = inferConstraintsDataConArgs inst_ty inst_tys
+
+             inst_ty  = mkTyConApp tc tc_args
+             inst_tys = cls_tys ++ [inst_ty]
+
+             -- Constraints arising from superclasses
+             -- See Note [Superclasses of derived instance]
+             cls_tvs  = classTyVars main_cls
+             sc_constraints = ASSERT2( equalLength cls_tvs inst_tys
+                                     , ppr main_cls <+> ppr inst_tys )
+                              [ mkThetaOrigin DerivOrigin TypeLevel [] [] $
+                                substTheta cls_subst (classSCTheta main_cls) ]
+             cls_subst = ASSERT( equalLength cls_tvs inst_tys )
+                         zipTvSubst cls_tvs inst_tys
+
+       ; (inferred_constraints, tvs', inst_tys') <- infer_constraints
+       ; lift $ traceTc "inferConstraints" $ vcat
               [ ppr main_cls <+> ppr inst_tys'
               , ppr inferred_constraints
               ]
        ; return ( sc_constraints ++ inferred_constraints
                 , tvs', inst_tys' ) }
-  where
-    is_anyclass = isDerivSpecAnyClass mechanism
-    infer_constraints
-      | is_anyclass = inferConstraintsDAC tvs main_cls inst_tys
-      | otherwise   = inferConstraintsDataConArgs tvs main_cls cls_tys inst_ty
-                                                  rep_tc rep_tc_args
-
-    inst_tys = cls_tys ++ [inst_ty]
-
-        -- Constraints arising from superclasses
-        -- See Note [Superclasses of derived instance]
-    cls_tvs  = classTyVars main_cls
-    sc_constraints = ASSERT2( equalLength cls_tvs inst_tys
-                            , ppr main_cls <+> ppr inst_tys )
-                     [ mkThetaOrigin DerivOrigin TypeLevel [] [] $
-                       substTheta cls_subst (classSCTheta main_cls) ]
-    cls_subst = ASSERT( equalLength cls_tvs inst_tys )
-                zipTvSubst cls_tvs inst_tys
 
 -- | Like 'inferConstraints', but used only in the case of deriving strategies
 -- where the constraints are inferred by inspecting the fields of each data
 -- constructor (i.e., stock- and newtype-deriving).
-inferConstraintsDataConArgs
-  :: [TyVar] -> Class -> [TcType] -> TcType -> TyCon -> [TcType]
-  -> TcM ([ThetaOrigin], [TyVar], [TcType])
-inferConstraintsDataConArgs tvs main_cls cls_tys inst_ty rep_tc rep_tc_args
-  | is_generic                             -- Generic constraints are easy
-  = return ([], tvs, inst_tys)
-
-  | is_generic1                            -- Generic1 needs Functor
-  = ASSERT( rep_tc_tvs `lengthExceeds` 0 ) -- See Note [Getting base classes]
-    ASSERT( cls_tys `lengthIs` 1 )         -- Generic1 has a single kind variable
-    do { functorClass <- tcLookupClass functorClassName
-       ; con_arg_constraints (get_gen1_constraints functorClass) }
-
-  | otherwise  -- The others are a bit more complicated
-  = -- See the comment with all_rep_tc_args for an explanation of
-    -- this assertion
-    ASSERT2( equalLength rep_tc_tvs all_rep_tc_args
-           , ppr main_cls <+> ppr rep_tc
-             $$ ppr rep_tc_tvs $$ ppr all_rep_tc_args )
-      do { (arg_constraints, tvs', inst_tys')
-             <- con_arg_constraints get_std_constrained_tys
-         ; traceTc "inferConstraintsDataConArgs" $ vcat
-                [ ppr main_cls <+> ppr inst_tys'
-                , ppr arg_constraints
-                ]
-         ; return ( stupid_constraints ++ extra_constraints ++ arg_constraints
-                  , tvs', inst_tys') }
-  where
-    tc_binders = tyConBinders rep_tc
-    choose_level bndr
-      | isNamedTyConBinder bndr = KindLevel
-      | otherwise               = TypeLevel
-    t_or_ks = map choose_level tc_binders ++ repeat TypeLevel
-       -- want to report *kind* errors when possible
-
-       -- Constraints arising from the arguments of each constructor
-    con_arg_constraints :: (CtOrigin -> TypeOrKind
-                                     -> Type
-                                     -> [([PredOrigin], Maybe TCvSubst)])
-                        -> TcM ([ThetaOrigin], [TyVar], [TcType])
-    con_arg_constraints get_arg_constraints
-      = let (predss, mbSubsts) = unzip
-              [ preds_and_mbSubst
-              | data_con <- tyConDataCons rep_tc
-              , (arg_n, arg_t_or_k, arg_ty)
-                  <- zip3 [1..] t_or_ks $
-                     dataConInstOrigArgTys data_con all_rep_tc_args
-                -- No constraints for unlifted types
-                -- See Note [Deriving and unboxed types]
-              , not (isUnliftedType arg_ty)
-              , let orig = DerivOriginDC data_con arg_n
-              , preds_and_mbSubst <- get_arg_constraints orig arg_t_or_k arg_ty
-              ]
-            preds = concat predss
-            -- If the constraints require a subtype to be of kind (* -> *)
-            -- (which is the case for functor-like constraints), then we
-            -- explicitly unify the subtype's kinds with (* -> *).
-            -- See Note [Inferring the instance context]
-            subst        = foldl' composeTCvSubst
-                                  emptyTCvSubst (catMaybes mbSubsts)
-            unmapped_tvs = filter (\v -> v `notElemTCvSubst` subst
-                                      && not (v `isInScope` subst)) tvs
-            (subst', _)  = mapAccumL substTyVarBndr subst unmapped_tvs
-            preds'       = map (substPredOrigin subst') preds
-            inst_tys'    = substTys subst' inst_tys
-            tvs'         = tyCoVarsOfTypesWellScoped inst_tys'
-        in return ([mkThetaOriginFromPreds preds'], tvs', inst_tys')
-
-    is_generic  = main_cls `hasKey` genClassKey
-    is_generic1 = main_cls `hasKey` gen1ClassKey
-    -- is_functor_like: see Note [Inferring the instance context]
-    is_functor_like = typeKind inst_ty `tcEqKind` typeToTypeKind
-                   || is_generic1
-
-    get_gen1_constraints :: Class -> CtOrigin -> TypeOrKind -> Type
-                         -> [([PredOrigin], Maybe TCvSubst)]
-    get_gen1_constraints functor_cls orig t_or_k ty
-       = mk_functor_like_constraints orig t_or_k functor_cls $
-         get_gen1_constrained_tys last_tv ty
-
-    get_std_constrained_tys :: CtOrigin -> TypeOrKind -> Type
-                            -> [([PredOrigin], Maybe TCvSubst)]
-    get_std_constrained_tys orig t_or_k ty
-        | is_functor_like = mk_functor_like_constraints orig t_or_k main_cls $
-                            deepSubtypesContaining last_tv ty
-        | otherwise       = [( [mk_cls_pred orig t_or_k main_cls ty]
-                             , Nothing )]
-
-    mk_functor_like_constraints :: CtOrigin -> TypeOrKind
-                                -> Class -> [Type]
+inferConstraintsDataConArgs :: TcType -> [TcType]
+                            -> DerivM ([ThetaOrigin], [TyVar], [TcType])
+inferConstraintsDataConArgs inst_ty inst_tys
+  = do DerivEnv { denv_tvs         = tvs
+                , denv_rep_tc      = rep_tc
+                , denv_rep_tc_args = rep_tc_args
+                , denv_cls         = main_cls
+                , denv_cls_tys     = cls_tys } <- ask
+
+       let tc_binders = tyConBinders rep_tc
+           choose_level bndr
+             | isNamedTyConBinder bndr = KindLevel
+             | otherwise               = TypeLevel
+           t_or_ks = map choose_level tc_binders ++ repeat TypeLevel
+              -- want to report *kind* errors when possible
+
+              -- Constraints arising from the arguments of each constructor
+           con_arg_constraints
+             :: (CtOrigin -> TypeOrKind
+                          -> Type
+                          -> [([PredOrigin], Maybe TCvSubst)])
+             -> ([ThetaOrigin], [TyVar], [TcType])
+           con_arg_constraints get_arg_constraints
+             = let (predss, mbSubsts) = unzip
+                     [ preds_and_mbSubst
+                     | data_con <- tyConDataCons rep_tc
+                     , (arg_n, arg_t_or_k, arg_ty)
+                         <- zip3 [1..] t_or_ks $
+                            dataConInstOrigArgTys data_con all_rep_tc_args
+                       -- No constraints for unlifted types
+                       -- See Note [Deriving and unboxed types]
+                     , not (isUnliftedType arg_ty)
+                     , let orig = DerivOriginDC data_con arg_n
+                     , preds_and_mbSubst
+                         <- get_arg_constraints orig arg_t_or_k arg_ty
+                     ]
+                   preds = concat predss
+                   -- If the constraints require a subtype to be of kind
+                   -- (* -> *) (which is the case for functor-like
+                   -- constraints), then we explicitly unify the subtype's
+                   -- kinds with (* -> *).
+                   -- See Note [Inferring the instance context]
+                   subst        = foldl' composeTCvSubst
+                                         emptyTCvSubst (catMaybes mbSubsts)
+                   unmapped_tvs = filter (\v -> v `notElemTCvSubst` subst
+                                             && not (v `isInScope` subst)) tvs
+                   (subst', _)  = mapAccumL substTyVarBndr subst unmapped_tvs
+                   preds'       = map (substPredOrigin subst') preds
+                   inst_tys'    = substTys subst' inst_tys
+                   tvs'         = tyCoVarsOfTypesWellScoped inst_tys'
+               in ([mkThetaOriginFromPreds preds'], tvs', inst_tys')
+
+           is_generic  = main_cls `hasKey` genClassKey
+           is_generic1 = main_cls `hasKey` gen1ClassKey
+           -- is_functor_like: see Note [Inferring the instance context]
+           is_functor_like = typeKind inst_ty `tcEqKind` typeToTypeKind
+                          || is_generic1
+
+           get_gen1_constraints :: Class -> CtOrigin -> TypeOrKind -> Type
                                 -> [([PredOrigin], Maybe TCvSubst)]
-    -- 'cls' is usually main_cls (Functor or Traversable etc), but if
-    -- main_cls = Generic1, then 'cls' can be Functor; see get_gen1_constraints
-    --
-    -- For each type, generate two constraints, [cls ty, kind(ty) ~ (*->*)],
-    -- and a kind substitution that results from unifying kind(ty) with * -> *.
-    -- If the unification is successful, it will ensure that the resulting
-    -- instance is well kinded. If not, the second constraint will result
-    -- in an error message which points out the kind mismatch.
-    -- See Note [Inferring the instance context]
-    mk_functor_like_constraints orig t_or_k cls
-       = map $ \ty -> let ki = typeKind ty in
-                      ( [ mk_cls_pred orig t_or_k cls ty
-                        , mkPredOrigin orig KindLevel
-                            (mkPrimEqPred ki typeToTypeKind) ]
-                      , tcUnifyTy ki typeToTypeKind
-                      )
-
-    rep_tc_tvs      = tyConTyVars rep_tc
-    last_tv         = last rep_tc_tvs
-    -- When we first gather up the constraints to solve, most of them contain
-    -- rep_tc_tvs, i.e., the type variables from the derived datatype's type
-    -- constructor. We don't want these type variables to appear in the final
-    -- instance declaration, so we must substitute each type variable with its
-    -- counterpart in the derived instance. rep_tc_args lists each of these
-    -- counterpart types in the same order as the type variables.
-    all_rep_tc_args = rep_tc_args ++ map mkTyVarTy
-                                     (drop (length rep_tc_args) rep_tc_tvs)
-
-    inst_tys = cls_tys ++ [inst_ty]
-
-        -- Stupid constraints
-    stupid_constraints = [ mkThetaOrigin DerivOrigin TypeLevel [] [] $
-                           substTheta tc_subst (tyConStupidTheta rep_tc) ]
-    tc_subst = -- See the comment with all_rep_tc_args for an explanation of
-               -- this assertion
-               ASSERT( equalLength rep_tc_tvs all_rep_tc_args )
-               zipTvSubst rep_tc_tvs all_rep_tc_args
-
-        -- Extra Data constraints
-        -- The Data class (only) requires that for
-        --    instance (...) => Data (T t1 t2)
-        -- IF   t1:*, t2:*
-        -- THEN (Data t1, Data t2) are among the (...) constraints
-        -- Reason: when the IF holds, we generate a method
-        --             dataCast2 f = gcast2 f
-        --         and we need the Data constraints to typecheck the method
-    extra_constraints = [mkThetaOriginFromPreds constrs]
-      where
-        constrs
-          | main_cls `hasKey` dataClassKey
-          , all (isLiftedTypeKind . typeKind) rep_tc_args
-          = [ mk_cls_pred DerivOrigin t_or_k main_cls ty
-            | (t_or_k, ty) <- zip t_or_ks rep_tc_args]
-          | otherwise
-          = []
-
-    mk_cls_pred orig t_or_k cls ty   -- Don't forget to apply to cls_tys' too
-       = mkPredOrigin orig t_or_k (mkClassPred cls (cls_tys' ++ [ty]))
-    cls_tys' | is_generic1 = [] -- In the awkward Generic1 case, cls_tys'
-                                -- should be empty, since we are applying the
-                                -- class Functor.
-             | otherwise   = cls_tys
+           get_gen1_constraints functor_cls orig t_or_k ty
+              = mk_functor_like_constraints orig t_or_k functor_cls $
+                get_gen1_constrained_tys last_tv ty
+
+           get_std_constrained_tys :: CtOrigin -> TypeOrKind -> Type
+                                   -> [([PredOrigin], Maybe TCvSubst)]
+           get_std_constrained_tys orig t_or_k ty
+               | is_functor_like
+               = mk_functor_like_constraints orig t_or_k main_cls $
+                 deepSubtypesContaining last_tv ty
+               | otherwise
+               = [( [mk_cls_pred orig t_or_k main_cls ty]
+                  , Nothing )]
+
+           mk_functor_like_constraints :: CtOrigin -> TypeOrKind
+                                       -> Class -> [Type]
+                                       -> [([PredOrigin], Maybe TCvSubst)]
+           -- 'cls' is usually main_cls (Functor or Traversable etc), but if
+           -- main_cls = Generic1, then 'cls' can be Functor; see
+           -- get_gen1_constraints
+           --
+           -- For each type, generate two constraints,
+           -- [cls ty, kind(ty) ~ (*->*)], and a kind substitution that results
+           -- from unifying  kind(ty) with * -> *. If the unification is
+           -- successful, it will ensure that the resulting instance is well
+           -- kinded. If not, the second constraint will result in an error
+           -- message which points out the kind mismatch.
+           -- See Note [Inferring the instance context]
+           mk_functor_like_constraints orig t_or_k cls
+              = map $ \ty -> let ki = typeKind ty in
+                             ( [ mk_cls_pred orig t_or_k cls ty
+                               , mkPredOrigin orig KindLevel
+                                   (mkPrimEqPred ki typeToTypeKind) ]
+                             , tcUnifyTy ki typeToTypeKind
+                             )
+
+           rep_tc_tvs      = tyConTyVars rep_tc
+           last_tv         = last rep_tc_tvs
+           -- When we first gather up the constraints to solve, most of them
+           -- contain rep_tc_tvs, i.e., the type variables from the derived
+           -- datatype's type constructor. We don't want these type variables
+           -- to appear in the final instance declaration, so we must
+           -- substitute each type variable with its counterpart in the derived
+           -- instance. rep_tc_args lists each of these counterpart types in
+           -- the same order as the type variables.
+           all_rep_tc_args
+             = rep_tc_args ++ map mkTyVarTy
+                                  (drop (length rep_tc_args) rep_tc_tvs)
+
+               -- Stupid constraints
+           stupid_constraints
+             = [ mkThetaOrigin DerivOrigin TypeLevel [] [] $
+                 substTheta tc_subst (tyConStupidTheta rep_tc) ]
+           tc_subst = -- See the comment with all_rep_tc_args for an
+                      -- explanation of this assertion
+                      ASSERT( equalLength rep_tc_tvs all_rep_tc_args )
+                      zipTvSubst rep_tc_tvs all_rep_tc_args
+
+           -- Extra Data constraints
+           -- The Data class (only) requires that for
+           --    instance (...) => Data (T t1 t2)
+           -- IF   t1:*, t2:*
+           -- THEN (Data t1, Data t2) are among the (...) constraints
+           -- Reason: when the IF holds, we generate a method
+           --             dataCast2 f = gcast2 f
+           --         and we need the Data constraints to typecheck the method
+           extra_constraints = [mkThetaOriginFromPreds constrs]
+             where
+               constrs
+                 | main_cls `hasKey` dataClassKey
+                 , all (isLiftedTypeKind . typeKind) rep_tc_args
+                 = [ mk_cls_pred DerivOrigin t_or_k main_cls ty
+                   | (t_or_k, ty) <- zip t_or_ks rep_tc_args]
+                 | otherwise
+                 = []
+
+           mk_cls_pred orig t_or_k cls ty
+                -- Don't forget to apply to cls_tys' too
+              = mkPredOrigin orig t_or_k (mkClassPred cls (cls_tys' ++ [ty]))
+           cls_tys' | is_generic1 = []
+                      -- In the awkward Generic1 case, cls_tys' should be
+                      -- empty, since we are applying the class Functor.
+
+                    | otherwise   = cls_tys
+
+       if    -- Generic constraints are easy
+          |  is_generic
+          -> return ([], tvs, inst_tys)
+
+             -- Generic1 needs Functor
+             -- See Note [Getting base classes]
+          |  is_generic1
+          -> ASSERT( rep_tc_tvs `lengthExceeds` 0 )
+             -- Generic1 has a single kind variable
+             ASSERT( cls_tys `lengthIs` 1 )
+             do { functorClass <- lift $ tcLookupClass functorClassName
+                ; pure $ con_arg_constraints
+                       $ get_gen1_constraints functorClass }
+
+             -- The others are a bit more complicated
+          |  otherwise
+          -> -- See the comment with all_rep_tc_args for an explanation of
+             -- this assertion
+             ASSERT2( equalLength rep_tc_tvs all_rep_tc_args
+                    , ppr main_cls <+> ppr rep_tc
+                      $$ ppr rep_tc_tvs $$ ppr all_rep_tc_args )
+               do { let (arg_constraints, tvs', inst_tys')
+                          = con_arg_constraints get_std_constrained_tys
+                  ; lift $ traceTc "inferConstraintsDataConArgs" $ vcat
+                         [ ppr main_cls <+> ppr inst_tys'
+                         , ppr arg_constraints
+                         ]
+                  ; return ( stupid_constraints ++ extra_constraints
+                                                ++ arg_constraints
+                           , tvs', inst_tys') }
 
 typeToTypeKind :: Kind
 typeToTypeKind = liftedTypeKind `mkFunTy` liftedTypeKind
@@ -260,43 +286,49 @@ typeToTypeKind = liftedTypeKind `mkFunTy` liftedTypeKind
 -- See Note [Gathering and simplifying constraints for DeriveAnyClass]
 -- for an explanation of how these constraints are used to determine the
 -- derived instance context.
-inferConstraintsDAC :: [TyVar] -> Class -> [TcType]
-                    -> TcM ([ThetaOrigin], [TyVar], [TcType])
-inferConstraintsDAC tvs cls inst_tys
-  = do { let gen_dms = [ (sel_id, dm_ty)
+inferConstraintsDAC :: [TcType] -> DerivM ([ThetaOrigin], [TyVar], [TcType])
+inferConstraintsDAC inst_tys
+  = do { DerivEnv { denv_tvs = tvs
+                  , denv_cls = cls } <- ask
+
+       ; let gen_dms = [ (sel_id, dm_ty)
                        | (sel_id, Just (_, GenericDM dm_ty)) <- classOpItems cls ]
 
-       ; theta_origins <- pushTcLevelM_ (mapM do_one_meth gen_dms)
+             cls_tvs = classTyVars cls
+             empty_subst = mkEmptyTCvSubst (mkInScopeSet (mkVarSet tvs))
+
+             do_one_meth :: (Id, Type) -> TcM ThetaOrigin
+               -- (Id,Type) are the selector Id and the generic default method type
+               -- NB: the latter is /not/ quantified over the class variables
+               -- See Note [Gathering and simplifying constraints for DeriveAnyClass]
+             do_one_meth (sel_id, gen_dm_ty)
+               = do { let (sel_tvs, _cls_pred, meth_ty)
+                                   = tcSplitMethodTy (varType sel_id)
+                          meth_ty' = substTyWith sel_tvs inst_tys meth_ty
+                          (meth_tvs, meth_theta, meth_tau)
+                                   = tcSplitNestedSigmaTys meth_ty'
+
+                          gen_dm_ty' = substTyWith cls_tvs inst_tys gen_dm_ty
+                          (dm_tvs, dm_theta, dm_tau)
+                                     = tcSplitNestedSigmaTys gen_dm_ty'
+
+                    ; (subst, _meta_tvs) <- pushTcLevelM_ $
+                                            newMetaTyVarsX empty_subst dm_tvs
+                      -- Yuk: the pushTcLevel is to match the one in mk_wanteds
+                      --      simplifyDeriv.  If we don't, the unification
+                      --      variables will bogusly be untouchable.
+
+                    ; let dm_theta' = substTheta subst dm_theta
+                          tau_eq = mkPrimEqPred meth_tau (substTy subst dm_tau)
+                    ; return (mkThetaOrigin DerivOrigin TypeLevel
+                                meth_tvs meth_theta (tau_eq:dm_theta')) }
+
+       ; theta_origins <- lift $ pushTcLevelM_ (mapM do_one_meth gen_dms)
             -- Yuk: the pushTcLevel is to match the one wrapping the call
             --      to mk_wanteds in simplifyDeriv.  If we omit this, the
             --      unification variables will wrongly be untouchable.
 
        ; return (theta_origins, tvs, inst_tys) }
-  where
-    cls_tvs = classTyVars cls
-    empty_subst = mkEmptyTCvSubst (mkInScopeSet (mkVarSet tvs))
-
-    do_one_meth :: (Id, Type) -> TcM ThetaOrigin
-      -- (Id,Type) are the selector Id and the generic default method type
-      -- NB: the latter is /not/ quantified over the class variables
-      -- See Note [Gathering and simplifying constraints for DeriveAnyClass]
-    do_one_meth (sel_id, gen_dm_ty)
-      = do { let (sel_tvs, _cls_pred, meth_ty) = tcSplitMethodTy (varType sel_id)
-                 meth_ty' = substTyWith sel_tvs inst_tys meth_ty
-                 (meth_tvs, meth_theta, meth_tau) = tcSplitNestedSigmaTys meth_ty'
-
-                 gen_dm_ty' = substTyWith cls_tvs inst_tys gen_dm_ty
-                 (dm_tvs, dm_theta, dm_tau) = tcSplitNestedSigmaTys gen_dm_ty'
-
-           ; (subst, _meta_tvs) <- pushTcLevelM_ $
-                                   newMetaTyVarsX empty_subst dm_tvs
-                -- Yuk: the pushTcLevel is to match the one in mk_wanteds
-                --      simplifyDeriv.  If we don't, the unification variables
-                --      will bogusly be untouchable.
-           ; let dm_theta' = substTheta subst dm_theta
-                 tau_eq    = mkPrimEqPred meth_tau (substTy subst dm_tau)
-           ; return (mkThetaOrigin DerivOrigin TypeLevel
-                                   meth_tvs meth_theta (tau_eq:dm_theta')) }
 
 {- Note [Inferring the instance context]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -314,7 +346,7 @@ There are two sorts of 'deriving':
     the instance context is user-supplied
 
 For a deriving clause (InferTheta) we must figure out the
-instance context (inferConstraints). Suppose we are inferring
+instance context (inferConstraintsDataConArgs). Suppose we are inferring
 the instance context for
     C t1 .. tn (T s1 .. sm)
 There are two cases
@@ -424,7 +456,7 @@ Let's call the context reqd for the T instance of class C at types
         Eq (T a b) = (Ping a, Pong b, ...)
 
 Now we can get a (recursive) equation from the data decl.  This part
-is done by inferConstraints.
+is done by inferConstraintsDataConArgs.
 
         Eq (T a b) = Eq (Foo a) u Eq (Bar b)    -- From C1
                    u Eq (T b a) u Eq Int        -- From C2
index 05d323c..fd0bf04 100644 (file)
@@ -9,6 +9,7 @@ Error-checking and other utilities for @deriving@ clauses or declarations.
 {-# LANGUAGE TypeFamilies #-}
 
 module TcDerivUtils (
+        DerivM, DerivEnv(..),
         DerivSpec(..), pprDerivSpec,
         DerivSpecMechanism(..), isDerivSpecStock,
         isDerivSpecNewtype, isDerivSpecAnyClass,
@@ -48,9 +49,68 @@ import Type
 import Util
 import VarSet
 
+import Control.Monad.Trans.Reader
 import qualified GHC.LanguageExtensions as LangExt
 import ListSetOps (assocMaybe)
 
+-- | To avoid having to manually plumb everything in 'DerivEnv' throughout
+-- various functions in @TcDeriv@ and @TcDerivInfer@, we use 'DerivM', which
+-- is a simple reader around 'TcRn'.
+type DerivM = ReaderT DerivEnv TcRn
+
+-- | Contains all of the information known about a derived instance when
+-- determining what its @EarlyDerivSpec@ should be.
+data DerivEnv = DerivEnv
+  { denv_overlap_mode :: Maybe OverlapMode
+    -- ^ Is this an overlapping instance?
+  , denv_tvs          :: [TyVar]
+    -- ^ Universally quantified type variables in the instance
+  , denv_cls          :: Class
+    -- ^ Class for which we need to derive an instance
+  , denv_cls_tys      :: [Type]
+    -- ^ Other arguments to the class except the last
+  , denv_tc           :: TyCon
+    -- ^ Type constructor for which the instance is requested
+    --   (last arguments to the type class)
+  , denv_tc_args      :: [Type]
+    -- ^ Arguments to the type constructor
+  , denv_rep_tc       :: TyCon
+    -- ^ The representation tycon for 'denv_tc'
+    --   (for data family instances)
+  , denv_rep_tc_args  :: [Type]
+    -- ^ The representation types for 'denv_tc_args'
+    --   (for data family instances)
+  , denv_mtheta       :: DerivContext
+    -- ^ 'Just' the context of the instance, for standalone deriving.
+    --   'Nothing' for @deriving@ clauses.
+  , denv_strat        :: Maybe DerivStrategy
+    -- ^ 'Just' if user requests a particular deriving strategy.
+    --   Otherwise, 'Nothing'.
+  }
+
+instance Outputable DerivEnv where
+  ppr (DerivEnv { denv_overlap_mode = overlap_mode
+                , denv_tvs          = tvs
+                , denv_cls          = cls
+                , denv_cls_tys      = cls_tys
+                , denv_tc           = tc
+                , denv_tc_args      = tc_args
+                , denv_rep_tc       = rep_tc
+                , denv_rep_tc_args  = rep_tc_args
+                , denv_mtheta       = mtheta
+                , denv_strat        = mb_strat })
+    = hang (text "DerivEnv")
+         2 (vcat [ text "denv_overlap_mode" <+> ppr overlap_mode
+                 , text "denv_tvs"          <+> ppr tvs
+                 , text "denv_cls"          <+> ppr cls
+                 , text "denv_cls_tys"      <+> ppr cls_tys
+                 , text "denv_tc"           <+> ppr tc
+                 , text "denv_tc_args"      <+> ppr tc_args
+                 , text "denv_rep_tc"       <+> ppr rep_tc
+                 , text "denv_rep_tc_args"  <+> ppr rep_tc_args
+                 , text "denv_mtheta"       <+> ppr mtheta
+                 , text "denv_strat"        <+> ppr mb_strat ])
+
 data DerivSpec theta = DS { ds_loc       :: SrcSpan
                           , ds_name      :: Name         -- DFun name
                           , ds_tvs       :: [TyVar]