Merge branch 'master' into type-nats
[ghc.git] / compiler / typecheck / TcSMonad.lhs
index 87cd5eb..c7c667e 100644 (file)
@@ -15,13 +15,15 @@ module TcSMonad (
     CanonicalCt(..), Xi, tyVarsOfCanonical, tyVarsOfCanonicals, tyVarsOfCDicts, 
     deCanonicalise, mkFrozenError,
 
-    isWanted, isGiven, isDerived,
-    isGivenCt, isWantedCt, isDerivedCt, pprFlavorArising,
+    isWanted, isGivenOrSolved, isDerived,
+    isGivenOrSolvedCt, isGivenCt_maybe, 
+    isWantedCt, isDerivedCt, pprFlavorArising,
 
     isFlexiTcsTv,
 
     canRewrite, canSolve,
-    combineCtLoc, mkGivenFlavor, mkWantedFlavor,
+    combineCtLoc, mkSolvedFlavor, mkGivenFlavor,
+    mkWantedFlavor,
     getWantedLoc,
 
     TcS, runTcS, failTcS, panicTcS, traceTcS, -- Basic functionality 
@@ -39,8 +41,11 @@ module TcSMonad (
 
     setWantedTyBind,
 
+    lookupFlatCacheMap, updateFlatCacheMap,
+
     getInstEnvs, getFamInstEnvs,                -- Getting the environments
     getTopEnv, getGblEnv, getTcEvBinds, getUntouchables,
+    tcsLookupClass, tcsLookupTyCon,
     getTcEvBindsBag, getTcSContext, getTcSTyBinds, getTcSTyBindsMap,
 
     newFlattenSkolemTy,                         -- Flatten skolems 
@@ -81,7 +86,9 @@ import FamInstEnv
 import qualified TcRnMonad as TcM
 import qualified TcMType as TcM
 import qualified TcEnv as TcM 
-       ( checkWellStaged, topIdLvl, tcLookupFamInst, tcGetDefaultTys )
+       ( checkWellStaged, topIdLvl, tcLookupFamInst, tcGetDefaultTys
+       , tcLookupClass, tcLookupTyCon )
+import Kind
 import TcType
 import DynFlags
 
@@ -97,14 +104,20 @@ import Outputable
 import Bag
 import MonadUtils
 import VarSet
+import Pair
 import FastString
 
 import HsBinds               -- for TcEvBinds stuff 
 import Id 
-
 import TcRnTypes
-
 import Data.IORef
+
+import qualified Data.Map as Map
+
+#ifdef DEBUG
+import StaticFlags( opt_PprStyle_Debug )
+import Control.Monad( when )
+#endif
 \end{code}
 
 
@@ -204,9 +217,9 @@ instance Outputable CanonicalCt where
   ppr (CIPCan ip fl ip_nm ty)     
       = ppr fl <+> ppr ip <+> dcolon <+> parens (ppr ip_nm <> dcolon <> ppr ty)
   ppr (CTyEqCan co fl tv ty)      
-      = ppr fl <+> ppr co <+> dcolon <+> pprEqPred (mkTyVarTy tv, ty)
+      = ppr fl <+> ppr co <+> dcolon <+> pprEqPred (Pair (mkTyVarTy tv) ty)
   ppr (CFunEqCan co fl tc tys ty) 
-      = ppr fl <+> ppr co <+> dcolon <+> pprEqPred (mkTyConApp tc tys, ty)
+      = ppr fl <+> ppr co <+> dcolon <+> pprEqPred (Pair (mkTyConApp tc tys) ty)
   ppr (CFrozenErr co fl)
       = ppr fl <+> pprEvVarWithType co
 \end{code}
@@ -330,11 +343,16 @@ getWantedLoc ct
 
 isWantedCt :: CanonicalCt -> Bool
 isWantedCt ct = isWanted (cc_flavor ct)
-isGivenCt :: CanonicalCt -> Bool
-isGivenCt ct = isGiven (cc_flavor ct)
 isDerivedCt :: CanonicalCt -> Bool
 isDerivedCt ct = isDerived (cc_flavor ct)
 
+isGivenCt_maybe :: CanonicalCt -> Maybe GivenKind
+isGivenCt_maybe ct = isGiven_maybe (cc_flavor ct)
+
+isGivenOrSolvedCt :: CanonicalCt -> Bool
+isGivenOrSolvedCt ct = isGivenOrSolved (cc_flavor ct)
+
+
 canSolve :: CtFlavor -> CtFlavor -> Bool 
 -- canSolve ctid1 ctid2 
 -- The constraint ctid1 can be used to solve ctid2 
@@ -359,22 +377,27 @@ canRewrite = canSolve
 
 combineCtLoc :: CtFlavor -> CtFlavor -> WantedLoc
 -- Precondition: At least one of them should be wanted 
-combineCtLoc (Wanted loc) _    = loc 
-combineCtLoc _ (Wanted loc)    = loc 
-combineCtLoc (Derived loc ) _  = loc 
-combineCtLoc _ (Derived loc )  = loc 
+combineCtLoc (Wanted loc) _    = loc
+combineCtLoc _ (Wanted loc)    = loc
+combineCtLoc (Derived loc ) _  = loc
+combineCtLoc _ (Derived loc )  = loc
 combineCtLoc _ _ = panic "combineCtLoc: both given"
 
-mkGivenFlavor :: CtFlavor -> SkolemInfo -> CtFlavor
-mkGivenFlavor (Wanted  loc) sk  = Given (setCtLocOrigin loc sk)
-mkGivenFlavor (Derived loc) sk  = Given (setCtLocOrigin loc sk)
-mkGivenFlavor (Given   loc) sk  = Given (setCtLocOrigin loc sk)
+mkSolvedFlavor :: CtFlavor -> SkolemInfo -> CtFlavor
+-- To be called when we actually solve a wanted/derived (perhaps leaving residual goals)
+mkSolvedFlavor (Wanted  loc) sk  = Given (setCtLocOrigin loc sk) GivenSolved
+mkSolvedFlavor (Derived loc) sk  = Given (setCtLocOrigin loc sk) GivenSolved
+mkSolvedFlavor fl@(Given {}) _sk = pprPanic "Solving a given constraint!" $ ppr fl
 
+mkGivenFlavor :: CtFlavor -> SkolemInfo -> CtFlavor
+mkGivenFlavor (Wanted  loc) sk  = Given (setCtLocOrigin loc sk) GivenOrig
+mkGivenFlavor (Derived loc) sk  = Given (setCtLocOrigin loc sk) GivenOrig
+mkGivenFlavor fl@(Given {}) _sk = pprPanic "Solving a given constraint!" $ ppr fl
 
 mkWantedFlavor :: CtFlavor -> CtFlavor
 mkWantedFlavor (Wanted  loc) = Wanted loc
 mkWantedFlavor (Derived loc) = Wanted loc
-mkWantedFlavor fl@(Given {}) = pprPanic "mkWantedFlavour" (ppr fl)
+mkWantedFlavor fl@(Given {}) = pprPanic "mkWantedFlavor" (ppr fl)
 \end{code}
 
 %************************************************************************
@@ -409,10 +432,33 @@ data TcSEnv
                      
       tcs_untch :: TcsUntouchables,
 
-      tcs_ic_depth :: Int,     -- Implication nesting depth
-      tcs_count    :: IORef Int        -- Global step count
+      tcs_ic_depth   :: Int,       -- Implication nesting depth
+      tcs_count      :: IORef Int, -- Global step count
+
+      tcs_flat_map   :: IORef FlatCache
     }
 
+data FlatCache 
+  = FlatCache { givenFlatCache  :: Map.Map FunEqHead (TcType,Coercion,CtFlavor)
+                -- Invariant: all CtFlavors here satisfy isGiven
+              , wantedFlatCache :: Map.Map FunEqHead (TcType,Coercion,CtFlavor) }
+                -- Invariant: all CtFlavors here satisfy isWanted
+
+emptyFlatCache :: FlatCache
+emptyFlatCache 
+ = FlatCache { givenFlatCache  = Map.empty, wantedFlatCache = Map.empty }
+
+newtype FunEqHead = FunEqHead (TyCon,[Xi])
+
+instance Eq FunEqHead where
+  FunEqHead (tc1,xis1) == FunEqHead (tc2,xis2) = tc1 == tc2 && eqTypes xis1 xis2
+
+instance Ord FunEqHead where
+  FunEqHead (tc1,xis1) `compare` FunEqHead (tc2,xis2) 
+    = case compare tc1 tc2 of 
+        EQ    -> cmpTypes xis1 xis2
+        other -> other
+
 type TcsUntouchables = (Untouchables,TcTyVarSet)
 -- Like the TcM Untouchables, 
 -- but records extra TcsTv variables generated during simplification
@@ -421,17 +467,16 @@ type TcsUntouchables = (Untouchables,TcTyVarSet)
 
 \begin{code}
 data SimplContext
-  = SimplInfer         -- Inferring type of a let-bound thing
-  | SimplRuleLhs       -- Inferring type of a RULE lhs
-  | SimplInteractive   -- Inferring type at GHCi prompt
-  | SimplCheck         -- Checking a type signature or RULE rhs
-  deriving Eq
+  = SimplInfer SDoc       -- Inferring type of a let-bound thing
+  | SimplRuleLhs RuleName  -- Inferring type of a RULE lhs
+  | SimplInteractive      -- Inferring type at GHCi prompt
+  | SimplCheck SDoc       -- Checking a type signature or RULE rhs
 
 instance Outputable SimplContext where
-  ppr SimplInfer       = ptext (sLit "SimplInfer")
-  ppr SimplRuleLhs     = ptext (sLit "SimplRuleLhs")
+  ppr (SimplInfer d)   = ptext (sLit "SimplInfer") <+> d
+  ppr (SimplCheck d)   = ptext (sLit "SimplCheck") <+> d
+  ppr (SimplRuleLhs n) = ptext (sLit "SimplRuleLhs") <+> doubleQuotes (ftext n)
   ppr SimplInteractive = ptext (sLit "SimplInteractive")
-  ppr SimplCheck       = ptext (sLit "SimplCheck")
 
 isInteractive :: SimplContext -> Bool
 isInteractive SimplInteractive = True
@@ -441,14 +486,14 @@ simplEqsOnly :: SimplContext -> Bool
 -- Simplify equalities only, not dictionaries
 -- This is used for the LHS of rules; ee
 -- Note [Simplifying RULE lhs constraints] in TcSimplify
-simplEqsOnly SimplRuleLhs = True
-simplEqsOnly _            = False
+simplEqsOnly (SimplRuleLhs {}) = True
+simplEqsOnly _                 = False
 
 performDefaulting :: SimplContext -> Bool
-performDefaulting SimplInfer              = False
-performDefaulting SimplRuleLhs            = False
-performDefaulting SimplInteractive = True
-performDefaulting SimplCheck       = True
+performDefaulting (SimplInfer {})   = False
+performDefaulting (SimplRuleLhs {}) = False
+performDefaulting SimplInteractive  = True
+performDefaulting (SimplCheck {})   = True
 
 ---------------
 newtype TcS a = TcS { unTcS :: TcSEnv -> TcM a } 
@@ -510,12 +555,14 @@ runTcS context untouch tcs
   = do { ty_binds_var <- TcM.newTcRef emptyVarEnv
        ; ev_binds_var@(EvBindsVar evb_ref _) <- TcM.newTcEvBinds
        ; step_count <- TcM.newTcRef 0
+       ; flat_cache_var <- TcM.newTcRef emptyFlatCache
        ; let env = TcSEnv { tcs_ev_binds = ev_binds_var
                           , tcs_ty_binds = ty_binds_var
                           , tcs_context  = context
                           , tcs_untch    = (untouch, emptyVarSet) -- No Tcs untouchables yet
                          , tcs_count    = step_count
                          , tcs_ic_depth = 0
+                          , tcs_flat_map = flat_cache_var
                           }
 
             -- Run the computation
@@ -526,7 +573,9 @@ runTcS context untouch tcs
 
 #ifdef DEBUG
        ; count <- TcM.readTcRef step_count
-       ; TcM.dumpTcRn (ptext (sLit "Constraint solver steps =") <+> int count)
+       ; when (opt_PprStyle_Debug && count > 0) $
+         TcM.debugDumpTcRn (ptext (sLit "Constraint solver steps =") 
+                            <+> int count <+> ppr context)
 #endif
              -- And return
        ; ev_binds      <- TcM.readTcRef evb_ref
@@ -540,21 +589,31 @@ nestImplicTcS ref (inner_range, inner_tcs) (TcS thing_inside)
                   , tcs_untch = (_outer_range, outer_tcs)
                   , tcs_count = count
                   , tcs_ic_depth = idepth
-                   , tcs_context = ctxt } ->
-    let 
-       inner_untch = (inner_range, outer_tcs `unionVarSet` inner_tcs)
+                   , tcs_context = ctxt 
+                   , tcs_flat_map = orig_flat_cache_var
+                   } ->
+    do { let inner_untch = (inner_range, outer_tcs `unionVarSet` inner_tcs)
                           -- The inner_range should be narrower than the outer one
                   -- (thus increasing the set of untouchables) but 
                   -- the inner Tcs-untouchables must be unioned with the
                   -- outer ones!
-       nest_env = TcSEnv { tcs_ev_binds = ref
-                         , tcs_ty_binds = ty_binds
-                         , tcs_untch    = inner_untch
-                         , tcs_count    = count
-                         , tcs_ic_depth = idepth+1
-                         , tcs_context  = ctxtUnderImplic ctxt }
-    in 
-    thing_inside nest_env
+
+       ; orig_flat_cache <- TcM.readTcRef orig_flat_cache_var
+       ; flat_cache_var  <- TcM.newTcRef orig_flat_cache
+       -- One could be more conservative as well: 
+       -- ; flat_cache_var  <- TcM.newTcRef emptyFlatCache 
+
+                            -- Consider copying the results the tcs_flat_map of the 
+                            -- incomping constraint, but we must make sure that we
+                            -- have pushed everything in, which seems somewhat fragile
+       ; let nest_env = TcSEnv { tcs_ev_binds = ref
+                               , tcs_ty_binds = ty_binds
+                               , tcs_untch    = inner_untch
+                               , tcs_count    = count
+                               , tcs_ic_depth = idepth+1
+                               , tcs_context  = ctxtUnderImplic ctxt 
+                               , tcs_flat_map = flat_cache_var }
+       ; thing_inside nest_env }
 
 recoverTcS :: TcS a -> TcS a -> TcS a
 recoverTcS (TcS recovery_code) (TcS thing_inside)
