Mark non-recursive join lambdas as one-shot
authorSimon Peyton Jones <simonpj@microsoft.com>
Wed, 1 Mar 2017 01:25:33 +0000 (20:25 -0500)
committerDavid Feuer <David.Feuer@gmail.com>
Wed, 1 Mar 2017 01:25:34 +0000 (20:25 -0500)
When we have

  join j x y = rhs in ...

we know that the lambdas for 'x' and 'y' are one-shot.
Let's mark them as such!

This doesn't fix a specific bug, but it feels right to me.

Reviewers: austin, bgamari

Reviewed By: bgamari

Subscribers: lukemaurer, thomie

Differential Revision: https://phabricator.haskell.org/D3196

compiler/simplCore/OccurAnal.hs

index f2f7da6..949cbf1 100644 (file)
@@ -732,7 +732,6 @@ add this analysis if necessary.
 ------------------------------------------------------------
 Note [Adjusting for lambdas]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
 There's a bit of a dance we need to do after analysing a lambda expression or
 a right-hand side. In particular, we need to
 
@@ -802,28 +801,33 @@ occAnalNonRecBind env lvl imp_rule_edges binder rhs body_usage
   | otherwise                   -- It's mentioned in the body
   = (body_usage' +++ rhs_usage', [NonRec tagged_binder rhs'])
   where
-    (bndrs, body) = collectBinders rhs
     (body_usage', tagged_binder) = tagNonRecBinder lvl body_usage binder
+    mb_join_arity = willBeJoinId_maybe tagged_binder
+
+    (bndrs, body) = collectBinders rhs
+
     (rhs_usage1, bndrs', body') = occAnalNonRecRhs env tagged_binder bndrs body
-    rhs' = mkLams bndrs' body'
+    rhs' = mkLams (markJoinOneShots mb_join_arity bndrs') body'
+           -- For a /non-recursive/ join point we can mark all
+           -- its join-lambda as one-shot; and it's a good idea to do so
+
+    -- Unfoldings
+    -- See Note [Unfoldings and join points]
     rhs_usage2 = case occAnalUnfolding env NonRecursive binder of
                    Just unf_usage -> rhs_usage1 +++ unf_usage
                    Nothing        -> rhs_usage1
-       -- See Note [Unfoldings and join points]
 
-    mb_join_arity = willBeJoinId_maybe tagged_binder
+    -- Rules
+    -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
     rules_w_uds = occAnalRules env mb_join_arity NonRecursive tagged_binder
-
     rhs_usage3 = rhs_usage2 +++ combineUsageDetailsList
                                   (map (\(_, l, r) -> l +++ r) rules_w_uds)
-       -- See Note [Rules are extra RHSs] and Note [Rule dependency info]
-
     rhs_usage4 = maybe rhs_usage3 (addManyOccsSet rhs_usage3) $
                  lookupVarEnv imp_rule_edges binder
        -- See Note [Preventing loops due to imported functions rules]
 
-    rhs_usage' = adjustRhsUsage (willBeJoinId_maybe tagged_binder) NonRecursive
-                                bndrs' rhs_usage4
+    -- Final adjustment
+    rhs_usage' = adjustRhsUsage mb_join_arity NonRecursive bndrs' rhs_usage4
 
 -----------------
 occAnalRecBind :: OccEnv -> TopLevelFlag -> ImpRuleEdges -> [(Var,CoreExpr)]
@@ -1550,7 +1554,6 @@ occAnalNonRecRhs env bndr bndrs body
     -- See Note [Sources of one-shot information]
     rhs_env = env1 { occ_one_shots = argOneShots dmd }
 
-
     certainly_inline -- See Note [Cascading inlines]
       = case idOccInfo bndr of
           OneOcc { occ_in_lam = in_lam, occ_one_br = one_br }
@@ -1731,7 +1734,8 @@ occAnal env app@(App _ _)
 --   (a) occurrences inside type lambdas only not marked as InsideLam
 --   (b) type variables not in environment
 
-occAnal env (Lam x body) | isTyVar x
+occAnal env (Lam x body)
+  | isTyVar x
   = case occAnal env body of { (body_usage, body') ->
     (markAllNonTailCalled body_usage, Lam x body')
     }
@@ -1749,14 +1753,14 @@ occAnal env expr@(Lam _ _)
   = case occAnalLamOrRhs env binders body of { (usage, tagged_binders, body') ->
     let
         expr'       = mkLams tagged_binders body'
-        final_usage | all isOneShotBndr tagged_binders
-                    = markAllNonTailCalled usage
-                    | otherwise
-                    = markAllInsideLam $ markAllNonTailCalled usage
+        usage1      = markAllNonTailCalled usage
+        one_shot_gp = all isOneShotBndr tagged_binders
+        final_usage | one_shot_gp = usage1
+                    | otherwise   = markAllInsideLam usage1
     in
     (final_usage, expr') }
   where
-    (binders, body)      = collectBinders expr
+    (binders, body) = collectBinders expr
 
 occAnal env (Case scrut bndr ty alts)
   = case occ_anal_scrut scrut alts     of { (scrut_usage, scrut') ->
@@ -2130,21 +2134,31 @@ oneShotGroup env@(OccEnv { occ_one_shots = ctxt }) bndrs
       = ( env { occ_one_shots = [], occ_encl = OccVanilla }
         , reverse rev_bndrs ++ bndrs )
 
-    go ctxt (bndr:bndrs) rev_bndrs
-      | isId bndr
-
-      = case ctxt of
-          []                -> go []   bndrs (bndr : rev_bndrs)
-          (one_shot : ctxt) -> go ctxt bndrs (bndr': rev_bndrs)
-            where
-               bndr' = updOneShotInfo bndr one_shot
+    go ctxt@(one_shot : ctxt') (bndr : bndrs) rev_bndrs
+      | isId bndr = go ctxt' bndrs (bndr': rev_bndrs)
+      | otherwise = go ctxt  bndrs (bndr : rev_bndrs)
+      where
+        bndr' = updOneShotInfo bndr one_shot
                -- Use updOneShotInfo, not setOneShotInfo, as pre-existing
                -- one-shot info might be better than what we can infer, e.g.
                -- due to explicit use of the magic 'oneShot' function.
                -- See Note [The oneShot function]
 
-       | otherwise
-      = go ctxt bndrs (bndr:rev_bndrs)
+
+markJoinOneShots :: Maybe JoinArity -> [Var] -> [Var]
+-- Mark the lambdas of a non-recursive join point as one-shot.
+-- This is good to prevent gratuitous float-out etc
+markJoinOneShots mb_join_arity bndrs
+  = case mb_join_arity of
+      Nothing -> bndrs
+      Just n  -> go n bndrs
+ where
+   go 0 bndrs  = bndrs
+   go _ []     = WARN( True, ppr mb_join_arity <+> ppr bndrs ) []
+   go n (b:bs) = b' : go (n-1) bs
+     where
+       b' | isId b    = setOneShotLambda b
+          | otherwise = b
 
 addAppCtxt :: OccEnv -> [Arg CoreBndr] -> OccEnv
 addAppCtxt env@(OccEnv { occ_one_shots = ctxt }) args