Instantiate when inferring types
authorRichard Eisenberg <eir@cis.upenn.edu>
Wed, 29 Jul 2015 19:19:38 +0000 (15:19 -0400)
committerRichard Eisenberg <eir@cis.upenn.edu>
Wed, 29 Jul 2015 19:19:38 +0000 (15:19 -0400)
compiler/deSugar/DsBinds.hs
compiler/hsSyn/HsBinds.hs
compiler/typecheck/TcBinds.hs
compiler/typecheck/TcClassDcl.hs
compiler/typecheck/TcHsSyn.hs
compiler/typecheck/TcInstDcls.hs

index bd1cb84..da1a014 100644 (file)
@@ -137,17 +137,20 @@ dsHsBind (PatBind { pat_lhs = pat, pat_rhs = grhss, pat_rhs_ty = ty
 dsHsBind (AbsBinds { abs_tvs = tyvars, abs_ev_vars = dicts
                    , abs_exports = [export]
                    , abs_ev_binds = ev_binds, abs_binds = binds })
-  | ABE { abe_wrap = wrap, abe_poly = global
+  | ABE { abe_inst_wrap = inst_wrap, abe_wrap = wrap, abe_poly = global
         , abe_mono = local, abe_prags = prags } <- export
+    -- See Note [AbsBinds wrappers] in HsBinds
   = do  { dflags <- getDynFlags
         ; bind_prs <- ds_lhs_binds binds
         ; let core_bind = Rec (fromOL bind_prs)
         ; ds_binds <- dsTcEvBinds_s ev_binds
+        ; inner_rhs <- dsHsWrapper inst_wrap $
+                       mkCoreLets ds_binds $
+                       Let core_bind $
+                       Var local
         ; rhs <- dsHsWrapper wrap $  -- Usually the identity
-                            mkLams tyvars $ mkLams dicts $
-                            mkCoreLets ds_binds $
-                            Let core_bind $
-                            Var local
+                 mkLams tyvars $ mkLams dicts $
+                 inner_rhs
 
         ; (spec_binds, rules) <- dsSpecs rhs prags
 
@@ -178,13 +181,17 @@ dsHsBind (AbsBinds { abs_tvs = tyvars, abs_ev_vars = dicts
 
         ; poly_tup_id <- newSysLocalDs (exprType poly_tup_rhs)
 
-        ; let mk_bind (ABE { abe_wrap = wrap, abe_poly = global
+        ; let mk_bind (ABE { abe_inst_wrap = inst_wrap, abe_wrap = wrap
+                           , abe_poly = global
                            , abe_mono = local, abe_prags = spec_prags })
+                         -- See Note [AbsBinds wrappers] in HsBinds
                 = do { tup_id  <- newSysLocalDs tup_ty
+                     ; inner_rhs <- dsHsWrapper inst_wrap $
+                                    mkTupleSelector locals local tup_id $
+                                    mkVarApps (Var poly_tup_id) (tyvars ++ dicts)
                      ; rhs <- dsHsWrapper wrap $
-                                 mkLams tyvars $ mkLams dicts $
-                                 mkTupleSelector locals local tup_id $
-                                 mkVarApps (Var poly_tup_id) (tyvars ++ dicts)
+                              mkLams tyvars $ mkLams dicts $
+                              inner_rhs
                      ; let rhs_for_spec = Let (NonRec poly_tup_id poly_tup_rhs) rhs
                      ; (spec_binds, rules) <- dsSpecs rhs_for_spec spec_prags
                      ; let global' = (global `setInlinePragma` defaultInlinePragma)
@@ -266,8 +273,8 @@ dictArity :: [Var] -> Arity
 dictArity dicts = count isId dicts
 
 {-
-[Desugaring AbsBinds]
-~~~~~~~~~~~~~~~~~~~~~
+Note [Desugaring AbsBinds]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
 In the general AbsBinds case we desugar the binding to this:
 
        tup a (d:Num a) = let fm = ...gm...
index d934418..11ed429 100644 (file)
@@ -237,11 +237,13 @@ deriving instance (DataId idL, DataId idR)
         -- See Note [AbsBinds]
 
 data ABExport id
-  = ABE { abe_poly  :: id           -- ^ Any INLINE pragmas is attached to this Id
-        , abe_mono  :: id
-        , abe_wrap  :: HsWrapper    -- ^ See Note [AbsBinds wrappers]
-             -- Shape: (forall abs_tvs. abs_ev_vars => abe_mono) ~ abe_poly
-        , abe_prags :: TcSpecPrags  -- ^ SPECIALISE pragmas
+  = ABE { abe_poly      :: id    -- ^ Any INLINE pragmas is attached to this Id
+        , abe_mono      :: id
+        , abe_inst_wrap :: HsWrapper
+             -- ^ Shape: abe_mono ~ abe_insted
+        , abe_wrap      :: HsWrapper    -- ^ See Note [AbsBinds wrappers]
+             -- Shape: (forall abs_tvs. abs_ev_vars => abe_insted) ~ abe_poly
+        , abe_prags     :: TcSpecPrags  -- ^ SPECIALISE pragmas
   } deriving (Data, Typeable)
 
 -- | - 'ApiAnnotation.AnnKeywordId' : 'ApiAnnotation.AnnPattern',
@@ -338,6 +340,24 @@ The abe_wrap field deals with impedance-matching between
 and the thing we really want, which may have fewer type
 variables.  The action happens in TcBinds.mkExport.
 
+For abe_inst_wrap, consider this:
+  x = (*)
+The abe_mono type will be  forall a. Num a => a -> a -> a
+because no instantiation happens during typechecking. Before inferring
+a final type, we must instantiate this. See Note [Instantiate when inferring
+a type] in TcBinds. The abe_inst_wrap takes the uninstantiated abe_mono type
+to a proper instantiated type.
+
+It's conceivable that we could combine the two wrappers, but note that there
+is a gap: neither wrapper tacks on the tvs and dicts from the outer AbsBinds.
+These bits are added manually in desugaring. (See DsBinds.dsHsBind.) A problem
+that would arise in combining them is that zonking becomes more challenging:
+we want to zonk the tvs and dicts in the AbsBinds, but then we end up re-zonking
+when we zonk the ABExport. And -- worse -- the combined wrapper would have
+the tvs and dicts in binding positions, so they would shadow the original
+tvs and dicts. This is all resolvable with some plumbing, but it seems simpler
+just to keep the two wrappers distinct.
+
 Note [Bind free vars]
 ~~~~~~~~~~~~~~~~~~~~~
 The bind_fvs field of FunBind and PatBind records the free variables
@@ -510,10 +530,12 @@ ppr_monobind (AbsBinds { abs_tvs = tyvars, abs_ev_vars = dictvars
     , ifPprDebug (ptext (sLit "Evidence:") <+> ppr ev_binds) ]
 
 instance (OutputableBndr id) => Outputable (ABExport id) where
-  ppr (ABE { abe_wrap = wrap, abe_poly = gbl, abe_mono = lcl, abe_prags = prags })
+  ppr (ABE { abe_wrap = wrap, abe_inst_wrap = inst_wrap
+           , abe_poly = gbl, abe_mono = lcl, abe_prags = prags })
     = vcat [ ppr gbl <+> ptext (sLit "<=") <+> ppr lcl
            , nest 2 (pprTcSpecPrags prags)
-           , nest 2 (ppr wrap)]
+           , nest 2 (ppr wrap)
+           , nest 2 (ppr inst_wrap)]
 
 instance (OutputableBndr idL, OutputableBndr idR) => Outputable (PatSynBind idL idR) where
   ppr (PSB{ psb_id = L _ psyn, psb_args = details, psb_def = pat, psb_dir = dir })
@@ -625,7 +647,7 @@ data Sig name
       --          'ApiAnnotation.AnnComma'
 
       -- For details on above see note [Api annotations] in ApiAnnotation
-    TypeSig 
+    TypeSig
        [Located name]         -- LHS of the signature; e.g.  f,g,h :: blah
        (LHsType name)         -- RHS of the signature
        (PostRn name [Name])   -- Wildcards (both named and anonymous) of the RHS
index ad38bfb..d0baa17 100644 (file)
@@ -13,7 +13,7 @@ module TcBinds ( tcLocalBinds, tcTopBinds, tcRecSelBinds,
                  tcVectDecls,
                  TcSigInfo(..), TcSigFun, mkPragFun,
                  instTcTySig, instTcTySigFromId, findScopedTyVars,
-                 badBootDeclErr, mkExport ) where
+                 badBootDeclErr ) where
 
 import {-# SOURCE #-} TcMatches ( tcGRHSsPat, tcMatchesFun )
 import {-# SOURCE #-} TcExpr  ( tcMonoExpr )
@@ -30,7 +30,7 @@ import TcHsType
 import TcPat
 import TcMType
 import ConLike
-import Inst( topInstantiate )
+import Inst( topInstantiate, deeplyInstantiate )
 import FamInstEnv( normaliseType )
 import FamInst( tcGetFamInstEnvs )
 import Type( pprSigmaTypeExtraCts, tidyOpenTypes )
@@ -565,10 +565,11 @@ tcPolyCheck rec_tc prag_fn
        ; poly_id    <- addInlinePrags poly_id prag_sigs
 
        ; let (_, _, mono_id) = mono_info
-             export = ABE { abe_wrap = idHsWrapper
-                          , abe_poly = poly_id
-                          , abe_mono = mono_id
-                          , abe_prags = SpecPrags spec_prags }
+             export = ABE { abe_wrap      = idHsWrapper
+                          , abe_inst_wrap = idHsWrapper
+                          , abe_poly      = poly_id
+                          , abe_mono      = mono_id
+                          , abe_prags     = SpecPrags spec_prags }
              abs_bind = L loc $ AbsBinds
                         { abs_tvs = tvs
                         , abs_ev_vars = ev_vars, abs_ev_binds = [ev_binds]
@@ -579,6 +580,61 @@ tcPolyCheck _rec_tc _prag_fn sig _bind
   = pprPanic "tcPolyCheck" (ppr sig)
 
 ------------------
+{-
+Note [Instantiate when inferring a type]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+  f = (*)
+As there is no incentive to instantiate the RHS, tcMonoBinds will
+produce a type of forall a. Num a => a -> a -> a for `f`. This will then go
+through simplifyInfer and such, remaining unchanged.
+
+There are three problems with this:
+ 1) Even though `f` does not have a type signature, its type variable `a`
+    is considered "specified" (see Note [Visible type application] in TcExpr),
+    allowing it to participate in visible type application. Yet, when
+    `f` is exported, it will be exported with an *inferred* type variable,
+    preventing importing modules from using visible type application with
+    `f`.
+
+ 2) If the definition were `g _ = (*)`, we get a very unusual type of
+    `forall {a}. a -> forall b. Num b => b -> b -> b` for `g`. This is
+    surely confusing for users.
+
+ 3) The monomorphism restriction can't work. The MR is dealt with in
+    simplifyInfer, and simplifyInfer has no way of instantiating. This
+    could perhaps be worked around, but it may be hard to know even
+    when instantiation should happen.
+
+There is an easy solution to all three problems: instantiate (deeply) when
+inferring a type. So that's what we do. Note that this decision is
+user-facing.
+
+Here are the details:
+ * tcMonoBinds produces the "monomorphic" ids to be put in the AbsBinds.
+   It is inconvenient to instantiate in this function or below. So the
+   monomorphic ids will be uninstantiated (and hence actually polymorphic,
+   but that doesn't ruin anyone's day).
+
+ * In the same captureConstraints as the tcMonoBinds, we instantiate all
+   the types of the monomorphic ids. Instantiating will produce constraints
+   to solve and instantiated types. These constraints and the instantiated
+   types go into simplifyInfer. HsWrappers are produced that go from
+   the "mono" types to the instantiated ones.
+
+ * simplifyInfer does its magic, figuring out how to regeneralize.
+
+ * mkExport then does the impedence matching and needs to connect the
+   monomorphic ids to the polymorphic types as decided by simplifyInfer.
+   Because the instantiation happens before simplifyInfer, we also pass in
+   the HsWrappers obtained via instantiating so that mkExport can connect
+   all the pieces.
+
+ * We produce an AbsBinds with the right (instantiated and then, perhaps,
+   regeneralized) polytypes and the not-yet-instantiated "monomorphic" ids,
+   using the built HsWrappers to connect. Done!
+-}
+
 tcPolyInfer
   :: RecFlag       -- Whether it's recursive after breaking
                    -- dependencies based on type signatures
@@ -587,18 +643,28 @@ tcPolyInfer
   -> [LHsBind Name]
   -> TcM (LHsBinds TcId, [TcId])
 tcPolyInfer rec_tc prag_fn tc_sig_fn mono bind_list
-  = do { ((binds', mono_infos), tclvl, wanted)
+  = do { ((binds', mono_infos, wrappers, insted_tys), tclvl, wanted)
              <- pushLevelAndCaptureConstraints  $
-                tcMonoBinds rec_tc tc_sig_fn LetLclBndr bind_list
-
-       ; let name_taus = [(name, idType mono_id) | (name, _, mono_id) <- mono_infos]
+             do { (binds', mono_infos)
+                    <- tcMonoBinds rec_tc tc_sig_fn LetLclBndr bind_list
+                  -- See Note [Instantiate when inferring a type]
+                ; mono_tys <- mapM (zonkTcType . idType . thirdOf3) mono_infos
+                    -- NB: zonk to uncover any foralls
+                ; (wrappers, insted_tys)
+                         -- TODO (RAE): Fix origin
+                    <- mapAndUnzipM (deeplyInstantiate AppOrigin) mono_tys
+                ; return (binds', mono_infos, wrappers, insted_tys) }
+
+       ; let name_taus = [(name, tau) | ((name, _, _), tau)
+                                          <- zip mono_infos insted_tys]
        ; traceTc "simplifyInfer call" (ppr name_taus $$ ppr wanted)
        ; (qtvs, givens, _mr_bites, ev_binds)
                  <- simplifyInfer tclvl mono name_taus wanted
 
        ; let inferred_theta = map evVarPred givens
        ; exports <- checkNoErrs $
-                    mapM (mkExport prag_fn qtvs inferred_theta) mono_infos
+                    zipWith3M (mkExport prag_fn qtvs inferred_theta)
+                              mono_infos wrappers insted_tys
 
        ; loc <- getSrcSpanM
        ; let poly_ids = map abe_poly exports
@@ -615,8 +681,11 @@ tcPolyInfer rec_tc prag_fn tc_sig_fn mono bind_list
 mkExport :: PragFun
          -> [TyVar] -> TcThetaType      -- Both already zonked
          -> MonoBindInfo
+         -> HsWrapper -- the instantiation wrapper;
+                      -- see Note [Instantiate when inferring a type]
+         -> TcTauType -- the instantiated type
          -> TcM (ABExport Id)
--- Only called for generalisation plan IferGen, not by CheckGen or NoGen
+-- Only called for generalisation plan InferGen, not by CheckGen or NoGen
 --
 -- mkExport generates exports with
 --      zonked type variables,
@@ -630,16 +699,17 @@ mkExport :: PragFun
 -- Pre-condition: the qtvs and theta are already zonked
 
 mkExport prag_fn qtvs inferred_theta (poly_name, mb_sig, mono_id)
-  = do  { mono_ty <- zonkTcType (idType mono_id)
+         inst_wrap inst_ty
+  = do  { inst_ty <- zonkTcType inst_ty
 
         ; (poly_id, inferred) <- case mb_sig of
-              Nothing  -> do { poly_id <- mkInferredPolyId poly_name qtvs inferred_theta mono_ty
+              Nothing  -> do { poly_id <- mkInferredPolyId poly_name qtvs inferred_theta inst_ty
                              ; return (poly_id, True) }
               Just sig | Just poly_id <- completeSigPolyId_maybe sig
                        -> return (poly_id, False)
                        | otherwise
                        -> do { final_theta <- completeTheta inferred_theta sig
-                             ; poly_id <- mkInferredPolyId poly_name qtvs final_theta mono_ty
+                             ; poly_id <- mkInferredPolyId poly_name qtvs final_theta inst_ty
                              ; return (poly_id, True) }
 
         -- NB: poly_id has a zonked type
@@ -647,7 +717,7 @@ mkExport prag_fn qtvs inferred_theta (poly_name, mb_sig, mono_id)
         ; spec_prags <- tcSpecPrags poly_id prag_sigs
                 -- tcPrags requires a zonked poly_id
 
-        ; let sel_poly_ty = mkSigmaTy qtvs inferred_theta mono_ty
+        ; let sel_poly_ty = mkSigmaTy qtvs inferred_theta inst_ty
         ; traceTc "mkExport: check sig"
                   (vcat [ ppr poly_name, ppr sel_poly_ty, ppr (idType poly_id) ])
 
@@ -663,10 +733,11 @@ mkExport prag_fn qtvs inferred_theta (poly_name, mb_sig, mono_id)
                             tcSubType_NC sig_ctxt sel_poly_ty (idType poly_id)
         ; ev_binds <- simplifyTop wanted
 
-        ; return (ABE { abe_wrap = mkWpLet (EvBinds ev_binds) <.> wrap
-                      , abe_poly = poly_id
-                      , abe_mono = mono_id
-                      , abe_prags = SpecPrags spec_prags}) }
+        ; return (ABE { abe_wrap      = mkWpLet (EvBinds ev_binds) <.> wrap
+                      , abe_inst_wrap = inst_wrap
+                      , abe_poly      = poly_id
+                      , abe_mono      = mono_id
+                      , abe_prags     = SpecPrags spec_prags}) }
   where
     prag_sigs = prag_fn poly_name
     sig_ctxt  = InfSigCtxt poly_name
@@ -823,7 +894,7 @@ We can get these by "impedance matching":
 
 Suppose the shared quantified tyvars are qtvs and constraints theta.
 Then we want to check that
-   f's polytype  is more polymorphic than   forall qtvs. theta => f_mono_ty
+     forall qtvs. theta => f_mono_ty   is more polymorphic than   f's polytype
 and the proof is the impedance matcher.
 
 Notice that the impedance matcher may do defaulting.  See Trac #7173.
index bc1bac2..7a4e256 100644 (file)
@@ -9,7 +9,7 @@ Typechecking class declarations
 {-# LANGUAGE CPP #-}
 
 module TcClassDcl ( tcClassSigs, tcClassDecl2,
-                    findMethodBind, instantiateMethod, 
+                    findMethodBind, instantiateMethod,
                     tcClassMinimalDef,
                     HsSigFun, mkHsSigFun, lookupHsSig, emptyHsSigs,
                     tcMkDeclCtxt, tcAddDeclCtxt, badMethodErr
@@ -232,13 +232,14 @@ tcDefMeth clas tyvars this_dict binds_in
                   tcPolyCheck NonRecursive no_prag_fn local_dm_sig'
                               (L bind_loc lm_bind)
 
-        ; let export = ABE { abe_poly  = global_dm_id
+        ; let export = ABE { abe_poly      = global_dm_id
                            -- We have created a complete type signature in
                            -- instTcTySig, hence it is safe to call
                            -- completeSigPolyId
-                           , abe_mono  = completeSigPolyId local_dm_sig'
-                           , abe_wrap  = idHsWrapper
-                           , abe_prags = IsDefaultMethod }
+                           , abe_mono      = completeSigPolyId local_dm_sig'
+                           , abe_wrap      = idHsWrapper
+                           , abe_inst_wrap = idHsWrapper
+                           , abe_prags     = IsDefaultMethod }
               full_bind = AbsBinds { abs_tvs      = tyvars
                                    , abs_ev_vars  = [this_dict]
                                    , abs_exports  = [export]
index e616ce8..3dc302a 100644 (file)
@@ -480,12 +480,15 @@ zonk_bind env sig_warn (AbsBinds { abs_tvs = tyvars, abs_ev_vars = evs
                           , abs_ev_binds = new_ev_binds
                           , abs_exports = new_exports, abs_binds = new_val_bind }) }
   where
-    zonkExport env (ABE{ abe_wrap = wrap, abe_poly = poly_id
+    zonkExport env (ABE{ abe_wrap = wrap, abe_inst_wrap = inst_wrap
+                       , abe_poly = poly_id
                        , abe_mono = mono_id, abe_prags = prags })
         = do new_poly_id <- zonkIdBndr env poly_id
              (_, new_wrap) <- zonkCoFn env wrap
+             (_, new_inst_wrap) <- zonkCoFn env inst_wrap
              new_prags <- zonkSpecPrags env prags
-             return (ABE{ abe_wrap = new_wrap, abe_poly = new_poly_id
+             return (ABE{ abe_wrap = new_wrap, abe_inst_wrap = new_inst_wrap
+                        , abe_poly = new_poly_id
                         , abe_mono = zonkIdOcc env mono_id
                         , abe_prags = new_prags })
 
index 88b843e..e30d045 100644 (file)
@@ -895,7 +895,8 @@ tcInstDecl2 (InstInfo { iSpec = ispec, iBinds = ibinds })
                 | otherwise
                 = SpecPrags spec_inst_prags
 
-             export = ABE { abe_wrap = idHsWrapper, abe_poly = dfun_id
+             export = ABE { abe_wrap = idHsWrapper, abe_inst_wrap = idHsWrapper
+                          , abe_poly = dfun_id
                           , abe_mono = self_dict, abe_prags = dfun_spec_prags }
                           -- NB: see Note [SPECIALISE instance pragmas]
              main_bind = AbsBinds { abs_tvs = inst_tyvars
@@ -1007,7 +1008,9 @@ tcSuperClasses dfun_id cls tyvars dfun_evs inst_tys dfun_ev_binds _fam_envs sc_t
            ; sc_top_name <- newName (mkSuperDictAuxOcc n (getOccName cls))
            ; let sc_top_ty = mkForAllTys tyvars (mkPiTypes dfun_evs sc_pred)
                  sc_top_id = mkLocalId sc_top_name sc_top_ty HasSigId
-                 export = ABE { abe_wrap = idHsWrapper, abe_poly = sc_top_id
+                 export = ABE { abe_wrap = idHsWrapper
+                              , abe_inst_wrap = idHsWrapper
+                              , abe_poly = sc_top_id
                               , abe_mono = sc_ev_id
                               , abe_prags = SpecPrags [] }
                  local_ev_binds = TcEvBinds (ic_binds sc_implic)
@@ -1336,7 +1339,8 @@ tcMethods dfun_id clas tyvars dfun_ev_vars inst_tys
                         -- method to this version. Note [INLINE and default methods]
 
 
-                 export = ABE { abe_wrap = hs_wrap, abe_poly = meth_id1
+                 export = ABE { abe_wrap = hs_wrap, abe_inst_wrap = idHsWrapper
+                              , abe_poly = meth_id1
                               , abe_mono = local_meth_id
                               , abe_prags = mk_meth_spec_prags meth_id1 spec_inst_prags [] }
                  bind = AbsBinds { abs_tvs = tyvars, abs_ev_vars = dfun_ev_vars
@@ -1392,10 +1396,11 @@ tcMethodBody clas tyvars dfun_ev_vars inst_tys
                               (L bind_loc lm_bind)
 
         ; let specs  = mk_meth_spec_prags global_meth_id spec_inst_prags spec_prags
-              export = ABE { abe_poly  = global_meth_id
-                           , abe_mono  = local_meth_id
-                           , abe_wrap  = hs_wrap
-                           , abe_prags = specs }
+              export = ABE { abe_poly      = global_meth_id
+                           , abe_mono      = local_meth_id
+                           , abe_wrap      = hs_wrap
+                           , abe_inst_wrap = idHsWrapper
+                           , abe_prags     = specs }
 
               local_ev_binds = TcEvBinds (ic_binds meth_implic)
               full_bind = AbsBinds { abs_tvs      = tyvars