@@ -563,18 +622,21 @@ recoverTcS (TcS recovery_code) (TcS thing_inside)
 
 ctxtUnderImplic :: SimplContext -> SimplContext
 -- See Note [Simplifying RULE lhs constraints] in TcSimplify
-ctxtUnderImplic SimplRuleLhs = SimplCheck
-ctxtUnderImplic ctxt         = ctxt
+ctxtUnderImplic (SimplRuleLhs n) = SimplCheck (ptext (sLit "lhs of rule") 
+                                               <+> doubleQuotes (ftext n))
+ctxtUnderImplic ctxt              = ctxt
 
 tryTcS :: TcS a -> TcS a
--- Like runTcS, but from within the TcS monad 
+-- Like runTcS, but from within the TcS monad
 -- Ignore all the evidence generated, and do not affect caller's evidence!
-tryTcS tcs 
+tryTcS tcs
   = TcS (\env -> do { ty_binds_var <- TcM.newTcRef emptyVarEnv
                     ; ev_binds_var <- TcM.newTcEvBinds
+                    ; flat_cache_var <- TcM.newTcRef emptyFlatCache
                     ; let env1 = env { tcs_ev_binds = ev_binds_var
-                                     , tcs_ty_binds = ty_binds_var }
-                    ; unTcS tcs env1 })
+                                     , tcs_ty_binds = ty_binds_var
+                                     , tcs_flat_map = flat_cache_var }
+                   ; unTcS tcs env1 })
 
 -- Update TcEvBinds 
 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -597,12 +659,51 @@ getTcSTyBinds = TcS (return . tcs_ty_binds)
 getTcSTyBindsMap :: TcS (TyVarEnv (TcTyVar, TcType))
 getTcSTyBindsMap = getTcSTyBinds >>= wrapTcS . (TcM.readTcRef) 
 
