separate tcplugin passes for improving givens vs solving wanteds
authorEric Seidel <gridaphobe@gmail.com>
Sun, 16 Nov 2014 18:45:29 +0000 (10:45 -0800)
committerEric Seidel <gridaphobe@gmail.com>
Sun, 16 Nov 2014 19:22:41 +0000 (11:22 -0800)
compiler/typecheck/TcInteract.lhs
compiler/typecheck/TcRnTypes.lhs

index 9890ab8..c8fedfb 100644 (file)
@@ -39,7 +39,7 @@ import TcErrors
 import TcSMonad
 import Bag
 
-import Data.List( partition, foldl' )
+import Data.List( partition )
 
 import VarEnv
 
@@ -119,11 +119,15 @@ solveFlatGivens loc givens
   | null givens  -- Shortcut for common case
   = return ()
   | otherwise
-  = solveFlats (listToBag (map mk_given_ct givens))
+  = go (listToBag (map mk_given_ct givens))
   where
     mk_given_ct ev_id = mkNonCanonical (CtGiven { ctev_evtm = EvId ev_id
                                                 , ctev_pred = evVarPred ev_id
                                                 , ctev_loc  = loc })
+    go givens = do { solveFlats givens
+                   ; (upd_givens, rerun) <- runTcPluginsGiven givens
+                   ; when rerun (go upd_givens)
+                   }
 
 solveFlatWanteds :: Cts -> TcS WantedConstraints
 solveFlatWanteds wanteds
@@ -135,7 +139,7 @@ solveFlatWanteds wanteds
        ; zonked <- zonkFlats (others `andCts` unflattened_eqs)
             -- Postcondition is that the wl_flats are zonked
 
-       ; (wanteds', rerun) <- runTcPluginsFinal zonked
+       ; (wanteds', rerun) <- runTcPluginsWanted zonked
        ; if rerun then updInertTcS prepareInertsForImplications >> solveFlatWanteds wanteds'
                   else return (WC { wc_flat  = wanteds'
                                   , wc_insol = insols
@@ -153,75 +157,20 @@ solveFlats cts
   = {-# SCC "solveFlats" #-}
     do { dyn_flags <- getDynFlags
        ; updWorkListTcS (\wl -> foldrBag extendWorkListCt wl cts)
-       ; solve_loop False (maxSubGoalDepth dyn_flags) }
+       ; solve_loop (maxSubGoalDepth dyn_flags) }
   where
-    solve_loop inertsModified max_depth
+    solve_loop max_depth
       = {-# SCC "solve_loop" #-}
         do { sel <- selectNextWorkItem max_depth
            ; case sel of
-
-              NoWorkRemaining
-                | inertsModified ->
-                    do gblEnv <- getGblEnv
-                       mapM_ runTcPlugin (tcg_tc_plugins gblEnv)
-                       solve_loop False max_depth
-
-                -- Done, successfuly (modulo frozen)
-                | otherwise -> return ()
-
+              NoWorkRemaining     -- Done, successfuly (modulo frozen)
+                -> return ()
 
               MaxDepthExceeded cnt ct -- Failure, depth exceeded
                 -> wrapErrTcS $ solverDepthErrorTcS cnt (ctEvidence ct)
 
               NextWorkItem ct     -- More work, loop around!
-                -> do { changes <- runSolverPipeline thePipeline ct
-                      ; let newMod = changes || inertsModified
-                      ; newMod `seq` solve_loop newMod max_depth } }
-
-
--- | Try to make progress using type-checker plugings.
--- The plugin is provided only with CTyEq and CFunEq constraints.
-runTcPlugin :: TcPluginSolver -> TcS ()
-runTcPlugin solver =
-  do iSet <- getTcSInerts
-     let iCans    = inert_cans iSet
-         (given,derived,wanted) = splitInertCans iCans
-
-     result <- runTcPluginTcS (solver False given derived wanted)
-     case result of
-
-       TcPluginContradiction bad_cts ->
-          do setInertCans (removeInertCts iCans bad_cts)
-             mapM_ emitInsoluble bad_cts
-
-       TcPluginOk solved_cts new_cts ->
-          do setInertCans (removeInertCts iCans (map snd solved_cts))
-             mapM_ setEv solved_cts
-             updWorkListTcS (extendWorkListCts new_cts)
-  where
-  removeInertCts :: InertCans -> [Ct] -> InertCans
-  removeInertCts = foldl' removeInertCt
-
-  -- Remove the constraint from the inert set.  We use this either when:
-  --   * a wanted constraint was solved, or
-  --   * some constraint was marked as insoluable, and so it will be
-  --     put right back into InertSet, but in the insoluable section.
-  removeInertCt :: InertCans -> Ct -> InertCans
-  removeInertCt is ct =
-    case ct of
-
-      CDictCan  { cc_class = cl, cc_tyargs = tys } ->
-        is { inert_dicts = delDict (inert_dicts is) cl tys }
-
-      CFunEqCan { cc_fun  = tf,  cc_tyargs = tys } ->
-        is { inert_funeqs = delFunEq (inert_funeqs is) tf tys }
-
-      CTyEqCan  { cc_tyvar = x,  cc_rhs    = ty  } ->
-        is { inert_eqs = delTyEq (inert_eqs is) x ty }
-
-      CIrredEvCan {}   -> panic "runTcPlugin/removeInert: CIrredEvCan"
-      CNonCanonical {} -> panic "runTcPlugin/removeInert: CNonCanonical"
-      CHoleCan {}      -> panic "runTcPlugin/removeInert: CHoleCan"
+                -> do { runSolverPipeline thePipeline ct; solve_loop max_depth } }
 
 
 splitInertCans :: InertCans -> ([Ct], [Ct], [Ct])
@@ -240,24 +189,49 @@ setEv (ev,ct) = case ctEvidence ct of
                   CtWanted {ctev_evar = evar} -> setEvBind evar ev
                   _                           -> return ()
 
+runTcPluginsGiven :: Cts -> TcS (Cts, Bool)
+runTcPluginsGiven givens = do
+    gblEnv <- getGblEnv
+    foldM f (givens, False) (tcg_tc_plugins gblEnv)
+  where
+    f :: (Cts, Bool) -> TcPluginSolver -> TcS (Cts, Bool)
+    f (givens, rerun) solver = do
+      result <- runTcPluginTcS (solver (bagToList givens) [] [])
+      case result of
+        TcPluginContradiction bad_cts -> do
+          mapM_ emitInsoluble bad_cts
+          return (discard bad_cts givens, rerun)
+        TcPluginOk [] []              -> return (givens, rerun)
+        TcPluginOk [] new_cts         -> do
+          let new_facts = [ct | ct <- new_cts, not (any (eqCt ct) (bagToList givens))]
+          updWorkListTcS (extendWorkListCts new_facts)
+          return ( unionBags givens (listToBag new_facts)
+                 , rerun || notNull new_facts)
+        TcPluginOk _solved_cts _new_cts ->
+          panic "runTcPluginsGiven: plugin solved a given constraint"
+      where
+        discard cs = filterBag (\ c -> not $ any (eqCt c) cs)
+        eqCt c c'  = ctEvPred (ctEvidence c) `eqType` ctEvPred (ctEvidence c')
 
-runTcPluginsFinal :: Cts -> TcS (Cts, Bool)
-runTcPluginsFinal zonked_wanteds = do
+runTcPluginsWanted :: Cts -> TcS (Cts, Bool)
+runTcPluginsWanted zonked_wanteds = do
     gblEnv <- getGblEnv
     (given,derived,_) <- fmap splitInertCans getInertCans
     foldM (f given derived) (zonked_wanteds, False) (tcg_tc_plugins gblEnv)
   where
     f :: [Ct] -> [Ct] -> (Cts, Bool) -> TcPluginSolver -> TcS (Cts, Bool)
     f given derived (wanteds, rerun) solver = do
-      result <- runTcPluginTcS (solver True given derived (bagToList wanteds))
+      result <- runTcPluginTcS (solver given derived (bagToList wanteds))
       case result of
         TcPluginContradiction bad_cts -> do mapM_ emitInsoluble bad_cts
                                             return (discard bad_cts wanteds, rerun)
         TcPluginOk [] []              -> return (wanteds, rerun)
         TcPluginOk solved_cts new_cts -> do
              mapM_ setEv solved_cts
-             return (discard (map snd solved_cts) wanteds `unionBags` listToBag new_cts
-                    , rerun || notNull new_cts)
+             let new_facts = [ct | ct <- new_cts, not (any (eqCt ct) (given ++ derived ++ bagToList wanteds))]
+             updWorkListTcS (extendWorkListCts new_facts)
+             return ( discard (map snd solved_cts) wanteds
+                    , rerun || notNull new_facts)
       where
         discard cs = filterBag (\ c -> not $ any (eqCt c) cs)
 
@@ -293,7 +267,7 @@ selectNextWorkItem max_depth
 
 runSolverPipeline :: [(String,SimplifierStage)] -- The pipeline
                   -> WorkItem                   -- The work item
-                  -> TcS Bool                   -- Did we modify the inert set
+                  -> TcS ()
 -- Run this item down the pipeline, leaving behind new work and inerts
 runSolverPipeline pipeline workItem
   = do { initial_is <- getTcSInerts
@@ -309,14 +283,13 @@ runSolverPipeline pipeline workItem
            Stop ev s       -> do { traceFireTcS ev s
                                  ; traceTcS "End solver pipeline (discharged) }"
                                        (ptext (sLit "inerts =") <+> ppr final_is)
-                                 ; return False }
+                                 ; return () }
            ContinueWith ct -> do { traceFireTcS (ctEvidence ct) (ptext (sLit "Kept as inert"))
                                  ; traceTcS "End solver pipeline (not discharged) }" $
                                        vcat [ ptext (sLit "final_item =") <+> ppr ct
                                             , pprTvBndrs (varSetElems $ tyVarsOfCt ct)
                                             , ptext (sLit "inerts     =") <+> ppr final_is]
-                                 ; insertInertItemTcS ct
-                                 ; return True }
+                                 ; insertInertItemTcS ct }
        }
   where run_pipeline :: [(String,SimplifierStage)] -> StopOrContinue Ct 
                      -> TcS (StopOrContinue Ct)
index 402b7f3..3e0c053 100644 (file)
@@ -1974,8 +1974,7 @@ Constraint Solver Plugins
 
 \begin{code}
 
-type TcPluginSolver = Bool
-                   -> [Ct]    -- given
+type TcPluginSolver = [Ct]    -- given
                    -> [Ct]    -- derived
                    -> [Ct]    -- wanted
                    -> TcPluginM TcPluginResult