Experimental alternative approach to invoking typechecker plugins
authorAdam Gundry <adam@well-typed.com>
Fri, 14 Nov 2014 16:23:52 +0000 (16:23 +0000)
committerAdam Gundry <adam@well-typed.com>
Fri, 14 Nov 2014 16:49:54 +0000 (16:49 +0000)
The solver is now provided with a boolean flag, which is False when
invoked inside solveFlats and True when invoked on the unflattened
constraints at the end.

compiler/typecheck/TcInteract.lhs
compiler/typecheck/TcRnTypes.lhs
compiler/typecheck/TcSMonad.lhs

index 78fb3f3..9890ab8 100644 (file)
@@ -43,7 +43,7 @@ import Data.List( partition, foldl' )
 
 import VarEnv
 
-import Control.Monad( when, unless, forM )
+import Control.Monad( when, unless, forM, foldM )
 import Pair (Pair(..))
 import Unique( hasKey )
 import FastString ( sLit )
@@ -134,9 +134,13 @@ solveFlatWanteds wanteds
 
        ; zonked <- zonkFlats (others `andCts` unflattened_eqs)
             -- Postcondition is that the wl_flats are zonked
-       ; return (WC { wc_flat  = zonked
-                    , wc_insol = insols
-                    , wc_impl  = implics }) }
+
+       ; (wanteds', rerun) <- runTcPluginsFinal zonked
+       ; if rerun then updInertTcS prepareInertsForImplications >> solveFlatWanteds wanteds'
+                  else return (WC { wc_flat  = wanteds'
+                                  , wc_insol = insols
+                                  , wc_impl  = implics }) }
+
 
 -- The main solver loop implements Note [Basic Simplifier Plan]
 ---------------------------------------------------------------
@@ -181,14 +185,9 @@ runTcPlugin :: TcPluginSolver -> TcS ()
 runTcPlugin solver =
   do iSet <- getTcSInerts
      let iCans    = inert_cans iSet
-         allCts   = foldDicts  (:) (inert_dicts iCans)
-                  $ foldFunEqs (:) (inert_funeqs iCans)
-                  $ concat (varEnvElts (inert_eqs iCans))
+         (given,derived,wanted) = splitInertCans iCans
 
-         (derived,other) = partition isDerivedCt allCts
-         (wanted,given)  = partition isWantedCt  other
-
-     result <- runTcPluginTcS (solver given derived wanted)
+     result <- runTcPluginTcS (solver False given derived wanted)
      case result of
 
        TcPluginContradiction bad_cts ->
@@ -197,7 +196,6 @@ runTcPlugin solver =
 
        TcPluginOk solved_cts new_cts ->
           do setInertCans (removeInertCts iCans (map snd solved_cts))
-             let setEv (ev,ct) = setEvBind (ctev_evar (cc_ev ct)) ev
              mapM_ setEv solved_cts
              updWorkListTcS (extendWorkListCts new_cts)
   where
@@ -225,6 +223,47 @@ runTcPlugin solver =
       CNonCanonical {} -> panic "runTcPlugin/removeInert: CNonCanonical"
       CHoleCan {}      -> panic "runTcPlugin/removeInert: CHoleCan"
 
+
+splitInertCans :: InertCans -> ([Ct], [Ct], [Ct])
+splitInertCans iCans = (given,derived,wanted)
+  where
+    allCts   = foldDicts  (:) (inert_dicts iCans)
+             $ foldFunEqs (:) (inert_funeqs iCans)
+             $ concat (varEnvElts (inert_eqs iCans))
+
+    (derived,other) = partition isDerivedCt allCts
+    (wanted,given)  = partition isWantedCt  other
+
+
+setEv :: (EvTerm,Ct) -> TcS ()
+setEv (ev,ct) = case ctEvidence ct of
+                  CtWanted {ctev_evar = evar} -> setEvBind evar ev
+                  _                           -> return ()
+
+
+runTcPluginsFinal :: Cts -> TcS (Cts, Bool)
+runTcPluginsFinal zonked_wanteds = do
+    gblEnv <- getGblEnv
+    (given,derived,_) <- fmap splitInertCans getInertCans
+    foldM (f given derived) (zonked_wanteds, False) (tcg_tc_plugins gblEnv)
+  where
+    f :: [Ct] -> [Ct] -> (Cts, Bool) -> TcPluginSolver -> TcS (Cts, Bool)
+    f given derived (wanteds, rerun) solver = do
+      result <- runTcPluginTcS (solver True given derived (bagToList wanteds))
+      case result of
+        TcPluginContradiction bad_cts -> do mapM_ emitInsoluble bad_cts
+                                            return (discard bad_cts wanteds, rerun)
+        TcPluginOk [] []              -> return (wanteds, rerun)
+        TcPluginOk solved_cts new_cts -> do
+             mapM_ setEv solved_cts
+             return (discard (map snd solved_cts) wanteds `unionBags` listToBag new_cts
+                    , rerun || notNull new_cts)
+      where
+        discard cs = filterBag (\ c -> not $ any (eqCt c) cs)
+
+        eqCt c c' = ctEvPred (ctEvidence c) `eqType` ctEvPred (ctEvidence c')
+
+
 type WorkItem = Ct
 type SimplifierStage = WorkItem -> TcS (StopOrContinue Ct)
 
index 3e0c053..402b7f3 100644 (file)
@@ -1974,7 +1974,8 @@ Constraint Solver Plugins
 
 \begin{code}
 
-type TcPluginSolver = [Ct]    -- given
+type TcPluginSolver = Bool
+                   -> [Ct]    -- given
                    -> [Ct]    -- derived
                    -> [Ct]    -- wanted
                    -> TcPluginM TcPluginResult
index da79f32..120c248 100644 (file)
@@ -14,7 +14,7 @@ module TcSMonad (
 
     updWorkListTcS, updWorkListTcS_return,
 
-    updInertCans, updInertDicts, updInertIrreds, updInertFunEqs,
+    updInertTcS, updInertCans, updInertDicts, updInertIrreds, updInertFunEqs,
 
     Ct(..), Xi, tyVarsOfCt, tyVarsOfCts,
     emitInsoluble, emitWorkNC,