Improve checking of joins in Core Lint
authorSimon Peyton Jones <simonpj@microsoft.com>
Thu, 16 Feb 2017 09:42:32 +0000 (09:42 +0000)
committerSimon Peyton Jones <simonpj@microsoft.com>
Thu, 16 Feb 2017 14:24:57 +0000 (14:24 +0000)
This patch addresses the rather expensive treatment of join points,
identified in Trac #13220 comment:17

Before we were tracking the "bad joins".  Now we track the good ones.
That is easier to think about, and much more efficient; see CoreLint
Note [Join points].

On the way I did some other modest refactoring, among other things
removing a duplicated call of lintIdBndr for let-bindings.

On teh

compiler/coreSyn/CoreLint.hs

index f87989d..053ac21 100644 (file)
@@ -151,7 +151,6 @@ find an occurrence of an Id, we fetch it from the in-scope set.
 
 Note [Bad unsafe coercion]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 For discussion see https://ghc.haskell.org/trac/ghc/wiki/BadUnsafeCoercions
 Linter introduces additional rules that checks improper coercion between
 different types, called bad coercions. Following coercions are forbidden:
@@ -170,12 +169,10 @@ different types, called bad coercions. Following coercions are forbidden:
 
 Note [Join points]
 ~~~~~~~~~~~~~~~~~~
-
 We check the rules listed in Note [Invariants on join points] in CoreSyn. The
 only one that causes any difficulty is the first: All occurrences must be tail
-calls. To this end, along with the in-scope set, we remember in le_bad_joins the
-subset of join ids that are no longer allowed because they were declared "too
-far away." For example:
+calls. To this end, along with the in-scope set, we remember in le_joins the
+subset of in-scope Ids that are valid join ids. For example:
 
   join j x = ... in
   case e of
@@ -184,11 +181,11 @@ far away." For example:
            C -> join h = jump j w in ... -- good
            D -> let x = jump j v in ... -- BAD
 
-A join point remains valid in case branches, so when checking the A branch, j
-is still valid. When we check the scrutinee of the inner case, however, we add j
-to le_bad_joins and catch the error. Similarly, join points can occur free in
-RHSes of other join points but not the RHSes of value bindings (thunks and
-functions).
+A join point remains valid in case branches, so when checking the A
+branch, j is still valid. When we check the scrutinee of the inner
+case, however, we set le_joins to empty, and catch the
+error. Similarly, join points can occur free in RHSes of other join
+points but not the RHSes of value bindings (thunks and functions).
 
 ************************************************************************
 *                                                                      *
@@ -387,10 +384,9 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc,
 -- If you edit this function, you may need to update the GHC formalism
 -- See Note [GHC Formalism]
 lintCoreBindings dflags pass local_in_scope binds
-  = initL dflags flags $
-    addLoc TopLevelBindings        $
-    addInScopeVars local_in_scope  $
-    addInScopeVars binders         $
+  = initL dflags flags in_scope_set $
+    addLoc TopLevelBindings         $
+    lintIdBndrs TopLevel binders    $
         -- Put all the top-level binders in scope at the start
         -- This is because transformation rules can bring something
         -- into use 'unexpectedly'
@@ -398,6 +394,8 @@ lintCoreBindings dflags pass local_in_scope binds
        ; checkL (null ext_dups) (dupExtVars ext_dups)
        ; mapM lint_bind binds }
   where
+    in_scope_set = mkInScopeSet (mkVarSet local_in_scope)
+
     flags = LF { lf_check_global_ids = check_globals
                , lf_check_inline_loop_breakers = check_lbs
                , lf_check_static_ptrs = check_static_ptrs }
@@ -463,9 +461,9 @@ lintUnfolding dflags locn vars expr
   | isEmptyBag errs = Nothing
   | otherwise       = Just (pprMessageBag errs)
   where
-    (_warns, errs) = initL dflags defaultLintFlags linter
+    in_scope = mkInScopeSet vars
+    (_warns, errs) = initL dflags defaultLintFlags in_scope linter
     linter = addLoc (ImportedUnfolding locn) $
