Improve Core Lint, mainly for join points
authorSimon Peyton Jones <simonpj@microsoft.com>
Fri, 17 Feb 2017 15:03:01 +0000 (15:03 +0000)
committerBen Gamari <ben@smart-cactus.org>
Tue, 21 Feb 2017 14:31:17 +0000 (09:31 -0500)
* lintSingleBinding: check that join points have
                     a valid join-point type
  (Trac #13281)

* lintIdBinder: check that a JoinId is bound by
                a non-top-level let
  i.e.  not a top level binder
        not lambda/case binder

* Check for empty Rec [] bindings

* Rename lintIdBndrs to lintLetBndrs

compiler/coreSyn/CoreLint.hs

index 4aa7d44..aed9382 100644 (file)
@@ -392,7 +392,7 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc,
 lintCoreBindings dflags pass local_in_scope binds
   = initL dflags flags in_scope_set $
     addLoc TopLevelBindings         $
-    lintIdBndrs TopLevel binders    $
+    lintLetBndrs TopLevel binders   $
         -- Put all the top-level binders in scope at the start
         -- This is because transformation rules can bring something
         -- into use 'unexpectedly'
@@ -531,9 +531,12 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
 
        ; flags <- getLintFlags
 
-        -- Check that if the binder is top-level, it's not a join point
-       ; checkL (not (isJoinId binder && isTopLevel top_lvl_flag))
-           (mkTopJoinMsg binder)
+         -- Check that a join-point binder has a valid type
+         -- NB: lintIdBinder has checked that it is not top-level bound
+       ; case isJoinId_maybe binder of
+            Nothing    -> return ()
+            Just arity ->  checkL (isValidJoinPointType arity binder_ty)
+                                  (mkInvalidJoinPointMsg binder binder_ty)
 
        ; when (lf_check_inline_loop_breakers flags
                && isStrongLoopBreaker (idOccInfo binder)
@@ -591,11 +594,13 @@ lintRhs bndr rhs
   where
     lint_join_lams 0 _ _ rhs
       = lintCoreExpr rhs
+
     lint_join_lams n tot enforce (Lam var expr)
       = addLoc (LambdaBodyOf var) $
-        lintBinder var $ \ var' ->
+        lintBinder LambdaBind var $ \ var' ->
         do { body_ty <- lint_join_lams (n-1) tot enforce expr
            ; return $ mkLamType var' body_ty }
+
     lint_join_lams n tot True _other
       = failWithL $ mkBadJoinArityMsg bndr tot (tot-n)
     lint_join_lams _ _ False rhs
@@ -617,7 +622,7 @@ lintRhs _bndr rhs = fmap lf_check_static_ptrs getLintFlags >>= go
         -- imitate @lintCoreExpr (Lam ...)@
         (\var loopBinders ->
           addLoc (LambdaBodyOf var) $
-            lintBinder var $ \var' ->
+            lintBinder LambdaBind var $ \var' ->
               do { body_ty <- loopBinders
                  ; return $ mkLamType var' body_ty }
         )
@@ -636,7 +641,7 @@ lintIdUnfolding bndr bndr_ty (CoreUnfolding { uf_tmpl = rhs, uf_src = src })
 
 lintIdUnfolding bndr bndr_ty (DFunUnfolding { df_con = con, df_bndrs = bndrs
                                             , df_args = args })
-  = do { ty <- lintBinders bndrs $ \ bndrs' ->
+  = do { ty <- lintBinders LambdaBind bndrs $ \ bndrs' ->
                do { res_ty <- lintCoreArgs (dataConRepType con) args
                   ; return (mkLamTypes bndrs' res_ty) }
        ; ensureEqTys bndr_ty ty (mkRhsMsg bndr (text "dfun unfolding") ty) }
@@ -724,19 +729,26 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
   | isId bndr
   = do  { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs)
         ; addLoc (BodyOfLetRec [bndr])
-                 (lintIdBndr NotTopLevel bndr $ \_ ->
+                 (lintIdBndr NotTopLevel LetBind bndr $ \_ ->
                   addGoodJoins [bndr] $
                   lintCoreExpr body) }
 
   | otherwise
   = failWithL (mkLetErr bndr rhs)       -- Not quite accurate
 
-lintCoreExpr (Let (Rec pairs) body)
-  = lintIdBndrs NotTopLevel bndrs  $
+lintCoreExpr e@(Let (Rec pairs) body)
+  = lintLetBndrs NotTopLevel bndrs $
     addGoodJoins bndrs             $
-    do  { checkL (null dups) (dupVars dups)
+    do  { -- Check that the list of pairs is non-empty
+          checkL (not (null pairs)) (emptyRec e)
+
+          -- Check that there are no duplicated binders
+        ; checkL (null dups) (dupVars dups)
+
+          -- Check that either all the binders are joins, or none
         ; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
             mkInconsistentRecMsg bndrs
+
         ; mapM_ (lintSingleBinding NotTopLevel Recursive) pairs
         ; addLoc (BodyOfLetRec bndrs) (lintCoreExpr body) }
   where
@@ -753,7 +765,7 @@ lintCoreExpr e@(App _ _)
 lintCoreExpr (Lam var expr)
   = addLoc (LambdaBodyOf var) $
     markAllJoinsBad $
-    lintBinder var $ \ var' ->
+    lintBinder LambdaBind var $ \ var' ->
     do { body_ty <- lintCoreExpr expr
        ; return $ mkLamType var' body_ty }
 
@@ -798,7 +810,7 @@ lintCoreExpr e@(Case scrut var alt_ty alts) =
      ; subst <- getTCvSubst
      ; ensureEqTys var_ty scrut_ty (mkScrutMsg var var_ty scrut_ty subst)
 
-     ; lintIdBndr NotTopLevel var $ \_ ->
+     ; lintIdBndr NotTopLevel CaseBind var $ \_ ->
        do { -- Check the alternatives
             mapM_ (lintCoreAlt scrut_ty alt_ty) alts
           ; checkCaseAlts e scrut_ty alts
@@ -850,7 +862,7 @@ lintCoreFun (Lam var body) nargs
   -- Note [Beta redexes]
   | nargs /= 0
   = addLoc (LambdaBodyOf var) $
-    lintBinder var $ \ var' ->
+    lintBinder LambdaBind var $ \ var' ->
     do { body_ty <- lintCoreFun body (nargs - 1)
        ; return $ mkLamType var' body_ty }
 
@@ -1117,7 +1129,7 @@ lintCoreAlt scrut_ty alt_ty alt@(DataAlt con, args, rhs)
     ; let con_payload_ty = piResultTys (dataConRepType con) tycon_arg_tys
 
         -- And now bring the new binders into scope
-    ; lintBinders args $ \ args' -> do
+    ; lintBinders CasePatBind args $ \ args' -> do
     { addLoc (CasePat alt) (lintAltBinders scrut_ty con_payload_ty args')
     ; lintAltExpr rhs alt_ty } }
 
@@ -1136,19 +1148,19 @@ lintCoreAlt scrut_ty alt_ty alt@(DataAlt con, args, rhs)
 --  1. Lint var types or kinds (possibly substituting)
 --  2. Add the binder to the in scope set, and if its a coercion var,
 --     we may extend the substitution to reflect its (possibly) new kind
-lintBinders :: [Var] -> ([Var] -> LintM a) -> LintM a
-lintBinders [] linterF = linterF []
-lintBinders (var:vars) linterF = lintBinder var $ \var' ->
-                                 lintBinders vars $ \ vars' ->
-                                 linterF (var':vars')
+lintBinders :: BindingSite -> [Var] -> ([Var] -> LintM a) -> LintM a
+lintBinders _    []         linterF = linterF []
+lintBinders site (var:vars) linterF = lintBinder site var $ \var' ->
+                                      lintBinders site vars $ \ vars' ->
+                                      linterF (var':vars')
 
 -- If you edit this function, you may need to update the GHC formalism
 -- See Note [GHC Formalism]
-lintBinder :: Var -> (Var -> LintM a) -> LintM a
-lintBinder var linterF
-  | isTyVar var = lintTyBndr             var linterF
-  | isCoVar var = lintCoBndr             var linterF
-  | otherwise   = lintIdBndr NotTopLevel var linterF
+lintBinder :: BindingSite -> Var -> (Var -> LintM a) -> LintM a
+lintBinder site var linterF
+  | isTyVar var = lintTyBndr                  var linterF
+  | isCoVar var = lintCoBndr                  var linterF
+  | otherwise   = lintIdBndr NotTopLevel site var linterF
 
 lintTyBndr :: InTyVar -> (OutTyVar -> LintM a) -> LintM a
 lintTyBndr tv thing_inside
@@ -1166,20 +1178,20 @@ lintCoBndr cv thing_inside
                (text "CoVar with non-coercion type:" <+> pprTyVar cv)
        ; updateTCvSubst subst' (thing_inside cv') }
 
-lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
-lintIdBndrs top_lvl ids linterF
+lintLetBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
+lintLetBndrs top_lvl ids linterF
   = go ids
   where
     go []       = linterF
-    go (id:ids) = lintIdBndr  top_lvl id  $ \_ ->
-                  lintIdBndrs top_lvl ids $
-                  linterF
+    go (id:ids) = lintIdBndr top_lvl LetBind id  $ \_ ->
+                  go ids
 
-lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a
+lintIdBndr :: TopLevelFlag -> BindingSite
+           -> InVar -> (OutVar -> LintM a) -> LintM a
 -- Do substitution on the type of a binder and add the var with this
 -- new type to the in-scope set of the second argument
 -- ToDo: lint its rules
-lintIdBndr top_lvl id linterF
+lintIdBndr top_lvl bind_site id linterF
   = ASSERT2( isId id, ppr id )
     do { flags <- getLintFlags
        ; checkL (not (lf_check_global_ids flags) || isLocalId id)
@@ -1187,11 +1199,11 @@ lintIdBndr top_lvl id linterF
                 -- See Note [Checking for global Ids]
 
        -- Check that if the binder is nested, it is not marked as exported
-       ; checkL (not (isExportedId id) || isTopLevel top_lvl)
+       ; checkL (not (isExportedId id) || is_top_lvl)
            (mkNonTopExportedMsg id)
 
        -- Check that if the binder is nested, it does not have an external name
-       ; checkL (not (isExternalName (Var.varName id)) || isTopLevel top_lvl)
+       ; checkL (not (isExternalName (Var.varName id)) || is_top_lvl)
            (mkNonTopExternalNameMsg id)
 
        ; (ty, k) <- lintInTy (idType id)
@@ -1200,8 +1212,18 @@ lintIdBndr top_lvl id linterF
            (text "Levity-polymorphic binder:" <+>
                  (ppr id <+> dcolon <+> parens (ppr ty <+> dcolon <+> ppr k)))
 
+       -- Check that a join-id is a not-top-level let-binding
+       ; when (isJoinId id) $
+         checkL (not is_top_lvl && is_let_bind) $
+         mkBadJoinBindMsg id
+
        ; let id' = setIdType id ty
        ; addInScopeVar id' $ (linterF id') }
+  where
+    is_top_lvl = isTopLevel top_lvl
+    is_let_bind = case bind_site of
+                    LetBind -> True
+                    _       -> False
 
 {-
 %************************************************************************
@@ -1387,7 +1409,7 @@ lintCoreRule _ _ (BuiltinRule {})
 
 lintCoreRule fun fun_ty rule@(Rule { ru_name = name, ru_bndrs = bndrs
                                    , ru_args = args, ru_rhs = rhs })
-  = lintBinders bndrs $ \ _ ->
+  = lintBinders LambdaBind bndrs $ \ _ ->
     do { lhs_ty <- foldM lintCoreArg fun_ty args
        ; rhs_ty <- case isJoinId_maybe fun of
                      Just join_arity
@@ -2225,6 +2247,9 @@ mkTyAppMsg ty arg_ty
               hang (text "Arg type:")
                  4 (ppr arg_ty <+> dcolon <+> ppr (typeKind arg_ty))]
 
+emptyRec :: CoreExpr -> MsgDoc
+emptyRec e = hang (text "Empty Rec binding:") 2 (ppr e)
+
 mkRhsMsg :: Id -> SDoc -> Type -> MsgDoc
 mkRhsMsg binder what ty
   = vcat
@@ -2311,9 +2336,15 @@ mkBadTyVarMsg tv
   = text "Non-tyvar used in TyVarTy:"
       <+> ppr tv <+> dcolon <+> ppr (varType tv)
 
-mkTopJoinMsg :: Var -> SDoc
-mkTopJoinMsg var
-  = text "Join point at top level:" <+> ppr var
+mkBadJoinBindMsg :: Var -> SDoc
+mkBadJoinBindMsg var
+  = vcat [ text "Bad join point binding:" <+> ppr var
+         , text "Join points can be bound only by a non-top-level let" ]
+
+mkInvalidJoinPointMsg :: Var -> Type -> SDoc
+mkInvalidJoinPointMsg var ty
+  = hang (text "Join point has invalid type:")
+        2 (ppr var <+> dcolon <+> ppr ty)
 
 mkBadJoinArityMsg :: Var -> Int -> Int -> SDoc
 mkBadJoinArityMsg var ar nlams