RnExpr: Fix ApplicativeDo desugaring with RebindableSyntax
authorBen Gamari <bgamari.foss@gmail.com>
Wed, 31 Aug 2016 20:03:33 +0000 (16:03 -0400)
committerBen Gamari <ben@smart-cactus.org>
Wed, 31 Aug 2016 20:34:53 +0000 (16:34 -0400)
We need to compare against the local return and pure, not returnMName
and pureAName.

Fixes #12490.

Test Plan: Validate, add testcase

Reviewers: austin, simonmar

Reviewed By: simonmar

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2499

GHC Trac Issues: #12490

compiler/rename/RnEnv.hs
compiler/rename/RnExpr.hs
testsuite/tests/ado/T12490.hs [new file with mode: 0644]
testsuite/tests/ado/all.T

index b0a7281..63b1f1f 100644 (file)
@@ -25,7 +25,8 @@ module RnEnv (
         lookupFieldFixityRn, lookupTyFixityRn,
         lookupInstDeclBndr, lookupRecFieldOcc, lookupFamInstName,
         lookupConstructorFields,
-        lookupSyntaxName, lookupSyntaxNames, lookupIfThenElse,
+        lookupSyntaxName, lookupSyntaxName', lookupSyntaxNames,
+        lookupIfThenElse,
         lookupGreAvailRn,
         getLookupOccRn,mkUnboundName, mkUnboundNameRdr, isUnboundName,
         addUsedGRE, addUsedGREs, addUsedDataCons,
@@ -1600,6 +1601,16 @@ lookupIfThenElse
                  ; return ( Just (mkRnSyntaxExpr ite)
                           , unitFV ite ) } }
 
+lookupSyntaxName' :: Name          -- ^ The standard name
+                  -> RnM Name      -- ^ Possibly a non-standard name
+lookupSyntaxName' std_name
+  = do { rebindable_on <- xoptM LangExt.RebindableSyntax
+       ; if not rebindable_on then
+           return std_name
+         else
+            -- Get the similarly named thing from the local environment
+           lookupOccRn (mkRdrUnqual (nameOccName std_name)) }
+
 lookupSyntaxName :: Name                                -- The standard name
                  -> RnM (SyntaxExpr Name, FreeVars)     -- Possibly a non-standard name
 lookupSyntaxName std_name
index f964e77..87e5507 100644 (file)
@@ -1452,6 +1452,10 @@ dsDo {(arg_1 | ... | arg_n); stmts} expr =
 
 -}
 
+-- | The 'Name's of @return@ and @pure@. These may not be 'returnName' and
+-- 'pureName' due to @RebindableSyntax@.
+data MonadNames = MonadNames { return_name, pure_name :: Name }
+
 -- | rearrange a list of statements using ApplicativeDoStmt.  See
 -- Note [ApplicativeDo].
 rearrangeForApplicativeDo
@@ -1465,7 +1469,11 @@ rearrangeForApplicativeDo ctxt stmts0 = do
   optimal_ado <- goptM Opt_OptimalApplicativeDo
   let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts
                 | otherwise = mkStmtTreeHeuristic stmts
-  stmtTreeToStmts ctxt stmt_tree [last] last_fvs
+  return_name <- lookupSyntaxName' returnMName
+  pure_name   <- lookupSyntaxName' pureAName
+  let monad_names = MonadNames { return_name = return_name
+                               , pure_name   = pure_name }
+  stmtTreeToStmts monad_names ctxt stmt_tree [last] last_fvs
   where
     (stmts,(last,last_fvs)) = findLast stmts0
     findLast [] = error "findLast"
@@ -1568,7 +1576,8 @@ mkStmtTreeOptimal stmts =
 -- | Turn the ExprStmtTree back into a sequence of statements, using
 -- ApplicativeStmt where necessary.
 stmtTreeToStmts
-  :: HsStmtContext Name
+  :: MonadNames
+  -> HsStmtContext Name
   -> ExprStmtTree
   -> [ExprLStmt Name]             -- ^ the "tail"
   -> FreeVars                     -- ^ free variables of the tail
@@ -1581,9 +1590,9 @@ stmtTreeToStmts
 -- In the spec, but we do it here rather than in the desugarer,
 -- because we need the typechecker to typecheck the <$> form rather than
 -- the bind form, which would give rise to a Monad constraint.
-stmtTreeToStmts ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
+stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
                 tail _tail_fvs
