Refactor Pattern Match Checker to use ListT
authorMatthew Pickering <matthewtpickering@gmail.com>
Tue, 29 Nov 2016 19:43:43 +0000 (14:43 -0500)
committerBen Gamari <ben@smart-cactus.org>
Tue, 29 Nov 2016 19:43:44 +0000 (14:43 -0500)
Reviewers: bgamari, austin

Reviewed By: bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2725

compiler/deSugar/Check.hs
compiler/ghc.cabal.in
compiler/utils/ListT.hs [new file with mode: 0644]

index b5f6eac..04ba568 100644 (file)
@@ -50,6 +50,8 @@ import Coercion
 import TcEvidence
 import IOEnv
 
+import ListT (ListT(..), fold)
+
 {-
 This module checks pattern matches for:
 \begin{enumerate}
@@ -72,7 +74,25 @@ The algorithm is based on the paper:
 %************************************************************************
 -}
 
-type PmM a = DsM a
+-- We use the non-determinism monad to apply the algorithm to several
+-- possible sets of constructors. Users can specify complete sets of
+-- constructors by using COMPLETE pragmas.
+-- The algorithm only picks out constructor
+-- sets deep in the bowels which makes a simpler `mapM` more difficult to
+-- implement. The non-determinism is only used in one place, see the ConVar
+-- case in `pmCheckHd`.
+
+type PmM a = ListT DsM a
+
+liftD :: DsM a -> PmM a
+liftD m = ListT $ \sk fk -> m >>= \a -> sk a fk
+
+
+myRunListT :: PmM a -> DsM [a]
+myRunListT pm = fold pm go (return [])
+  where
+    go a mas =
+      mas >>= \as -> return (a:as)
 
 data PatTy = PAT | VA -- Used only as a kind, to index PmPat
 
@@ -122,14 +142,64 @@ type Uncovered = ValSetAbs
 --  C = True             ==> Useful clause (no warning)
 --  C = False, D = True  ==> Clause with inaccessible RHS
 --  C = False, D = False ==> Redundant clause
-type Triple = (Bool, Uncovered, Bool)
+
+data Covered = Covered | NotCovered
+  deriving Show
+
+instance Outputable Covered where
+  ppr (Covered) = text "Covered"
+  ppr (NotCovered) = text "NotCovered"
+
+-- Like the or monoid for booleans
+-- Covered = True, Uncovered = False
+instance Monoid Covered where
+  mempty = NotCovered
+  Covered `mappend` _ = Covered
+  _ `mappend` Covered = Covered
+  NotCovered `mappend` NotCovered = NotCovered
+
+data Diverged = Diverged | NotDiverged
+  deriving Show
+
+instance Outputable Diverged where
+  ppr Diverged = text "Diverged"
+  ppr NotDiverged = text "NotDiverged"
+
+instance Monoid Diverged where
+  mempty = NotDiverged
+  Diverged `mappend` _ = Diverged
+  _ `mappend` Diverged = Diverged
+  NotDiverged `mappend` NotDiverged = NotDiverged
+
+data PartialResult = PartialResult {
+                      presultCovered :: Covered
+                      , presultUncovered :: Uncovered
+                      , presultDivergent :: Diverged }
+
+instance Outputable PartialResult where
+  ppr (PartialResult c vsa d) = text "PartialResult" <+> ppr c
+                                  <+> ppr d <+> ppr vsa
+
+instance Monoid PartialResult where
+  mempty = PartialResult mempty [] mempty
+  (PartialResult cs1 vsa1 ds1)
+    `mappend` (PartialResult cs2 vsa2 ds2)
+      = PartialResult (cs1 `mappend` cs2)
+                      (vsa1 `mappend` vsa2)
+                      (ds1 `mappend` ds2)
+
+-- newtype ChoiceOf a = ChoiceOf [a]
 
 -- | Pattern check result
 --
 -- * Redundant clauses
 -- * Not-covered clauses
 -- * Clauses with inaccessible RHS
-type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]])
+data PmResult =
+  PmResult {
+    pmresultRedundant :: [Located [LPat Id]]
+    , pmresultUncovered :: Uncovered
+    , pmresultInaccessible :: [Located [LPat Id]] }
 
 {-
 %************************************************************************
@@ -142,63 +212,67 @@ type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]])
 -- | Check a single pattern binding (let)
 checkSingle :: DynFlags -> DsMatchContext -> Id -> Pat Id -> DsM ()
 checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do
