ApplicativeDo: handle BodyStmt (#12143)
authorSimon Marlow <marlowsd@gmail.com>
Thu, 26 Oct 2017 10:23:23 +0000 (11:23 +0100)
committerSimon Marlow <marlowsd@gmail.com>
Fri, 27 Oct 2017 13:48:23 +0000 (14:48 +0100)
Summary:
It's simple to treat BodyStmt just like a BindStmt with a wildcard
pattern, which is enough to fix #12143 without going all the way to
using `<*` and `*>` (#10892).

Test Plan:
* new test cases in `ado004.hs`
* validate

Reviewers: niteria, simonpj, bgamari, austin, erikd

Subscribers: rwbarton, thomie

GHC Trac Issues: #12143

Differential Revision: https://phabricator.haskell.org/D4128

compiler/deSugar/Coverage.hs
compiler/deSugar/DsExpr.hs
compiler/hsSyn/HsExpr.hs
compiler/hsSyn/HsUtils.hs
compiler/rename/RnExpr.hs
compiler/typecheck/TcHsSyn.hs
compiler/typecheck/TcMatches.hs
testsuite/tests/ado/ado004.hs
testsuite/tests/ado/ado004.stderr

index c58c1a4..862e564 100644 (file)
@@ -767,8 +767,11 @@ addTickApplicativeArg
 addTickApplicativeArg isGuard (op, arg) =
   liftM2 (,) (addTickSyntaxExpr hpcSrcSpan op) (addTickArg arg)
  where
-  addTickArg (ApplicativeArgOne pat expr) =
-    ApplicativeArgOne <$> addTickLPat pat <*> addTickLHsExpr expr
+  addTickArg (ApplicativeArgOne pat expr isBody) =
+    ApplicativeArgOne
+      <$> addTickLPat pat
+      <*> addTickLHsExpr expr
+      <*> pure isBody
   addTickArg (ApplicativeArgMany stmts ret pat) =
     ApplicativeArgMany
       <$> addTickLStmts isGuard stmts
index b2b98f8..635a9c6 100644 (file)
@@ -924,7 +924,7 @@ dsDo stmts
              let
                (pats, rhss) = unzip (map (do_arg . snd) args)
 
-               do_arg (ApplicativeArgOne pat expr) =
+               do_arg (ApplicativeArgOne pat expr _) =
                  (pat, dsLExpr expr)
                do_arg (ApplicativeArgMany stmts ret pat) =
                  (pat, dsDo (stmts ++ [noLoc $ mkLastStmt (noLoc ret)]))
index 1cfaa79..fedaa44 100644 (file)
@@ -1777,13 +1777,18 @@ deriving instance (DataId idL, DataId idR) => Data (ParStmtBlock idL idR)
 
 -- | Applicative Argument
 data ApplicativeArg idL idR
-  = ApplicativeArgOne            -- pat <- expr (pat must be irrefutable)
-      (LPat idL)
+  = ApplicativeArgOne      -- A single statement (BindStmt or BodyStmt)
+      (LPat idL)           -- WildPat if it was a BodyStmt (see below)
       (LHsExpr idL)
-  | ApplicativeArgMany           -- do { stmts; return vars }
-      [ExprLStmt idL]            -- stmts
-      (HsExpr idL)               -- return (v1,..,vn), or just (v1,..,vn)
-      (LPat idL)                 -- (v1,...,vn)
+      Bool                 -- True <=> was a BodyStmt
+                           -- False <=> was a BindStmt
+                           -- See Note [Applicative BodyStmt]
+
+  | ApplicativeArgMany     -- do { stmts; return vars }
+      [ExprLStmt idL]      -- stmts
+      (HsExpr idL)         -- return (v1,..,vn), or just (v1,..,vn)
+      (LPat idL)           -- (v1,...,vn)
+
 deriving instance (DataId idL, DataId idR) => Data (ApplicativeArg idL idR)
 
 {-
@@ -1921,6 +1926,34 @@ Parallel statements require the 'Control.Monad.Zip.mzip' function:
 
 In any other context than 'MonadComp', the fields for most of these
 'SyntaxExpr's stay bottom.
+
+
+Note [Applicative BodyStmt]
+
+(#12143) For the purposes of ApplicativeDo, we treat any BodyStmt
+as if it was a BindStmt with a wildcard pattern.  For example,
+
+  do
+    x <- A
+    B
+    return x
+
+is transformed as if it were
+
+  do
+    x <- A
+    _ <- B
+    return x
+
+so it transforms to
+
+  (\(x,_) -> x) <$> A <*> B
+
+But we have to remember when we treat a BodyStmt like a BindStmt,
+because in error messages we want to emit the original syntax the user
+wrote, not our internal representation.  So ApplicativeArgOne has a
+Bool flag that is True when the original statement was a BodyStmt, so
+that we can pretty-print it correctly.
 -}
 
 instance (SourceTextX idL, OutputableBndrId idL)
@@ -1973,7 +2006,11 @@ pprStmt (ApplicativeStmt args mb_join _)
    flattenStmt (L _ (ApplicativeStmt args _ _)) = concatMap flattenArg args
    flattenStmt stmt = [ppr stmt]
 
-   flattenArg (_, ApplicativeArgOne pat expr) =
+   flattenArg (_, ApplicativeArgOne pat expr isBody)
+     | isBody =  -- See Note [Applicative BodyStmt]
+     [ppr (BodyStmt expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
+             :: ExprStmt idL)]
+     | otherwise =
      [ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
              :: ExprStmt idL)]
    flattenArg (_, ApplicativeArgMany stmts _ _) =
@@ -1987,7 +2024,11 @@ pprStmt (ApplicativeStmt args mb_join _)
           then ap_expr
           else text "join" <+> parens ap_expr
 
-   pp_arg (_, ApplicativeArgOne pat expr) =
+   pp_arg (_, ApplicativeArgOne pat expr isBody)
+     | isBody =  -- See Note [Applicative BodyStmt]
+     ppr (BodyStmt expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
+            :: ExprStmt idL)
+     | otherwise =
      ppr (BindStmt pat expr noSyntaxExpr noSyntaxExpr (panic "pprStmt")
             :: ExprStmt idL)
    pp_arg (_, ApplicativeArgMany stmts return pat) =
index 3c1726b..8e17994 100644 (file)
@@ -1197,7 +1197,7 @@ lStmtsImplicits = hs_lstmts
     hs_stmt :: StmtLR GhcRn idR (Located (body idR)) -> NameSet
     hs_stmt (BindStmt pat _ _ _ _) = lPatImplicits pat
     hs_stmt (ApplicativeStmt args _ _) = unionNameSets (map do_arg args)
-      where do_arg (_, ApplicativeArgOne pat _) = lPatImplicits pat
+      where do_arg (_, ApplicativeArgOne pat _ _) = lPatImplicits pat
             do_arg (_, ApplicativeArgMany stmts _ _) = hs_lstmts stmts
     hs_stmt (LetStmt binds)      = hs_local_binds (unLoc binds)
     hs_stmt (BodyStmt {})        = emptyNameSet
index b23762a..cf47932 100644 (file)
@@ -1659,7 +1659,12 @@ stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
                 tail _tail_fvs
   | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail
   -- See Note [ApplicativeDo and strict patterns]
-  = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
+  = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs False] False tail'
+stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BodyStmt rhs _ _ _),_))
+                tail _tail_fvs
+  | (False,tail') <- needJoin monad_names tail
+  = mkApplicativeStmt ctxt
+      [ApplicativeArgOne nlWildPatName rhs True] False tail'
 
 stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
   return (s : tail, emptyNameSet)
@@ -1678,7 +1683,9 @@ stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
    return (stmts, unionNameSets (fvs:fvss))
  where
    stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BindStmt pat exp _ _ _), _)) =