-             addInScopeVarSet vars           $
              lintCoreExpr expr
 
 lintExpr :: DynFlags
@@ -477,9 +475,9 @@ lintExpr dflags vars expr
   | isEmptyBag errs = Nothing
   | otherwise       = Just (pprMessageBag errs)
   where
-    (_warns, errs) = initL dflags defaultLintFlags linter
+    in_scope = mkInScopeSet (mkVarSet vars)
+    (_warns, errs) = initL dflags defaultLintFlags in_scope linter
     linter = addLoc TopLevelBindings $
-             addInScopeVars vars     $
              lintCoreExpr expr
 
 {-
@@ -499,7 +497,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
   = addLoc (RhsOf binder) $
          -- Check the rhs
     do { ty <- lintRhs binder rhs
-       ; lint_bndr binder -- Check match to RHS type
        ; binder_ty <- applySubstTy (idType binder)
        ; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty)
 
@@ -571,11 +568,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
 
         -- We should check the unfolding, if any, but this is tricky because
         -- the unfolding is a SimplifiableCoreExpr. Give up for now.
-   where
-    -- If you edit this function, you may need to update the GHC formalism
-    -- See Note [GHC Formalism]
-    lint_bndr var | isId var  = lintIdBndr top_lvl_flag var $ \_ -> return ()
-                  | otherwise = return ()
 
 -- | Checks the RHS of bindings. It only differs from 'lintCoreExpr'
 -- in that it doesn't reject occurrences of the function 'makeStatic' when they
@@ -680,7 +672,7 @@ lintCoreExpr :: CoreExpr -> LintM OutType
 -- If you edit this function, you may need to update the GHC formalism
 -- See Note [GHC Formalism]
 lintCoreExpr (Var var)
-  = lintCoreVar var 0
+  = lintVarOcc var 0
 
 lintCoreExpr (Lit lit)
   = return (literalType lit)
@@ -726,13 +718,16 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
   | isId bndr
   = do  { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs)
         ; addLoc (BodyOfLetRec [bndr])
-                 (lintIdBndr NotTopLevel bndr $ \_ -> lintCoreExpr body) }
+                 (lintIdBndr NotTopLevel bndr $ \_ ->
+                  addGoodJoins [bndr] $
+                  lintCoreExpr body) }
 
   | otherwise
   = failWithL (mkLetErr bndr rhs)       -- Not quite accurate
 
 lintCoreExpr (Let (Rec pairs) body)