-  tracePm "checkSingle" (vcat [ppr ctxt, ppr var, ppr p])
-  mb_pm_res <- tryM (checkSingle' locn var p)
+  tracePmD "checkSingle" (vcat [ppr ctxt, ppr var, ppr p])
+  mb_pm_res <- tryM (head <$> myRunListT (checkSingle' locn var p))
   case mb_pm_res of
     Left  _   -> warnPmIters dflags ctxt
     Right res -> dsPmWarn dflags ctxt res
 
 -- | Check a single pattern binding (let)
-checkSingle' :: SrcSpan -> Id -> Pat Id -> DsM PmResult
+checkSingle' :: SrcSpan -> Id -> Pat Id -> PmM PmResult
 checkSingle' locn var p = do
-  resetPmIterDs -- set the iter-no to zero
-  fam_insts <- dsGetFamInstEnvs
-  clause    <- translatePat fam_insts p
+  liftD resetPmIterDs -- set the iter-no to zero
+  fam_insts <- liftD dsGetFamInstEnvs
+  clause    <- liftD $ translatePat fam_insts p
   missing   <- mkInitialUncovered [var]
   tracePm "checkSingle: missing" (vcat (map pprValVecDebug missing))
-  (cs,us,ds) <- runMany (pmcheckI clause []) missing -- no guards
+  PartialResult cs us ds <- runMany (pmcheckI clause []) missing -- no guards
   return $ case (cs,ds) of
-    (True,  _    ) -> ([], us, []) -- useful
-    (False, False) -> ( m, us, []) -- redundant
-    (False, True ) -> ([], us,  m) -- inaccessible rhs
+    (Covered,  _    )         -> PmResult [] us [] -- useful
+    (NotCovered, NotDiverged) -> PmResult m us []  -- redundant
+    (NotCovered, Diverged )   -> PmResult [] us m  -- inaccessible rhs
   where m = [L locn [L locn p]]
 
 -- | Check a matchgroup (case, functions, etc.)
 checkMatches :: DynFlags -> DsMatchContext
              -> [Id] -> [LMatch Id (LHsExpr Id)] -> DsM ()
 checkMatches dflags ctxt vars matches = do
-  tracePm "checkMatches" (hang (vcat [ppr ctxt
+  tracePmD "checkMatches" (hang (vcat [ppr ctxt
                                , ppr vars
                                , text "Matches:"])
                                2
                                (vcat (map ppr matches)))
-  mb_pm_res <- tryM (checkMatches' vars matches)
+  mb_pm_res <- tryM (head <$> myRunListT (checkMatches' vars matches))
   case mb_pm_res of
     Left  _   -> warnPmIters dflags ctxt
     Right res -> dsPmWarn dflags ctxt res
 
 -- | Check a matchgroup (case, functions, etc.)
-checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> DsM PmResult
+checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> PmM PmResult
 checkMatches' vars matches
-  | null matches = return ([], [], [])
+  | null matches = return $ PmResult [] [] []
   | otherwise = do
-      resetPmIterDs -- set the iter-no to zero
+      liftD resetPmIterDs -- set the iter-no to zero
       missing    <- mkInitialUncovered vars
       tracePm "checkMatches: missing" (vcat (map pprValVecDebug missing))
       (rs,us,ds) <- go matches missing
-      return (map hsLMatchToLPats rs, us, map hsLMatchToLPats ds)
+      return $ PmResult (map hsLMatchToLPats rs) us (map hsLMatchToLPats ds)
   where
+    go :: [LMatch Id (LHsExpr Id)] -> Uncovered
+       -> PmM ([LMatch Id (LHsExpr Id)] , Uncovered , [LMatch Id (LHsExpr Id)])
     go []     missing = return ([], missing, [])
     go (m:ms) missing = do
       tracePm "checMatches': go" (ppr m $$ ppr missing)
-      fam_insts          <- dsGetFamInstEnvs
-      (clause, guards)   <- translateMatch fam_insts m
-      (cs, missing', ds) <- runMany (pmcheckI clause guards) missing
+      fam_insts          <- liftD dsGetFamInstEnvs
+      (clause, guards)   <- liftD $ translateMatch fam_insts m
+      r@(PartialResult cs missing' ds)
+        <- runMany (pmcheckI clause guards) missing
+      tracePm "checMatches': go: res" (ppr r)
       (rs, final_u, is)  <- go ms missing'
       return $ case (cs, ds) of
-        (True,  _    ) -> (  rs, final_u,   is) -- useful
-        (False, False) -> (m:rs, final_u,   is) -- redundant
-        (False, True ) -> (  rs, final_u, m:is) -- inaccessible
+        (Covered,  _    )        -> (  rs, final_u,   is) -- useful
+        (NotCovered, NotDiverged) -> (m:rs, final_u,   is) -- redundant
+        (NotCovered, Diverged )   -> (  rs, final_u, m:is) -- inaccessible
 
     hsLMatchToLPats :: LMatch id body -> Located [LPat id]
     hsLMatchToLPats (L l (Match _ pats _ _)) = L l pats
@@ -239,7 +313,7 @@ isFakeGuard [PmCon { pm_con_con = c }] (PmExprOther EWildPat)
 isFakeGuard _pats _e = False
 
 -- | Generate a `canFail` pattern vector of a specific type
-mkCanFailPmPat :: Type -> PmM PatVec
+mkCanFailPmPat :: Type -> DsM PatVec
 mkCanFailPmPat ty = do
   var <- mkPmVar ty
   return [var, fake_pat]
@@ -274,7 +348,7 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
 -- -----------------------------------------------------------------------
 -- * Transform (Pat Id) into of (PmPat Id)
 
-translatePat :: FamInstEnvs -> Pat Id -> PmM PatVec
+translatePat :: FamInstEnvs -> Pat Id -> DsM PatVec
 translatePat fam_insts pat = case pat of
   WildPat ty  -> mkPmVars [ty]
   VarPat  id  -> return [PmVar (unLoc id)]
@@ -389,7 +463,7 @@ translatePat fam_insts pat = case pat of
 
 -- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs)
 translateNPat :: FamInstEnvs
-              -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> PmM PatVec
+              -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> DsM PatVec
 translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty
   | not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
   = translatePat fam_insts (LitPat (HsString src s))
@@ -407,12 +481,12 @@ translateNPat _ ol mb_neg _
 
 -- | Translate a list of patterns (Note: each pattern is translated
 -- to a pattern vector but we do not concatenate the results).
-translatePatVec :: FamInstEnvs -> [Pat Id] -> PmM [PatVec]
+translatePatVec :: FamInstEnvs -> [Pat Id] -> DsM [PatVec]
 translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
 
 -- | Translate a constructor pattern
 translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
-                   -> DataCon -> HsConPatDetails Id -> PmM PatVec
+                   -> DataCon -> HsConPatDetails Id -> DsM PatVec
 translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
   = concat <$> translatePatVec fam_insts (map unLoc ps)
 translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
@@ -467,7 +541,7 @@ translateConPatVec fam_insts  univ_tys  ex_tvs c (RecCon (HsRecFields fs _))
       | otherwise = subsetOf (x:xs) ys
 
 -- Translate a single match
-translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> PmM (PatVec,[PatVec])
+translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> DsM (PatVec,[PatVec])
 translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
   pats'   <- concat <$> translatePatVec fam_insts pats
   guards' <- mapM (translateGuards fam_insts) guards
@@ -483,7 +557,7 @@ translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
 -- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
 
 -- | Translate a list of guard statements to a pattern vector
-translateGuards :: FamInstEnvs -> [GuardStmt Id] -> PmM PatVec
+translateGuards :: FamInstEnvs -> [GuardStmt Id] -> DsM PatVec
 translateGuards fam_insts guards = do
   all_guards <- concat <$> mapM (translateGuard fam_insts) guards
   return (replace_unhandled all_guards)
@@ -523,7 +597,7 @@ cantFailPattern (PmGrd pv _e)
 cantFailPattern _ = False
 
 -- | Translate a guard statement to Pattern
-translateGuard :: FamInstEnvs -> GuardStmt Id -> PmM PatVec
+translateGuard :: FamInstEnvs -> GuardStmt Id -> DsM PatVec
 translateGuard fam_insts guard = case guard of
   BodyStmt   e _ _ _ -> translateBoolGuard e
   LetStmt      binds -> translateLet (unLoc binds)
@@ -535,17 +609,17 @@ translateGuard fam_insts guard = case guard of
   ApplicativeStmt {} -> panic "translateGuard ApplicativeLastStmt"
 
 -- | Translate let-bindings
-translateLet :: HsLocalBinds Id -> PmM PatVec
+translateLet :: HsLocalBinds Id -> DsM PatVec
 translateLet _binds = return []
 
 -- | Translate a pattern guard
-translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> PmM PatVec
+translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> DsM PatVec
 translateBind fam_insts (L _ p) e = do
   ps <- translatePat fam_insts p
   return [mkGuard ps (unLoc e)]
 
 -- | Translate a boolean guard
-translateBoolGuard :: LHsExpr Id -> PmM PatVec
+translateBoolGuard :: LHsExpr Id -> DsM PatVec
 translateBoolGuard e
   | isJust (isTrueLHsExpr e) = return []
     -- The formal thing to do would be to generate (True <- True)
@@ -675,7 +749,7 @@ pmPatType (PmGrd  { pm_grd_pv  = pv })
 
 -- | Generate a value abstraction for a given constructor (generate
 -- fresh variables of the appropriate type for arguments)
-mkOneConFull :: Id -> DataCon -> PmM (ValAbs, ComplexEq, Bag EvVar)
+mkOneConFull :: Id -> DataCon -> DsM (ValAbs, ComplexEq, Bag EvVar)
 --  *  x :: T tys, where T is an algebraic data type
 --     NB: in the case of a data familiy, T is the *representation* TyCon
 --     e.g.   data instance T (a,b) = T1 a b
@@ -738,17 +812,17 @@ mkPosEq x l = (PmExprVar (idName x), PmExprLit l)
 {-# INLINE mkPosEq #-}
 
 -- | Generate a variable pattern of a given type
-mkPmVar :: Type -> PmM (PmPat p)
+mkPmVar :: Type -> DsM (PmPat p)
 mkPmVar ty = PmVar <$> mkPmId ty
 {-# INLINE mkPmVar #-}
 
 -- | Generate many variable patterns, given a list of types
-mkPmVars :: [Type] -> PmM PatVec
+mkPmVars :: [Type] -> DsM PatVec
 mkPmVars tys = mapM mkPmVar tys
 {-# INLINE mkPmVars #-}
 
 -- | Generate a fresh `Id` of a given type
-mkPmId :: Type -> PmM Id
+mkPmId :: Type -> DsM Id
 mkPmId ty = getUniqueM >>= \unique ->
   let occname = mkVarOccFS (fsLit (show unique))
       name    = mkInternalName unique occname noSrcSpan
@@ -757,7 +831,7 @@ mkPmId ty = getUniqueM >>= \unique ->
 -- | Generate a fresh term variable of a given and return it in two forms:
 -- * A variable pattern
 -- * A variable expression
-mkPmId2Forms :: Type -> PmM (Pattern, LHsExpr Id)
+mkPmId2Forms :: Type -> DsM (Pattern, LHsExpr Id)
 mkPmId2Forms ty = do
   x <- mkPmId ty
   return (PmVar x, noLoc (HsVar (noLoc x)))
@@ -802,7 +876,7 @@ allConstructors = tyConDataCons . dataConTyCon
 newEvVar :: Name -> Type -> EvVar
 newEvVar name ty = mkLocalId name (toTcType ty)
 
-nameType :: String -> Type -> PmM EvVar
+nameType :: String -> Type -> DsM EvVar
 nameType name ty = do
   unique <- getUniqueM
   let occname = mkVarOccFS (fsLit (name++"_"++show unique))
@@ -820,7 +894,8 @@ nameType name ty = do
 -- | Check whether a set of type constraints is satisfiable.
 tyOracle :: Bag EvVar -> PmM Bool
 tyOracle evs
-  = do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs
+  = liftD $
+    do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs
        ; case res of
             Just sat -> return sat
             Nothing  -> pprPanic "tyOracle" (vcat $ pprErrMsgBagWithLoc errs) }
@@ -861,7 +936,7 @@ Main functions are:
   are checked, if they are inconsistent, the set is empty, otherwise, the
   set contains only a vector of variables with the constraints in scope.
 
-* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
 
   Checks redundancy, coverage and inaccessibility, using auxilary functions
   `pmcheckGuards` and `pmcheckHd`. Mainly handles the guard case which is
@@ -869,12 +944,12 @@ Main functions are:
   whole clause is checked, or `pmcheckHd` when the pattern vector does not
   start with a guard.
 
-* pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple
+* pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult
 
   Processes the guards.
 
 * pmcheckHd :: Pattern -> PatVec -> [PatVec]
-          -> ValAbs -> ValVec -> PmM Triple
+          -> ValAbs -> ValVec -> PmM PartialResult
 
   Worker: This function implements functions `covered`, `uncovered` and
   `divergent` from the paper at once. Slightly different from the paper because
@@ -886,17 +961,20 @@ Main functions are:
 -- | Lift a pattern matching action from a single value vector abstration to a
 -- value set abstraction, but calling it on every vector and the combining the
 -- results.
-runMany :: (ValVec -> PmM Triple) -> (Uncovered -> PmM Triple)
-runMany pm us = mapAndUnzip3M pm us >>= \(css, uss, dss) ->
-                  return (or css, concat uss, or dss)
+runMany :: (ValVec -> PmM PartialResult) -> (Uncovered -> PmM PartialResult)
+runMany _ [] = return $ PartialResult mempty mempty mempty
+runMany pm (m:ms) = do
+  (PartialResult c v d) <- pm m
+  (PartialResult cs vs ds) <- runMany pm ms
+  return (PartialResult (c `mappend` cs) (v `mappend` vs) (d `mappend` ds))
 {-# INLINE runMany #-}
 
 -- | Generate the initial uncovered set. It initializes the
 -- delta with all term and type constraints in scope.
 mkInitialUncovered :: [Id] -> PmM Uncovered
 mkInitialUncovered vars = do
-  ty_cs  <- getDictsDs
-  tm_cs  <- map toComplex . bagToList <$> getTmCsDs
+  ty_cs  <- liftD getDictsDs
+  tm_cs  <- map toComplex . bagToList <$> liftD getTmCsDs
   sat_ty <- tyOracle ty_cs
   return $ case (sat_ty, tmOracle initialTmState tm_cs) of
     (True, Just tm_state) -> [ValVec patterns (MkDelta ty_cs tm_state)]
@@ -908,41 +986,45 @@ mkInitialUncovered vars = do
 
 -- | Increase the counter for elapsed algorithm iterations, check that the
 -- limit is not exceeded and call `pmcheck`
-pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
 pmcheckI ps guards vva = do
-  n <- incrCheckPmIterDs
+  n <- liftD incrCheckPmIterDs
   tracePm "pmCheck" (ppr n <> colon <+> pprPatVec ps
                         $$ hang (text "guards:") 2 (vcat (map pprPatVec guards))
                         $$ pprValVecDebug vva)
-  pmcheck ps guards vva
+  res <- pmcheck ps guards vva
+  tracePm "pmCheckResult:" (ppr res)
+  return res
 {-# INLINE pmcheckI #-}
 
 -- | Increase the counter for elapsed algorithm iterations, check that the
 -- limit is not exceeded and call `pmcheckGuards`
-pmcheckGuardsI :: [PatVec] -> ValVec -> PmM Triple
-pmcheckGuardsI gvs vva = incrCheckPmIterDs >> pmcheckGuards gvs vva
+pmcheckGuardsI :: [PatVec] -> ValVec -> PmM PartialResult
+pmcheckGuardsI gvs vva = liftD incrCheckPmIterDs >> pmcheckGuards gvs vva
 {-# INLINE pmcheckGuardsI #-}
 
 -- | Increase the counter for elapsed algorithm iterations, check that the
 -- limit is not exceeded and call `pmcheckHd`
-pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple
+pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult
 pmcheckHdI p ps guards va vva = do
-  n <- incrCheckPmIterDs
+  n <- liftD incrCheckPmIterDs
   tracePm "pmCheckHdI" (ppr n <> colon <+> pprPmPatDebug p
                         $$ pprPatVec ps
                         $$ hang (text "guards:") 2 (vcat (map pprPatVec guards))
                         $$ pprPmPatDebug va
                         $$ pprValVecDebug vva)
 
-  pmcheckHd p ps guards va vva
+  res <- pmcheckHd p ps guards va vva
+  tracePm "pmCheckHdI: res" (ppr res)
+  return res
 {-# INLINE pmcheckHdI #-}
 
 -- | Matching function: Check simultaneously a clause (takes separately the
 -- patterns and the list of guards) for exhaustiveness, redundancy and
 -- inaccessibility.
-pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
 pmcheck [] guards vva@(ValVec [] _)
-  | null guards = return (True, [], False)
+  | null guards = return $ mempty { presultCovered = Covered }
   | otherwise   = pmcheckGuardsI guards vva
 
 -- Guard
@@ -953,7 +1035,7 @@ pmcheck (p@(PmGrd pv e) : ps) guards vva@(ValVec vas delta)
     -- though. So just have these two cases but do not do all the boilerplate
   | isFakeGuard pv e = forces . mkCons vva <$> pmcheckI ps guards vva
   | otherwise = do
-      y <- mkPmId (pmPatType p)
+      y <- liftD $ mkPmId (pmPatType p)
       let tm_state = extendSubst y e (delta_tm_cs delta)
           delta'   = delta { delta_tm_cs = tm_state }
       utail <$> pmcheckI (pv ++ ps) guards (ValVec (PmVar y : vas) delta')
@@ -965,41 +1047,44 @@ pmcheck (p:ps) guards (ValVec (va:vva) delta)
   = pmcheckHdI p ps guards va (ValVec vva delta)
 
 -- | Check the list of guards
-pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple
-pmcheckGuards []       vva = return (False, [vva], False)
+pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult
+pmcheckGuards []       vva = return (usimple [vva])
 pmcheckGuards (gv:gvs) vva = do
-  (cs,  vsa,  ds ) <- pmcheckI gv [] vva
-  (css, vsas, dss) <- runMany (pmcheckGuardsI gvs) vsa
-  return (cs || css, vsas, ds || dss)
+  (PartialResult cs vsa ds) <- pmcheckI gv [] vva
+  (PartialResult css vsas dss) <- runMany (pmcheckGuardsI gvs) vsa
+  return $ PartialResult (cs `mappend` css) vsas (ds `mappend` dss)
 
 -- | Worker function: Implements all cases described in the paper for all three
 -- functions (`covered`, `uncovered` and `divergent`) apart from the `Guard`
 -- cases which are handled by `pmcheck`
-pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple
+pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult
 
 -- Var
 pmcheckHd (PmVar x) ps guards va (ValVec vva delta)
   | Just tm_state <- solveOneEq (delta_tm_cs delta)
                                 (PmExprVar (idName x), vaToPmExpr va)
   = ucon va <$> pmcheckI ps guards (ValVec vva (delta {delta_tm_cs = tm_state}))
-  | otherwise = return (False, [], False)
+  | otherwise = return mempty
 
 -- ConCon
 pmcheckHd ( p@(PmCon {pm_con_con = c1, pm_con_args = args1})) ps guards
           (va@(PmCon {pm_con_con = c2, pm_con_args = args2})) (ValVec vva delta)
-  | c1 /= c2  = return (False, [ValVec (va:vva) delta], False)
+  | c1 /= c2  =
+    return (usimple [ValVec (va:vva) delta])
   | otherwise = kcon c1 (pm_con_arg_tys p) (pm_con_tvs p) (pm_con_dicts p)
                 <$> pmcheckI (args1 ++ ps) guards (ValVec (args2 ++ vva) delta)
 
 -- LitLit
-pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva = case eqPmLit l1 l2 of
-  True  -> ucon va <$> pmcheckI ps guards vva
-  False -> return $ ucon va (False, [vva], False)
+pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva =
+  case eqPmLit l1 l2 of
+    True  -> ucon va <$> pmcheckI ps guards vva
+    False -> return $ ucon va (usimple [vva])
 
 -- ConVar
 pmcheckHd (p@(PmCon { pm_con_con = con })) ps guards
           (PmVar x) (ValVec vva delta) = do
-  cons_cs  <- mapM (mkOneConFull x) (allConstructors con)
+  cons_cs  <- mapM (liftD . mkOneConFull x) (allConstructors con)
+
   inst_vsa <- flip concatMapM cons_cs $ \(va, tm_ct, ty_cs) -> do
     let ty_state = ty_cs `unionBags` delta_ty_cs delta -- not actually a state
     sat_ty <- if isEmptyBag ty_cs then return True
@@ -1018,13 +1103,13 @@ pmcheckHd (p@(PmLit l)) ps guards (PmVar x) (ValVec vva delta)
         case solveOneEq (delta_tm_cs delta) (mkPosEq x l) of
           Just tm_state -> pmcheckHdI p ps guards (PmLit l) $
                              ValVec vva (delta {delta_tm_cs = tm_state})
-          Nothing       -> return (False, [], False)
+          Nothing       -> return mempty
   where
     us | Just tm_state <- solveOneEq (delta_tm_cs delta) (mkNegEq x l)
        = [ValVec (PmNLit x [l] : vva) (delta { delta_tm_cs = tm_state })]
        | otherwise = []
 
-    non_matched = (False, us, False)
+    non_matched = usimple us
 
 -- LitNLit
 pmcheckHd (p@(PmLit l)) ps guards
@@ -1044,7 +1129,7 @@ pmcheckHd (p@(PmLit l)) ps guards
        = [ValVec (PmNLit x (l:lits) : vva) (delta { delta_tm_cs = tm_state })]
        | otherwise = []
 
-    non_matched = (False, us, False)
+    non_matched = usimple us
 
 -- ----------------------------------------------------------------------------
 -- The following three can happen only in cases like #322 where constructors
@@ -1055,14 +1140,14 @@ pmcheckHd (p@(PmLit l)) ps guards
 
 -- LitCon
 pmcheckHd (PmLit l) ps guards (va@(PmCon {})) (ValVec vva delta)
-  = do y <- mkPmId (pmPatType va)
+  = do y <- liftD $ mkPmId (pmPatType va)
        let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta)
            delta'   = delta { delta_tm_cs = tm_state }
        pmcheckHdI (PmVar y) ps guards va (ValVec vva delta')
 
 -- ConLit
 pmcheckHd (p@(PmCon {})) ps guards (PmLit l) (ValVec vva delta)
-  = do y <- mkPmId (pmPatType p)
+  = do y <- liftD $ mkPmId (pmPatType p)
        let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta)
            delta'   = delta { delta_tm_cs = tm_state }
        pmcheckHdI p ps guards (PmVar y) (ValVec vva delta')
@@ -1077,54 +1162,66 @@ pmcheckHd (PmGrd {}) _ _ _ _ = panic "pmcheckHd: Guard"
 -- ----------------------------------------------------------------------------
 -- * Utilities for main checking
 
+updateVsa :: (ValSetAbs -> ValSetAbs) -> (PartialResult -> PartialResult)
+updateVsa f p@(PartialResult { presultUncovered = old })
+  = p { presultUncovered = f old }
+
+
+-- | Initialise with default values for covering and divergent information.
+usimple :: ValSetAbs -> PartialResult
+usimple vsa = mempty { presultUncovered = vsa }
+
 -- | Take the tail of all value vector abstractions in the uncovered set
-utail :: Triple -> Triple
-utail (cs, vsa, ds) = (cs, vsa', ds)
-  where vsa' = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ]
+utail :: PartialResult -> PartialResult
+utail = updateVsa upd
+  where upd vsa = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ]
 
 -- | Prepend a value abstraction to all value vector abstractions in the
 -- uncovered set
-ucon :: ValAbs -> Triple -> Triple
-ucon va (cs, vsa, ds) = (cs, vsa', ds)
-  where vsa' = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ]
+ucon :: ValAbs -> PartialResult -> PartialResult
+ucon va = updateVsa upd
+  where
+    upd vsa = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ]
 
 -- | Given a data constructor of arity `a` and an uncovered set containing
 -- value vector abstractions of length `(a+n)`, pass the first `n` value
 -- abstractions to the constructor (Hence, the resulting value vector
 -- abstractions will have length `n+1`)
-kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar] -> Triple -> Triple
-kcon con arg_tys ex_tvs dicts (cs, vsa, ds)
-  = (cs, [ ValVec (va:vva) delta
-         | ValVec vva' delta <- vsa
-         , let (args, vva) = splitAt n vva'
-         , let va = PmCon { pm_con_con     = con
-                          , pm_con_arg_tys = arg_tys
-                          , pm_con_tvs     = ex_tvs
-                          , pm_con_dicts   = dicts
-                          , pm_con_args    = args } ]
-       , ds)
-  where n = dataConSourceArity con
+kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar]
+     -> PartialResult -> PartialResult
+kcon con arg_tys ex_tvs dicts
+  = let n = dataConSourceArity con
+        upd vsa =
+          [ ValVec (va:vva) delta
+          | ValVec vva' delta <- vsa
+          , let (args, vva) = splitAt n vva'
+          , let va = PmCon { pm_con_con     = con
+                            , pm_con_arg_tys = arg_tys
+                            , pm_con_tvs     = ex_tvs
+                            , pm_con_dicts   = dicts
+                            , pm_con_args    = args } ]
+    in updateVsa upd
 
 -- | Get the union of two covered, uncovered and divergent value set
 -- abstractions. Since the covered and divergent sets are represented by a
 -- boolean, union means computing the logical or (at least one of the two is
 -- non-empty).
-mkUnion :: Triple -> Triple -> Triple
-mkUnion (cs1, vsa1, ds1) (cs2, vsa2, ds2)
-  = (cs1 || cs2, vsa1 ++ vsa2, ds1 || ds2)
+
+mkUnion :: PartialResult -> PartialResult -> PartialResult
+mkUnion = mappend
 
 -- | Add a value vector abstraction to a value set abstraction (uncovered).
-mkCons :: ValVec -> Triple -> Triple
-mkCons vva (cs, vsa, ds) = (cs, vva:vsa, ds)
+mkCons :: ValVec -> PartialResult -> PartialResult
+mkCons vva = updateVsa (vva:)
 
 -- | Set the divergent set to not empty
-forces :: Triple -> Triple
-forces (cs, us, _) = (cs, us, True)
+forces :: PartialResult -> PartialResult
+forces pres = pres { presultDivergent = Diverged }
 
 -- | Set the divergent set to non-empty if the flag is `True`
-force_if :: Bool -> Triple -> Triple
-force_if True  (cs,us,_) = (cs,us,True)
-force_if False triple    = triple
+force_if :: Bool -> PartialResult -> PartialResult
+force_if True  pres = forces pres
+force_if False pres = pres
 
 -- ----------------------------------------------------------------------------
 -- * Propagation of term constraints inwards when checking nested matches
@@ -1133,7 +1230,7 @@ force_if False triple    = triple
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 When checking a match it would be great to have all type and term information
 available so we can get more precise results. For this reason we have functions
-`addDictsDs' and `addTmCsDs' in DsMonad that store in the environment type and
+`addDictsDs' and `addTmCsDs' in PmMonad that store in the environment type and
 term constraints (respectively) as we go deeper.
 
 The type constraints we propagate inwards are collected by `collectEvVarsPats'
@@ -1275,7 +1372,10 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result
       when exists_u $
         putSrcSpanDs loc (warnDs flag_u_reason (pprEqns uncovered))
   where
-    (redundant, uncovered, inaccessible) = pm_result
+    PmResult
+      { pmresultRedundant = redundant
+      , pmresultUncovered = uncovered
+      , pmresultInaccessible = inaccessible } = pm_result
 
     flag_i = wopt Opt_WarnOverlappingPatterns dflags
     flag_u = exhaustive dflags kind
@@ -1298,7 +1398,7 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result
 
 -- | Issue a warning when the predefined number of iterations is exceeded
 -- for the pattern match checker
-warnPmIters :: DynFlags -> DsMatchContext -> PmM ()
+warnPmIters :: DynFlags -> DsMatchContext -> DsM ()
 warnPmIters dflags (DsMatchContext kind loc)
   = when (flag_i || flag_u) $ do
       iters <- maxPmCheckIterations <$> getDynFlags
@@ -1441,7 +1541,11 @@ involved.
 -- Debugging Infrastructre
 
 tracePm :: String -> SDoc -> PmM ()
-tracePm herald doc = do
+tracePm herald doc = liftD $ tracePmD herald doc
+
+
+tracePmD :: String -> SDoc -> DsM ()
+tracePmD herald doc = do
   dflags <- getDynFlags
   printer <- mkPrintUnqualifiedDs
   liftIO $ dumpIfSet_dyn_printer printer dflags
index 9538e2c..c5ca313 100644 (file)
@@ -490,6 +490,7 @@ Library
         GraphPpr
         IOEnv
         ListSetOps
+        ListT
         Maybes
         MonadUtils
         OrdList
diff --git a/compiler/utils/ListT.hs b/compiler/utils/ListT.hs
new file mode 100644 (file)
index 0000000..2b81db1
--- /dev/null
@@ -0,0 +1,71 @@
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE Rank2Types #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+
+-------------------------------------------------------------------------
+-- |
+-- Module      : Control.Monad.Logic
+-- Copyright   : (c) Dan Doel
+-- License     : BSD3
+--
+-- Maintainer  : dan.doel@gmail.com
+-- Stability   : experimental
+-- Portability : non-portable (multi-parameter type classes)
+--
+-- A backtracking, logic programming monad.
+--
+--    Adapted from the paper
+--    /Backtracking, Interleaving, and Terminating
+--        Monad Transformers/, by
+--    Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry
+--    (<http://www.cs.rutgers.edu/~ccshan/logicprog/ListT-icfp2005.pdf>).
+-------------------------------------------------------------------------
+
+module ListT (
+    ListT(..),
+    runListT,
+    select,
+    fold
+  ) where
+
+import Control.Applicative
+
+import Control.Monad
+
+-------------------------------------------------------------------------
+-- | A monad transformer for performing backtracking computations
+-- layered over another monad 'm'
+newtype ListT m a =
+    ListT { unListT :: forall r. (a -> m r -> m r) -> m r -> m r }
+
+select :: Monad m => [a] -> ListT m a
+select xs = foldr (<|>) mzero (map pure xs)
+
+fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r
+fold = runListT
+
+-------------------------------------------------------------------------
+-- | Runs a ListT computation with the specified initial success and
+-- failure continuations.
+runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r
+runListT = unListT
+
+instance Functor (ListT f) where
+    fmap f lt = ListT $ \sk fk -> unListT lt (sk . f) fk
+
+instance Applicative (ListT f) where
+    pure a = ListT $ \sk fk -> sk a fk
+    f <*> a = ListT $ \sk fk -> unListT f (\g fk' -> unListT a (sk . g) fk') fk
+
+instance Alternative (ListT f) where
+    empty = ListT $ \_ fk -> fk
+    f1 <|> f2 = ListT $ \sk fk -> unListT f1 sk (unListT f2 sk fk)
+
+instance Monad (ListT m) where
+    m >>= f = ListT $ \sk fk -> unListT m (\a fk' -> unListT (f a) sk fk') fk
+    fail _ = ListT $ \_ fk -> fk
+
+instance MonadPlus (ListT m) where
+    mzero = ListT $ \_ fk -> fk
+    m1 `mplus` m2 = ListT $ \sk fk -> unListT m1 sk (unListT m2 sk fk)