Use NonEmpty lists to represent lists of duplicate elements
[ghc.git] / compiler / rename / RnExpr.hs
index e1a314f..6eabc89 100644 (file)
@@ -57,6 +57,7 @@ import qualified GHC.LanguageExtensions as LangExt
 
 import Data.Ord
 import Data.Array
+import qualified Data.List.NonEmpty as NE
 
 {-
 ************************************************************************
@@ -577,7 +578,7 @@ methodNamesMatch :: MatchGroup GhcRn (LHsCmd GhcRn) -> FreeVars
 methodNamesMatch (MG { mg_alts = L _ ms })
   = plusFVs (map do_one ms)
  where
-    do_one (L _ (Match _ _ _ grhss)) = methodNamesGRHSs grhss
+    do_one (L _ (Match { m_grhss = grhss })) = methodNamesGRHSs grhss
 
 -------------------------------------------------
 -- gaw 2004
@@ -970,7 +971,7 @@ rnParallelStmts ctxt return_op segs thing_inside
 
     cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
     dupErr vs = addErr (text "Duplicate binding in parallel list comprehension for:"
-                    <+> quotes (ppr (head vs)))
+                    <+> quotes (ppr (NE.head vs)))
 
 lookupStmtName :: HsStmtContext Name -> Name -> RnM (SyntaxExpr GhcRn, FreeVars)
 -- Like lookupSyntaxName, but respects contexts
@@ -1605,7 +1606,7 @@ mkStmtTreeOptimal stmts =
               (StmtTreeOne (stmt_arr ! hi), 1))
            | left_cost < right_cost
            = ((left,left_cost), (StmtTreeOne (stmt_arr ! hi), 1))
-           | otherwise -- left_cost > right_cost
+           | left_cost > right_cost
            = ((StmtTreeOne (stmt_arr ! lo), 1), (right,right_cost))
            | otherwise = minimumBy (comparing cost) alternatives
            where
@@ -1635,12 +1636,8 @@ stmtTreeToStmts
 -- the bind form, which would give rise to a Monad constraint.
 stmtTreeToStmts monad_names ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
                 tail _tail_fvs
-  | isIrrefutableHsPat pat, (False,tail') <- needJoin monad_names tail
-    -- WARNING: isIrrefutableHsPat on (HsPat Name) doesn't have enough info
-    --          to know which types have only one constructor.  So only
-    --          tuples come out as irrefutable; other single-constructor
-    --          types, and newtypes, will not.  See the code for
-    --          isIrrefuatableHsPat
+  | not (isStrictPattern pat), (False,tail') <- needJoin monad_names tail
+  -- See Note [ApplicativeDo and strict patterns]
   = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
 
 stmtTreeToStmts _monad_names _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
@@ -1715,6 +1712,8 @@ segments stmts = map fst $ merge $ reverse $ map reverse $ walk (reverse stmts)
     chunter _ [] = ([], [])
     chunter vars ((stmt,fvs) : rest)
        | not (isEmptyNameSet vars)
+       || isStrictPatternBind stmt
+           -- See Note [ApplicativeDo and strict patterns]
        = ((stmt,fvs) : chunk, rest')
        where (chunk,rest') = chunter vars' rest
              (pvars, evars) = stmtRefs stmt fvs
@@ -1727,6 +1726,58 @@ segments stmts = map fst $ merge $ reverse $ map reverse $ walk (reverse stmts)
       where fvs' = fvs `intersectNameSet` allvars
             pvars = mkNameSet (collectStmtBinders (unLoc stmt))
 
+    isStrictPatternBind :: ExprLStmt GhcRn -> Bool
+    isStrictPatternBind (L _ (BindStmt pat _ _ _ _)) = isStrictPattern pat
+    isStrictPatternBind _ = False
+
+{-
+Note [ApplicativeDo and strict patterns]
+
+A strict pattern match is really a dependency.  For example,
+
+do
+  (x,y) <- A
+  z <- B
+  return C
+
+The pattern (_,_) must be matched strictly before we do B.  If we
+allowed this to be transformed into
+
+  (\(x,y) -> \z -> C) <$> A <*> B
+
+then it could be lazier than the standard desuraging using >>=.  See #13875
+for more examples.
+
+Thus, whenever we have a strict pattern match, we treat it as a
+dependency between that statement and the following one.  The
+dependency prevents those two statements from being performed "in
+parallel" in an ApplicativeStmt, but doesn't otherwise affect what we
+can do with the rest of the statements in the same "do" expression.
+-}
+
+isStrictPattern :: LPat id -> Bool
+isStrictPattern (L _ pat) =
+  case pat of
+    WildPat{} -> False
+    VarPat{}  -> False
+    LazyPat{} -> False
+    AsPat _ p -> isStrictPattern p
+    ParPat p  -> isStrictPattern p
+    ViewPat _ p _ -> isStrictPattern p
+    SigPatIn p _ -> isStrictPattern p
+    SigPatOut p _ -> isStrictPattern p
+    BangPat{} -> True
+    TuplePat{} -> True
+    SumPat{} -> True
+    PArrPat{} -> True
+    ConPatIn{} -> True
+    ConPatOut{} -> True
+    LitPat{} -> True
+    NPat{} -> True
+    NPlusKPat{} -> True
+    SplicePat{} -> True
+    _otherwise -> panic "isStrictPattern"
+
 isLetStmt :: LStmt a b -> Bool
 isLetStmt (L _ LetStmt{}) = True
 isLetStmt _ = False