-  = lintIdBndrs bndrs       $ \_ ->
+  = lintIdBndrs NotTopLevel bndrs  $
+    addGoodJoins bndrs             $
     do  { checkL (null dups) (dupVars dups)
         ; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
             mkInconsistentRecMsg bndrs
@@ -812,51 +807,38 @@ lintCoreExpr (Coercion co)
   = do { (k1, k2, ty1, ty2, role) <- lintInCo co
        ; return (mkHeteroCoercionType role k1 k2 ty1 ty2) }
 
-lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed
+----------------------
+lintVarOcc :: Var -> Int -- Number of arguments (type or value) being passed
             -> LintM Type -- returns type of the *variable*
-lintCoreVar var nargs
+lintVarOcc var nargs
   = do  { checkL (isNonCoVarId var)
                  (text "Non term variable" <+> ppr var)
 
-        ; lf <- getLintFlags
+        -- Cneck that the type of the occurrence is the same
+        -- as the type of the binding site
+        ; ty   <- applySubstTy (idType var)
+        ; var' <- lookupIdInScope var
+        ; let ty' = idType var'
+        ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
+
           -- Check for a nested occurrence of the StaticPtr constructor.
           -- See Note [Checking StaticPtrs].
+        ; lf <- getLintFlags
         ; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $
             checkL (idName var /= makeStaticName) $
               text "Found makeStatic nested in an expression"
 
         ; checkDeadIdOcc var
-        ; ty   <- applySubstTy (idType var)
-        ; var' <- lookupIdInScope var
-        ; let ty' = idType var'
-        ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
-        ; mb_join_arity
-            <- case isJoinId_maybe var' of
-                 Just join_arity ->
-                   do  { checkL (isJoinId_maybe var == Just join_arity) $
-                           mkJoinBndrOccMismatchMsg var' var
-                       ; return $ Just join_arity }
-                 Nothing ->
-                   case tailCallInfo (idOccInfo var') of
-                     AlwaysTailCalled join_arity -> return $ Just join_arity
-                       -- This function will be turned into a join point by the
-                       -- simplifier; typecheck it as if it already were one
-                     NoTailCallInfo              -> return $ Nothing
-        ; case mb_join_arity of
-            Just join_arity ->
-              do  { bad <- isBadJoin var'
-                  ; checkL (not bad) $ mkJoinOutOfScopeMsg var'
-                  ; checkL (nargs == join_arity) $
-                      mkBadJumpMsg var' join_arity nargs }
-            Nothing ->
-              do  { checkL (not (isJoinId var)) $
-                      mkJoinBndrOccMismatchMsg var' var }
+        ; checkJoinOcc var nargs
+
         ; return (idType var') }
 
-lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed
-            -> LintM Type -- returns type of the *function*
+lintCoreFun :: CoreExpr
+            -> Int        -- Number of arguments (type or val) being passed
+            -> LintM Type -- Returns type of the *function*
 lintCoreFun (Var var) nargs
-  = lintCoreVar var nargs
+  = lintVarOcc var nargs
+
 lintCoreFun (Lam var body) nargs
   -- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see
   -- Note [Beta redexes]
@@ -865,10 +847,47 @@ lintCoreFun (Lam var body) nargs
     lintBinder var $ \ var' ->
     do { body_ty <- lintCoreFun body (nargs - 1)
        ; return $ mkLamType var' body_ty }
+
 lintCoreFun expr nargs
   = markAllJoinsBadIf (nargs /= 0) $
     lintCoreExpr expr
 
+------------------
+checkDeadIdOcc :: Id -> LintM ()
+-- Occurrences of an Id should never be dead....
+-- except when we are checking a case pattern
+checkDeadIdOcc id
+  | isDeadOcc (idOccInfo id)
+  = do { in_case <- inCasePat
+       ; checkL in_case
+                (text "Occurrence of a dead Id" <+> ppr id) }
+  | otherwise
+  = return ()
+
+------------------
+checkJoinOcc :: Id -> JoinArity -> LintM ()
+-- Check that if the occurrence is a JoinId, then so is the
+-- binding site, and it's a valid join Id
+checkJoinOcc var n_args
+  | Just join_arity_occ <- isJoinId_maybe var
+  = do { mb_join_arity_bndr <- lookupJoinId var
+       ; case mb_join_arity_bndr of {
+           Nothing -> -- Binder is not a join point
+                      addErrL (invalidJoinOcc var) ;
+
+           Just join_arity_bndr ->
+
+    do { checkL (join_arity_bndr == join_arity_occ) $
+           -- Arity differs at binding site and occurrence
+         mkJoinBndrOccMismatchMsg var join_arity_bndr join_arity_occ
+
+       ; checkL (n_args == join_arity_occ) $
+           -- Arity doesn't match #args
+         mkBadJumpMsg var join_arity_occ n_args } } }
+
+  | otherwise
+  = return ()
+
 {-
 Note [No alternatives lint check]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1010,17 +1029,6 @@ lintTyKind tyvar arg_ty
   where
     tyvar_kind = tyVarKind tyvar
 
-checkDeadIdOcc :: Id -> LintM ()
--- Occurrences of an Id should never be dead....
--- except when we are checking a case pattern
-checkDeadIdOcc id
-  | isDeadOcc (idOccInfo id)
-  = do { in_case <- inCasePat
-       ; checkL in_case
-                (text "Occurrence of a dead Id" <+> ppr id) }
-  | otherwise
-  = return ()
-
 {-
 ************************************************************************
 *                                                                      *
@@ -1152,21 +1160,22 @@ lintCoBndr cv thing_inside
                (text "CoVar with non-coercion type:" <+> pprTyVar cv)
        ; updateTCvSubst subst' (thing_inside cv') }
 
-lintIdBndrs :: [Var] -> ([Var] -> LintM a) -> LintM a
-lintIdBndrs ids linterF
+lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
+lintIdBndrs top_lvl ids linterF
   = go ids
   where
-    go []       = linterF []
-    go (id:ids) = lintIdBndr NotTopLevel id $ \id ->
-                  lintIdBndrs           ids $ \ids ->
-                  linterF (id:ids)
+    go []       = linterF
+    go (id:ids) = lintIdBndr  top_lvl id  $ \_ ->
+                  lintIdBndrs top_lvl ids $
+                  linterF
 
 lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a
 -- Do substitution on the type of a binder and add the var with this
 -- new type to the in-scope set of the second argument
 -- ToDo: lint its rules
 lintIdBndr top_lvl id linterF
-  = do { flags <- getLintFlags
+  = ASSERT2( isId id, ppr id )
+    do { flags <- getLintFlags
        ; checkL (not (lf_check_global_ids flags) || isLocalId id)
                 (text "Non-local Id binder" <+> ppr id)
                 -- See Note [Checking for global Ids]
@@ -1784,7 +1793,8 @@ data LintEnv
        , le_subst :: TCvSubst        -- Current type substitution; we also use this
                                      -- to keep track of all the variables in scope,
                                      -- both Ids and TyVars
-       , le_bad_joins :: IdSet       -- Join points that are no longer valid
+       , le_joins :: IdSet           -- Join points in scope that are valid
+                                     -- A subset of teh InScopeSet in le_subst
                                      -- See Note [Join points]
        , le_dynflags :: DynFlags     -- DynamicFlags
        }
@@ -1891,13 +1901,17 @@ data LintLocInfo
   | InType Type         -- Inside a type
   | InCo   Coercion     -- Inside a coercion
 
-initL :: DynFlags -> LintFlags -> LintM a -> WarnsAndErrs    -- Errors and warnings
-initL dflags flags m
+initL :: DynFlags -> LintFlags -> InScopeSet
+       -> LintM a -> WarnsAndErrs    -- Errors and warnings
+initL dflags flags in_scope m
   = case unLintM m env (emptyBag, emptyBag) of
       (_, errs) -> errs
   where
-    env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = []
-             , le_dynflags = dflags, le_bad_joins = emptyVarSet }
+    env = LE { le_flags = flags
+             , le_subst = mkEmptyTCvSubst in_scope
+             , le_joins = emptyVarSet
+             , le_loc = []
+             , le_dynflags = dflags }
 
 getLintFlags :: LintM LintFlags
 getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs)
@@ -1952,29 +1966,12 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs)
     is_case_pat (LE { le_loc = CasePat {} : _ }) = True
     is_case_pat _other                           = False
 
-addInScopeVars :: [Var] -> LintM a -> LintM a
-addInScopeVars vars m
-  = LintM $ \ env errs ->
-    unLintM m (env { le_subst     = extendTCvInScopeList (le_subst env) vars
-                   , le_bad_joins = bad_joins' env })
-              errs
-  where
-    bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars)
-
-addInScopeVarSet :: VarSet -> LintM a -> LintM a
-addInScopeVarSet vars m
-  = LintM $ \ env errs ->
-    unLintM m (env { le_subst = extendTCvInScopeSet (le_subst env) vars })
-              errs
-
 addInScopeVar :: Var -> LintM a -> LintM a
 addInScopeVar var m
   = LintM $ \ env errs ->
-    unLintM m (env { le_subst     = extendTCvInScope (le_subst env) var
-                   , le_bad_joins = bad_joins' env }) errs
-  where
-    bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var
-                   | otherwise    = le_bad_joins env
+    unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
+                   , le_joins = delVarSet        (le_joins env) var
+               }) errs
 
 extendSubstL :: TyVar -> Type -> LintM a -> LintM a
 extendSubstL tv ty m
@@ -1987,16 +1984,25 @@ updateTCvSubst subst' m
 
 markAllJoinsBad :: LintM a -> LintM a
 markAllJoinsBad m
-  = LintM $ \ env errs -> unLintM m (marked env) errs
-  where
-    marked env = env { le_bad_joins = filterVarSet isJoinId in_set }
-      where
-        in_set = getInScopeVars (getTCvInScope (le_subst env))
+  = LintM $ \ env errs -> unLintM m (env { le_joins = emptyVarSet }) errs
 
 markAllJoinsBadIf :: Bool -> LintM a -> LintM a
 markAllJoinsBadIf True  m = markAllJoinsBad m
 markAllJoinsBadIf False m = m
 
+addGoodJoins :: [Var] -> LintM a -> LintM a
+addGoodJoins vars thing_inside
+  | null join_ids
+  = thing_inside
+  | otherwise
+  = LintM $ \ env errs -> unLintM thing_inside (add_joins env) errs
+  where
+    add_joins env = env { le_joins = le_joins env `extendVarSetList` join_ids }
+    join_ids = filter isJoinId vars
+
+getValidJoins :: LintM IdSet
+getValidJoins = LintM (\ env errs -> (Just (le_joins env), errs))
+
 getTCvSubst :: LintM TCvSubst
 getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs))
 
@@ -2022,9 +2028,14 @@ lookupIdInScope id
   where
     out_of_scope = pprBndr LetBind id <+> text "is out of scope"
 
-isBadJoin :: Id -> LintM Bool
-isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env),
-                                     errs)
+lookupJoinId :: Id -> LintM (Maybe JoinArity)
+-- Look up an Id which should be a join point, valid here
+-- If so, return its arity, if not return Nothing
+lookupJoinId id
+  = do { join_set <- getValidJoins
+       ; case lookupVarSet join_set id of
+            Just id' -> return (isJoinId_maybe id')
+            Nothing  -> return Nothing }
 
 lintTyCoVarInScope :: Var -> LintM ()
 lintTyCoVarInScope v = lintInScope (text "is out of scope") v
@@ -2294,9 +2305,10 @@ mkBadJoinArityMsg var ar nlams
            text "Join arity:" <+> ppr ar,
            text "Number of lambdas:" <+> ppr nlams ]
 