-     return (ApplicativeArgOne pat exp, emptyFVs)
+     return (ApplicativeArgOne pat exp False, emptyFVs)
+   stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BodyStmt exp _ _ _), _)) =
+     return (ApplicativeArgOne nlWildPatName exp True, emptyFVs)
    stmtTreeArg ctxt tail_fvs tree = do
      let stmts = flattenStmtTree tree
          pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
index 2b56a78..01b7176 100644 (file)
@@ -1098,11 +1098,11 @@ zonkStmt env _zBody (ApplicativeStmt args mb_join body_ty)
     zonk_join env Nothing  = return (env, Nothing)
     zonk_join env (Just j) = second Just <$> zonkSyntaxExpr env j
 
-    get_pat (_, ApplicativeArgOne pat _)    = pat
+    get_pat (_, ApplicativeArgOne pat _ _) = pat
     get_pat (_, ApplicativeArgMany _ _ pat) = pat
 
-    replace_pat pat (op, ApplicativeArgOne _ a)
-      = (op, ApplicativeArgOne pat a)
+    replace_pat pat (op, ApplicativeArgOne _ a isBody)
+      = (op, ApplicativeArgOne pat a isBody)
     replace_pat pat (op, ApplicativeArgMany a b _)
       = (op, ApplicativeArgMany a b pat)
 