+getFlatCacheMapVar :: TcS (IORef FlatCache)
+getFlatCacheMapVar
+  = TcS (return . tcs_flat_map)
+
+lookupFlatCacheMap :: TyCon -> [Xi] -> CtFlavor 
+                   -> TcS (Maybe (TcType,Coercion,CtFlavor))
+-- For givens, we lookup in given flat cache
+lookupFlatCacheMap tc xis (Given {})
+  = do { cache_ref <- getFlatCacheMapVar
+       ; cache_map <- wrapTcS $ TcM.readTcRef cache_ref
+       ; return $ Map.lookup (FunEqHead (tc,xis)) (givenFlatCache cache_map) }
+-- For wanteds, we first lookup in givenFlatCache.
+-- If we get nothing back then we lookup in wantedFlatCache.
+lookupFlatCacheMap tc xis (Wanted {})
+  = do { cache_ref <- getFlatCacheMapVar
+       ; cache_map <- wrapTcS $ TcM.readTcRef cache_ref
+       ; case Map.lookup (FunEqHead (tc,xis)) (givenFlatCache cache_map) of
+           Nothing -> return $ Map.lookup (FunEqHead (tc,xis)) (wantedFlatCache cache_map)
+           other   -> return other }
+lookupFlatCacheMap _tc _xis (Derived {}) = return Nothing
+
+updateFlatCacheMap :: TyCon -> [Xi]
+                   -> TcType -> CtFlavor -> Coercion -> TcS ()
+updateFlatCacheMap _tc _xis _tv (Derived {}) _co
+  = return () -- Not caching deriveds
+updateFlatCacheMap tc xis ty fl co
+  = do { cache_ref <- getFlatCacheMapVar
+       ; cache_map <- wrapTcS $ TcM.readTcRef cache_ref
+       ; let new_cache_map
+              | isGivenOrSolved fl
+              = cache_map { givenFlatCache = Map.insert (FunEqHead (tc,xis)) (ty,co,fl) $
+                                             givenFlatCache cache_map }
+              | isWanted fl
+              = cache_map { wantedFlatCache = Map.insert (FunEqHead (tc,xis)) (ty,co,fl) $
+                                              wantedFlatCache cache_map }
+              | otherwise = pprPanic "updateFlatCacheMap, met Derived!" $ empty
+       ; wrapTcS $ TcM.writeTcRef cache_ref new_cache_map }
+
 
 getTcEvBindsBag :: TcS EvBindMap
 getTcEvBindsBag
   = do { EvBindsVar ev_ref _ <- getTcEvBinds 
        ; wrapTcS $ TcM.readTcRef ev_ref }
 
+
 setCoBind :: CoVar -> Coercion -> TcS () 
 setCoBind cv co = setEvBind cv (EvCoercion co)
 
@@ -659,6 +760,12 @@ getTopEnv = wrapTcS $ TcM.getTopEnv
 getGblEnv :: TcS TcGblEnv 
 getGblEnv = wrapTcS $ TcM.getGblEnv 
 
+tcsLookupClass :: Name -> TcS Class
+tcsLookupClass name = wrapTcS (TcM.tcLookupClass name)
+
+tcsLookupTyCon :: Name -> TcS TyCon
+tcsLookupTyCon name = wrapTcS (TcM.tcLookupTyCon name)
+
 -- Various smaller utilities [TODO, maybe will be absorbed in the instance matcher]
 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
@@ -672,7 +779,7 @@ checkWellStagedDFun pred dfun_id loc
     bind_lvl = TcM.topIdLvl dfun_id
 
 pprEq :: TcType -> TcType -> SDoc
-pprEq ty1 ty2 = pprPred $ mkEqPred (ty1,ty2)
+pprEq ty1 ty2 = pprPredTy $ mkEqPred (ty1,ty2)
 
 isTouchableMetaTyVar :: TcTyVar -> TcS Bool
 isTouchableMetaTyVar tv 