-mkJoinOutOfScopeMsg :: Var -> SDoc
-mkJoinOutOfScopeMsg var
-  = text "Join variable no longer in scope:" <+> ppr var
+invalidJoinOcc :: Var -> SDoc
+invalidJoinOcc var
+  = vcat [ text "Invalid occurrence of a join variable:" <+> ppr var
+         , text "The binder is either not a join point, or not valid here" ]
 
 mkBadJumpMsg :: Var -> Int -> Int -> SDoc
 mkBadJumpMsg var ar nargs
@@ -2312,17 +2324,12 @@ mkInconsistentRecMsg bndrs
   where
     ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr)
 
-mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc
-mkJoinBndrOccMismatchMsg bndr var
-  = vcat [ text "Mismatch in join point status between binder and occurrence",
-           text "Var:" <+> ppr bndr,
-           text "Binder:" <+> ppr_join_status bndr,
-           text "Occ:" <+> ppr_join_status var ]
-  where
-    ppr_join_status v = case details of JoinId _ -> ppr details
-                                        _        -> text "not a join id"
-      where
-        details = idDetails v
+mkJoinBndrOccMismatchMsg :: Var -> JoinArity -> JoinArity -> SDoc
+mkJoinBndrOccMismatchMsg bndr join_arity_bndr join_arity_occ
+  = vcat [ text "Mismatch in join point arity between binder and occurrence"
+         , text "Var:" <+> ppr bndr
+         , text "Arity at binding site:" <+> ppr join_arity_bndr
+         , text "Arity at occurrence:  " <+> ppr join_arity_occ ]
 
 mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc
 mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty