Small refactor in desugar of pattern matching
[ghc.git] / compiler / deSugar / DsUtils.hs
index cc621d5..f74be0b 100644 (file)
@@ -9,6 +9,8 @@ This module exports some utility functions of no great interest.
 -}
 
 {-# LANGUAGE CPP #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE TypeFamilies #-}
 
 -- | Utility functions for constructing Core syntax, principally for desugaring
 module DsUtils (
@@ -35,19 +37,21 @@ module DsUtils (
         mkSelectorBinds,
 
         selectSimpleMatchVarL, selectMatchVars, selectMatchVar,
-        mkOptTickBox, mkBinaryTickBox, decideBangHood
+        mkOptTickBox, mkBinaryTickBox, decideBangHood, addBang
     ) where
 
 #include "HsVersions.h"
 
-import {-# SOURCE #-}   Match ( matchSimply )
+import GhcPrelude
+
+import {-# SOURCE #-} Match  ( matchSimply )
+import {-# SOURCE #-} DsExpr ( dsLExpr )
 
 import HsSyn
 import TcHsSyn
 import TcType( tcSplitTyConApp )
 import CoreSyn
 import DsMonad
-import {-# SOURCE #-} DsExpr ( dsLExpr )
 
 import CoreUtils
 import MkCore
@@ -55,7 +59,6 @@ import MkId
 import Id
 import Literal
 import TyCon
--- import ConLike
 import DataCon
 import PatSyn
 import Type
@@ -68,6 +71,7 @@ import UniqSet
 import UniqSupply
 import Module
 import PrelNames
+import Name( isInternalName )
 import Outputable
 import SrcLoc
 import Util
@@ -92,7 +96,8 @@ hand, which should indeed be bound to the pattern as a whole, then use it;
 otherwise, make one up.
 -}
 
-selectSimpleMatchVarL :: LPat Id -> DsM Id
+selectSimpleMatchVarL :: LPat GhcTc -> DsM Id
+-- Postcondition: the returned Id has an Internal Name
 selectSimpleMatchVarL pat = selectMatchVar (unLoc pat)
 
 -- (selectMatchVars ps tys) chooses variables of type tys
@@ -111,22 +116,23 @@ selectSimpleMatchVarL pat = selectMatchVar (unLoc pat)
 --    Then we must not choose (x::Int) as the matching variable!
 -- And nowadays we won't, because the (x::Int) will be wrapped in a CoPat
 
-selectMatchVars :: [Pat Id] -> DsM [Id]
+selectMatchVars :: [Pat GhcTc] -> DsM [Id]
+-- Postcondition: the returned Ids have Internal Names
 selectMatchVars ps = mapM selectMatchVar ps
 
-selectMatchVar :: Pat Id -> DsM Id
-selectMatchVar (BangPat pat) = selectMatchVar (unLoc pat)
-selectMatchVar (LazyPat pat) = selectMatchVar (unLoc pat)
-selectMatchVar (ParPat pat)  = selectMatchVar (unLoc pat)
-selectMatchVar (VarPat var)  = return (localiseId (unLoc var))
+selectMatchVar :: Pat GhcTc -> DsM Id
+-- Postcondition: the returned Id has an Internal Name
+selectMatchVar (BangPat _ pat) = selectMatchVar (unLoc pat)
+selectMatchVar (LazyPat _ pat) = selectMatchVar (unLoc pat)
+selectMatchVar (ParPat _ pat)  = selectMatchVar (unLoc pat)
+selectMatchVar (VarPat _ var)  = return (localiseId (unLoc var))
                                   -- Note [Localise pattern binders]
-selectMatchVar (AsPat var _) = return (unLoc var)
-selectMatchVar other_pat     = newSysLocalDs (hsPatType other_pat)
+selectMatchVar (AsPat var _) = return (unLoc var)
+selectMatchVar other_pat       = newSysLocalDsNoLP (hsPatType other_pat)
                                   -- OK, better make up one...
 
-{-
-Note [Localise pattern binders]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Localise pattern binders]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider     module M where
                [Just a] = e
 After renaming it looks like
@@ -162,6 +168,7 @@ In fact, even CoreSubst.simplOptExpr will do this, and simpleOptExpr
 runs on the output of the desugarer, so all is well by the end of
 the desugaring pass.
 
+See also Note [MatchIds] in Match.hs
 
 ************************************************************************
 *                                                                      *
@@ -174,7 +181,7 @@ The ``equation info'' used by @match@ is relatively complicated and
 worthy of a type synonym and a few handy functions.
 -}
 
-firstPat :: EquationInfo -> Pat Id
+firstPat :: EquationInfo -> Pat GhcTc
 firstPat eqn = ASSERT( notNull (eqn_pats eqn) ) head (eqn_pats eqn)
 
 shiftEqns :: [EquationInfo] -> [EquationInfo]
@@ -255,7 +262,7 @@ mkGuardedMatchResult pred_expr (MatchResult _ body_fn)
   = MatchResult CanFail (\fail -> do body <- body_fn fail
                                      return (mkIfThenElse pred_expr body fail))
 
-mkCoPrimCaseMatchResult :: Id                        -- Scrutinee
+mkCoPrimCaseMatchResult :: Id                  -- Scrutinee
                         -> Type                      -- Type of the case
                         -> [(Literal, MatchResult)]  -- Alternatives
                         -> MatchResult               -- Literals are all unlifted
@@ -278,18 +285,15 @@ data CaseAlt a = MkCaseAlt{ alt_pat :: a,
                             alt_result :: MatchResult }
 
 mkCoAlgCaseMatchResult
-  :: DynFlags
-  -> Id                 -- Scrutinee
+  :: Id                 -- Scrutinee
   -> Type               -- Type of exp
   -> [CaseAlt DataCon]  -- Alternatives (bndrs *include* tyvars, dicts)
   -> MatchResult
-mkCoAlgCaseMatchResult dflags var ty match_alts
+mkCoAlgCaseMatchResult var ty match_alts
   | isNewtype  -- Newtype case; use a let
   = ASSERT( null (tail match_alts) && null (tail arg_ids1) )
     mkCoLetMatchResult (NonRec arg_id1 newtype_rhs) match_result1
 
-  | isPArrFakeAlts match_alts
-  = MatchResult CanFail $ mkPArrCase dflags var ty (sort_alts match_alts)
   | otherwise
   = mkDataConCase var ty match_alts
   where
@@ -307,34 +311,6 @@ mkCoAlgCaseMatchResult dflags var ty match_alts
                                                 -- (not that splitTyConApp does, these days)
     newtype_rhs = unwrapNewTypeBody tc ty_args (Var var)
 
-        --- Stuff for parallel arrays
-        --
-        -- Concerning `isPArrFakeAlts':
-        --
-        --  * it is *not* sufficient to just check the type of the type
-        --   constructor, as we have to be careful not to confuse the real
-        --   representation of parallel arrays with the fake constructors;
-        --   moreover, a list of alternatives must not mix fake and real
-        --   constructors (this is checked earlier on)
-        --
-        -- FIXME: We actually go through the whole list and make sure that
-        --        either all or none of the constructors are fake parallel
-        --        array constructors.  This is to spot equations that mix fake
-        --        constructors with the real representation defined in
-        --        `PrelPArr'.  It would be nicer to spot this situation
-        --        earlier and raise a proper error message, but it can really
-        --        only happen in `PrelPArr' anyway.
-        --
-
-    isPArrFakeAlts :: [CaseAlt DataCon] -> Bool
-    isPArrFakeAlts [alt] = isPArrFakeCon (alt_pat alt)
-    isPArrFakeAlts (alt:alts) =
-      case (isPArrFakeCon (alt_pat alt), isPArrFakeAlts alts) of
-        (True , True ) -> True
-        (False, False) -> False
-        _              -> panic "DsUtils: you may not mix `[:...:]' with `PArr' patterns"
-    isPArrFakeAlts [] = panic "DsUtils: unexpectedly found an empty list of PArr fake alternatives"
-
 mkCoSynCaseMatchResult :: Id -> Type -> CaseAlt PatSyn -> MatchResult
 mkCoSynCaseMatchResult var ty alt = MatchResult CanFail $ mkPatSynCase var ty alt
 
@@ -344,7 +320,7 @@ sort_alts = sortWith (dataConTag . alt_pat)
 mkPatSynCase :: Id -> Type -> CaseAlt PatSyn -> CoreExpr -> DsM CoreExpr
 mkPatSynCase var ty alt fail = do
     matcher <- dsLExpr $ mkLHsWrap wrapper $
-                         nlHsTyApp matcher [getRuntimeRep "mkPatSynCase" ty, ty]
+                         nlHsTyApp matcher [getRuntimeRep ty, ty]
     let MatchResult _ mkCont = match_result
     cont <- mkCoreLams bndrs <$> mkCont fail
     return $ mkCoreAppsDs (text "patsyn" <+> ppr var) matcher [Var var, ensure_unstrict cont, Lam voidArgId fail]
@@ -408,48 +384,6 @@ mkDataConCase var ty alts@(alt1:_) = MatchResult fail_flag mk_case
         = mkUniqSet data_cons `minusUniqSet` mentioned_constructors
     exhaustive_case = isEmptyUniqSet un_mentioned_constructors
 
---- Stuff for parallel arrays
---
---  * the following is to desugar cases over fake constructors for
---   parallel arrays, which are introduced by `tidy1' in the `PArrPat'
---   case
---
-mkPArrCase :: DynFlags -> Id -> Type -> [CaseAlt DataCon] -> CoreExpr -> DsM CoreExpr
-mkPArrCase dflags var ty sorted_alts fail = do
-    lengthP <- dsDPHBuiltin lengthPVar
-    alt <- unboxAlt
-    return (mkWildCase (len lengthP) intTy ty [alt])
-  where
-    elemTy      = case splitTyConApp (idType var) of
-        (_, [elemTy]) -> elemTy
-        _             -> panic panicMsg
-    panicMsg    = "DsUtils.mkCoAlgCaseMatchResult: not a parallel array?"
-    len lengthP = mkApps (Var lengthP) [Type elemTy, Var var]
-    --
-    unboxAlt = do
-        l      <- newSysLocalDs intPrimTy
-        indexP <- dsDPHBuiltin indexPVar
-        alts   <- mapM (mkAlt indexP) sorted_alts
-        return (DataAlt intDataCon, [l], mkWildCase (Var l) intPrimTy ty (dft : alts))
-      where
-        dft  = (DEFAULT, [], fail)
-
-    --
-    -- each alternative matches one array length (corresponding to one
-    -- fake array constructor), so the match is on a literal; each
-    -- alternative's body is extended by a local binding for each
-    -- constructor argument, which are bound to array elements starting
-    -- with the first
-    --
-    mkAlt indexP alt@MkCaseAlt{alt_result = MatchResult _ bodyFun} = do
-        body <- bodyFun fail
-        return (LitAlt lit, [], mkCoreLets binds body)
-      where
-        lit   = MachInt $ toInteger (dataConSourceArity (alt_pat alt))
-        binds = [NonRec arg (indexExpr i) | (i, arg) <- zip [1..] (alt_bndrs alt)]
-        --
-        indexExpr i = mkApps (Var indexP) [Type elemTy, Var var, mkIntExpr dflags i]
-
 {-
 ************************************************************************
 *                                                                      *
@@ -470,7 +404,7 @@ mkErrorAppDs err_id ty msg = do
         full_msg = showSDoc dflags (hcat [ppr src_loc, vbar, msg])
         core_msg = Lit (mkMachString full_msg)
         -- mkMachString returns a result of type String#
-    return (mkApps (Var err_id) [Type (getRuntimeRep "mkErrorAppDs" ty), Type ty, core_msg])
+    return (mkApps (Var err_id) [Type (getRuntimeRep ty), Type ty, core_msg])
 
 {-
 'mkCoreAppDs' and 'mkCoreAppsDs' hand the special-case desugaring of 'seq'.
@@ -540,17 +474,20 @@ into
 which stupidly tries to bind the datacon 'True'.
 -}
 
+-- NB: Make sure the argument is not levity polymorphic
 mkCoreAppDs  :: SDoc -> CoreExpr -> CoreExpr -> CoreExpr
 mkCoreAppDs _ (Var f `App` Type ty1 `App` Type ty2 `App` arg1) arg2
   | f `hasKey` seqIdKey            -- Note [Desugaring seq (1), (2)]
   = Case arg1 case_bndr ty2 [(DEFAULT,[],arg2)]
   where
     case_bndr = case arg1 of
-                   Var v1 | isLocalId v1 -> v1        -- Note [Desugaring seq (2) and (3)]
-                   _                     -> mkWildValBinder ty1
+                   Var v1 | isInternalName (idName v1)
+                          -> v1        -- Note [Desugaring seq (2) and (3)]
+                   _      -> mkWildValBinder ty1
 
 mkCoreAppDs s fun arg = mkCoreApp s fun arg  -- The rest is done in MkCore
 
+-- NB: No argument can be levity polymorphic
 mkCoreAppsDs :: SDoc -> CoreExpr -> [CoreExpr] -> CoreExpr
 mkCoreAppsDs s fun args = foldl (mkCoreAppDs s) fun args
 
@@ -722,7 +659,7 @@ work out well:
 -}
 
 mkSelectorBinds :: [[Tickish Id]] -- ^ ticks to add, possibly
-                -> LPat Id        -- ^ The pattern
+                -> LPat GhcTc     -- ^ The pattern
                 -> CoreExpr       -- ^ Expression to which the pattern is bound
                 -> DsM (Id,[(Id,CoreExpr)])
                 -- ^ Id the rhs is bound to, for desugaring strict
@@ -730,12 +667,12 @@ mkSelectorBinds :: [[Tickish Id]] -- ^ ticks to add, possibly
                 -- and all the desugared binds
 
 mkSelectorBinds ticks pat val_expr
-  | L _ (VarPat (L _ v)) <- pat'     -- Special case (A)
+  | L _ (VarPat (L _ v)) <- pat'     -- Special case (A)
   = return (v, [(v, val_expr)])
 
   | is_flat_prod_lpat pat'           -- Special case (B)
   = do { let pat_ty = hsLPatType pat'
-       ; val_var <- newSysLocalDs pat_ty
+       ; val_var <- newSysLocalDsNoLP pat_ty
 
        ; let mk_bind tick bndr_var
                -- (mk_bind sv bv)  generates  bv = case sv of { pat -> bv }
@@ -754,7 +691,7 @@ mkSelectorBinds ticks pat val_expr
 
   | otherwise                          -- General case (C)
   = do { tuple_var  <- newSysLocalDs tuple_ty
-       ; error_expr <- mkErrorAppDs iRREFUT_PAT_ERROR_ID tuple_ty (ppr pat')
+       ; error_expr <- mkErrorAppDs pAT_ERROR_ID tuple_ty (ppr pat')
        ; tuple_expr <- matchSimply val_expr PatBindRhs pat
                                    local_tuple error_expr
        ; let mk_tup_bind tick binder
@@ -777,17 +714,17 @@ mkSelectorBinds ticks pat val_expr
 
 strip_bangs :: LPat a -> LPat a
 -- Remove outermost bangs and parens
-strip_bangs (L _ (ParPat p))  = strip_bangs p
-strip_bangs (L _ (BangPat p)) = strip_bangs p
-strip_bangs lp                = lp
+strip_bangs (L _ (ParPat p))  = strip_bangs p
+strip_bangs (L _ (BangPat p)) = strip_bangs p
+strip_bangs lp                  = lp
 
 is_flat_prod_lpat :: LPat a -> Bool
 is_flat_prod_lpat p = is_flat_prod_pat (unLoc p)
 
 is_flat_prod_pat :: Pat a -> Bool
-is_flat_prod_pat (ParPat p)            = is_flat_prod_lpat p
-is_flat_prod_pat (TuplePat ps Boxed _) = all is_triv_lpat ps
-is_flat_prod_pat (ConPatOut { pat_con = L _ pcon, pat_args = ps})
+is_flat_prod_pat (ParPat _ p)          = is_flat_prod_lpat p
+is_flat_prod_pat (TuplePat _ ps Boxed) = all is_triv_lpat ps
+is_flat_prod_pat (ConPatOut { pat_con  = L _ pcon, pat_args = ps})
   | RealDataCon con <- pcon
   , isProductTyCon (dataConTyCon con)
   = all is_triv_lpat (hsConPatArgs ps)
@@ -797,10 +734,10 @@ is_triv_lpat :: LPat a -> Bool
 is_triv_lpat p = is_triv_pat (unLoc p)
 
 is_triv_pat :: Pat a -> Bool
-is_triv_pat (VarPat _)  = True
-is_triv_pat (WildPat _) = True
-is_triv_pat (ParPat p)  = is_triv_lpat p
-is_triv_pat _           = False
+is_triv_pat (VarPat {})  = True
+is_triv_pat (WildPat{})  = True
+is_triv_pat (ParPat _ p) = is_triv_lpat p
+is_triv_pat _            = False
 
 
 {- *********************************************************************
@@ -811,31 +748,31 @@ is_triv_pat _           = False
 *                                                                      *
 ********************************************************************* -}
 
-mkLHsPatTup :: [LPat Id] -> LPat Id
+mkLHsPatTup :: [LPat GhcTc] -> LPat GhcTc
 mkLHsPatTup []     = noLoc $ mkVanillaTuplePat [] Boxed
 mkLHsPatTup [lpat] = lpat
 mkLHsPatTup lpats  = L (getLoc (head lpats)) $
                      mkVanillaTuplePat lpats Boxed
 
-mkLHsVarPatTup :: [Id] -> LPat Id
+mkLHsVarPatTup :: [Id] -> LPat GhcTc
 mkLHsVarPatTup bs  = mkLHsPatTup (map nlVarPat bs)
 
-mkVanillaTuplePat :: [OutPat Id] -> Boxity -> Pat Id
+mkVanillaTuplePat :: [OutPat GhcTc] -> Boxity -> Pat GhcTc
 -- A vanilla tuple pattern simply gets its type from its sub-patterns
-mkVanillaTuplePat pats box = TuplePat pats box (map hsLPatType pats)
+mkVanillaTuplePat pats box = TuplePat (map hsLPatType pats) pats box
 
 -- The Big equivalents for the source tuple expressions
-mkBigLHsVarTupId :: [Id] -> LHsExpr Id
+mkBigLHsVarTupId :: [Id] -> LHsExpr GhcTc
 mkBigLHsVarTupId ids = mkBigLHsTupId (map nlHsVar ids)
 
-mkBigLHsTupId :: [LHsExpr Id] -> LHsExpr Id
+mkBigLHsTupId :: [LHsExpr GhcTc] -> LHsExpr GhcTc
 mkBigLHsTupId = mkChunkified mkLHsTupleExpr
 
 -- The Big equivalents for the source tuple patterns
-mkBigLHsVarPatTupId :: [Id] -> LPat Id
+mkBigLHsVarPatTupId :: [Id] -> LPat GhcTc
 mkBigLHsVarPatTupId bs = mkBigLHsPatTupId (map nlVarPat bs)
 
-mkBigLHsPatTupId :: [LPat Id] -> LPat Id
+mkBigLHsPatTupId :: [LPat GhcTc] -> LPat GhcTc
 mkBigLHsPatTupId = mkChunkified mkLHsPatTup
 
 {-
@@ -892,6 +829,15 @@ for the primitive case:
 \end{verbatim}
 
 Now @fail.33@ is a function, so it can be let-bound.
+
+We would *like* to use join points here; in fact, these "fail variables" are
+paradigmatic join points! Sadly, this breaks pattern synonyms, which desugar as
+CPS functions - i.e. they take "join points" as parameters. It's not impossible
+to imagine extending our type system to allow passing join points around (very
+carefully), but we certainly don't support it now.
+
+99.99% of the time, the fail variables wind up as join points in short order
+anyway, and the Void# doesn't do much harm.
 -}
 
 mkFailurePair :: CoreExpr       -- Result type of the whole case expression
@@ -911,6 +857,11 @@ mkFailurePair expr
 {-
 Note [Failure thunks and CPR]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+(This note predates join points as formal entities (hence the quotation marks).
+We can't use actual join points here (see above); if we did, this would also
+solve the CPR problem, since join points don't get CPR'd. See Note [Don't CPR
+join points] in WorkWrap.)
+
 When we make a failure point we ensure that it
 does not look like a thunk. Example:
 
@@ -955,16 +906,41 @@ mkBinaryTickBox ixT ixF e = do
 
 -- *******************************************************************
 
+{- Note [decideBangHood]
+~~~~~~~~~~~~~~~~~~~~~~~~
+With -XStrict we may make /outermost/ patterns more strict.
+E.g.
+       let (Just x) = e in ...
+          ==>
+       let !(Just x) = e in ...
+and
+       f x = e
+          ==>
+       f !x = e
+
+This adjustment is done by decideBangHood,
+
+  * Just before constructing an EqnInfo, in Match
+      (matchWrapper and matchSinglePat)
+
+  * When desugaring a pattern-binding in DsBinds.dsHsBind
+
+Note that it is /not/ done recursively.  See the -XStrict
+spec in the user manual.
+
+Specifically:
+   ~pat    => pat    -- when -XStrict (even if pat = ~pat')
+   !pat    => !pat   -- always
+   pat     => !pat   -- when -XStrict
+   pat     => pat    -- otherwise
+-}
+
+
 -- | Use -XStrict to add a ! or remove a ~
---
--- Examples:
--- ~pat    => pat    -- when -XStrict (even if pat = ~pat')
--- !pat    => !pat   -- always
--- pat     => !pat   -- when -XStrict
--- pat     => pat    -- otherwise
+-- See Note [decideBangHood]
 decideBangHood :: DynFlags
-               -> LPat id  -- ^ Original pattern
-               -> LPat id  -- Pattern with bang if necessary
+               -> LPat GhcTc  -- ^ Original pattern
+               -> LPat GhcTc  -- Pattern with bang if necessary
 decideBangHood dflags lpat
   | not (xopt LangExt.Strict dflags)
   = lpat
@@ -973,7 +949,20 @@ decideBangHood dflags lpat
   where
     go lp@(L l p)
       = case p of
-           ParPat p    -> L l (ParPat (go p))
-           LazyPat lp' -> lp'
-           BangPat _   -> lp
-           _           -> L l (BangPat lp)
+           ParPat x p    -> L l (ParPat x (go p))
+           LazyPat _ lp' -> lp'
+           BangPat _ _   -> lp
+           _             -> L l (BangPat noExt lp)
+
+-- | Unconditionally make a 'Pat' strict.
+addBang :: LPat GhcTc -- ^ Original pattern
+        -> LPat GhcTc -- ^ Banged pattern
+addBang = go
+  where
+    go lp@(L l p)
+      = case p of
+           ParPat x p    -> L l (ParPat x (go p))
+           LazyPat _ lp' -> L l (BangPat noExt lp')
+                                  -- Should we bring the extension value over?
+           BangPat _ _   -> lp
+           _             -> L l (BangPat noExt lp)