@@ -1121,9 +1121,9 @@ zonkStmt env _zBody (ApplicativeStmt args mb_join body_ty)
            ; return (env2, (new_op, new_arg) : new_args) }
     zonk_args_rev env [] = return (env, [])
 
-    zonk_arg env (ApplicativeArgOne pat expr)
+    zonk_arg env (ApplicativeArgOne pat expr isBody)
       = do { new_expr <- zonkLExpr env expr
-           ; return (ApplicativeArgOne pat new_expr) }
+           ; return (ApplicativeArgOne pat new_expr isBody) }
     zonk_arg env (ApplicativeArgMany stmts ret pat)
       = do { (env1, new_stmts) <- zonkStmts env zonkLExpr stmts
            ; new_ret           <- zonkExpr env1 ret
index acc33d9..d938de0 100644 (file)
@@ -1055,13 +1055,13 @@ tcApplicativeStmts ctxt pairs rhs_ty thing_inside
     goArg :: (ApplicativeArg GhcRn GhcRn, Type, Type)
           -> TcM (ApplicativeArg GhcTcId GhcTcId)
 
-    goArg (ApplicativeArgOne pat rhs, pat_ty, exp_ty)
+    goArg (ApplicativeArgOne pat rhs isBody, pat_ty, exp_ty)
       = setSrcSpan (combineSrcSpans (getLoc pat) (getLoc rhs)) $
         addErrCtxt (pprStmtInCtxt ctxt (mkBindStmt pat rhs))   $
         do { rhs' <- tcMonoExprNC rhs (mkCheckExpType exp_ty)
            ; (pat', _) <- tcPat (StmtCtxt ctxt) pat (mkCheckExpType pat_ty) $
                           return ()
-           ; return (ApplicativeArgOne pat' rhs') }
+           ; return (ApplicativeArgOne pat' rhs' isBody) }
 
     goArg (ApplicativeArgMany stmts ret pat, pat_ty, exp_ty)
       = do { (stmts', (ret',pat')) <-
@@ -1075,7 +1075,7 @@ tcApplicativeStmts ctxt pairs rhs_ty thing_inside
            ; return (ApplicativeArgMany stmts' ret' pat') }
 
     get_arg_bndrs :: ApplicativeArg GhcTcId GhcTcId -> [Id]
-    get_arg_bndrs (ApplicativeArgOne pat _)    = collectPatBinders pat
+    get_arg_bndrs (ApplicativeArgOne pat _ _)  = collectPatBinders pat
     get_arg_bndrs (ApplicativeArgMany _ _ pat) = collectPatBinders pat
 
 
index fa3c723..e7166c0 100644 (file)
@@ -16,6 +16,19 @@ test1a f = do
   y <- f 4
   return $ x + y
 
+-- When one of the statements is a BodyStmt
+test1b :: Applicative f => (Int -> f Int) -> f Int
+test1b f = do
+  x <- f 3
+  f 4
+  return x
+
+test1c :: Applicative f => (Int -> f Int) -> f Int
+test1c f = do
+  f 3
+  x <- f 4
+  return x
+
 -- Test we can also infer the Applicative version of the type
 test2 f = do
   x <- f 3
@@ -32,6 +45,11 @@ test2c f = do
   x <- f 3
   return $ x + 1
 
+-- with a BodyStmt
+test2d f = do
+  f 3
+  return 4
+
 -- Test for just one statement
 test2b f = do
   return (f 3)
index 9b95e3b..a3ef9e9 100644 (file)
@@ -3,6 +3,10 @@ TYPE SIGNATURES
     forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
   test1a ::
     forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
+  test1b ::
+    forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
+  test1c ::
+    forall (f :: * -> *). Applicative f => (Int -> f Int) -> f Int
   test2 ::
     forall (f :: * -> *) t b.
     (Num b, Num t, Applicative f) =>
@@ -17,6 +21,10 @@ TYPE SIGNATURES
     forall (f :: * -> *) t b.
     (Num b, Num t, Functor f) =>
     (t -> f b) -> f b
+  test2d ::
+    forall (f :: * -> *) t1 b t2.
+    (Num b, Num t1, Functor f) =>
+    (t1 -> f t2) -> f b
   test3 ::
     forall (m :: * -> *) t1 t2 a.
     (Num t1, Monad m) =>