-  | isIrrefutableHsPat pat, (False,tail') <- needJoin tail
+  | isIrrefutableHsPat pat, (False,tail') <- needJoin monad_names tail
     -- WARNING: isIrrefutableHsPat on (HsPat Name) doesn't have enough info
     --          to know which types have only one constructor.  So only
     --          tuples come out as irrefutable; other single-constructor
@@ -1591,19 +1600,19 @@ stmtTreeToStmts ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
     --          isIrrefuatableHsPat
   = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
 
-stmtTreeToStmts _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
+stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
   return (s : tail, emptyNameSet)
 
-stmtTreeToStmts ctxt (StmtTreeBind before after) tail tail_fvs = do
-  (stmts1, fvs1) <- stmtTreeToStmts ctxt after tail tail_fvs
+stmtTreeToStmts monad_names ctxt (StmtTreeBind before after) tail tail_fvs = do
+  (stmts1, fvs1) <- stmtTreeToStmts monad_names ctxt after tail tail_fvs
   let tail1_fvs = unionNameSets (tail_fvs : map snd (flattenStmtTree after))
-  (stmts2, fvs2) <- stmtTreeToStmts ctxt before stmts1 tail1_fvs
+  (stmts2, fvs2) <- stmtTreeToStmts monad_names ctxt before stmts1 tail1_fvs
   return (stmts2, fvs1 `plusFV` fvs2)
 
-stmtTreeToStmts ctxt (StmtTreeApplicative trees) tail tail_fvs = do
+stmtTreeToStmts monad_names ctxt (StmtTreeApplicative trees) tail tail_fvs = do
    pairs <- mapM (stmtTreeArg ctxt tail_fvs) trees
    let (stmts', fvss) = unzip pairs
-   let (need_join, tail') = needJoin tail
+   let (need_join, tail') = needJoin monad_names tail
    (stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
    return (stmts, unionNameSets (fvs:fvss))
  where
@@ -1617,7 +1626,7 @@ stmtTreeToStmts ctxt (StmtTreeApplicative trees) tail tail_fvs = do
            -- See Note [Deterministic ApplicativeDo and RecursiveDo desugaring]
          pat = mkBigLHsVarPatTup pvars
          tup = mkBigLHsVarTup pvars
-     (stmts',fvs2) <- stmtTreeToStmts ctxt tree [] pvarset
+     (stmts',fvs2) <- stmtTreeToStmts monad_names ctxt tree [] pvarset
      (mb_ret, fvs1) <-
         if | L _ ApplicativeStmt{} <- last stmts' ->
              return (unLoc tup, emptyNameSet)
@@ -1763,17 +1772,22 @@ mkApplicativeStmt ctxt args need_join body_stmts
 
 -- | Given the statements following an ApplicativeStmt, determine whether
 -- we need a @join@ or not, and remove the @return@ if necessary.
-needJoin :: [ExprLStmt Name] -> (Bool, [ExprLStmt Name])
-needJoin [] = (False, [])  -- we're in an ApplicativeArg
-needJoin [L loc (LastStmt e _ t)]
- | Just arg <- isReturnApp e = (False, [L loc (LastStmt arg True t)])
-needJoin stmts = (True, stmts)
+needJoin :: MonadNames
+         -> [ExprLStmt Name]
+         -> (Bool, [ExprLStmt Name])
+needJoin _monad_names [] = (False, [])  -- we're in an ApplicativeArg
+needJoin monad_names  [L loc (LastStmt e _ t)]
+ | Just arg <- isReturnApp monad_names e =
+       (False, [L loc (LastStmt arg True t)])
+needJoin _monad_names stmts = (True, stmts)
 
 -- | @Just e@, if the expression is @return e@ or @return $ e@,
 -- otherwise @Nothing@
-isReturnApp :: LHsExpr Name -> Maybe (LHsExpr Name)
-isReturnApp (L _ (HsPar expr)) = isReturnApp expr
-isReturnApp (L _ e) = case e of
+isReturnApp :: MonadNames
+            -> LHsExpr Name
+            -> Maybe (LHsExpr Name)
+isReturnApp monad_names (L _ (HsPar expr)) = isReturnApp monad_names expr
+isReturnApp monad_names (L _ e) = case e of
   OpApp l op _ r | is_return l, is_dollar op -> Just r
   HsApp f arg    | is_return f               -> Just arg
   _otherwise -> Nothing
@@ -1784,7 +1798,8 @@ isReturnApp (L _ e) = case e of
        -- TODO: I don't know how to get this right for rebindable syntax
   is_var _ _ = False
 
-  is_return = is_var (\n -> n == returnMName || n == pureAName)
+  is_return = is_var (\n -> n == return_name monad_names
+                         || n == pure_name monad_names)
   is_dollar = is_var (`hasKey` dollarIdKey)
 
 {-
diff --git a/testsuite/tests/ado/T12490.hs b/testsuite/tests/ado/T12490.hs
new file mode 100644 (file)
index 0000000..e1bb022
--- /dev/null
@@ -0,0 +1,30 @@
+{-# LANGUAGE RebindableSyntax #-}
+{-# LANGUAGE ApplicativeDo #-}
+
+module T12490 where
+
+import Prelude (Int, String, Functor(..), ($), undefined, (+))
+
+join :: Monad f => f (f a) -> f a
+join = undefined
+
+class Functor f => Applicative f where
+    pure :: a -> f a
+    (<*>) :: f (a -> b) -> f a -> f b
+
+class Applicative f => Monad f where
+    return :: a -> f a
+    (>>=) :: f a -> (a -> f b) -> f b
+    fail :: String -> f a
+
+f_app :: Applicative f => f Int -> f Int -> f Int
+f_app a b = do
+    a' <- a
+    b' <- b
+    pure (a' + b')
+
+f_monad :: Monad f => f Int -> f Int -> f Int
+f_monad a b = do
+    a' <- a
+    b' <- b
+    return $ a' + b'
index 06cdbf9..67697b9 100644 (file)
@@ -7,3 +7,4 @@ test('ado006', normal, compile, [''])
 test('ado007', normal, compile, [''])
 test('T11607', normal, compile_and_run, [''])
 test('ado-optimal', normal, compile_and_run, [''])
+test('T12490', normal, compile, [''])