@@ -817,13 +924,13 @@ matchClass clas tys
   = do { let pred = mkClassPred clas tys 
         ; instEnvs <- getInstEnvs
         ; case lookupInstEnv instEnvs clas tys of {
-            ([], unifs)               -- Nothing matches  
+            ([], unifs, _)               -- Nothing matches  
                 -> do { traceTcS "matchClass not matching"
                                  (vcat [ text "dict" <+> ppr pred, 
                                          text "unifs" <+> ppr unifs ]) 
                       ; return MatchInstNo  
                       } ;  
-           ([(ispec, inst_tys)], []) -- A single match 
+           ([(ispec, inst_tys)], [], _) -- A single match 
                -> do   { let dfun_id = is_dfun ispec
                        ; traceTcS "matchClass success"
                                   (vcat [text "dict" <+> ppr pred, 
@@ -832,7 +939,7 @@ matchClass clas tys
                                  -- Record that this dfun is needed
                         ; return $ MatchInstSingle (dfun_id, inst_tys)
                         } ;
-           (matches, unifs)          -- More than one matches 
+           (matches, unifs, _)          -- More than one matches 
                -> do   { traceTcS "matchClass multiple matches, deferring choice"
                                   (vcat [text "dict" <+> ppr pred,
                                          text "matches" <+> ppr matches,