Improve demand analysis for join points
authorSimon Peyton Jones <simonpj@microsoft.com>
Mon, 10 Apr 2017 07:51:49 +0000 (08:51 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Wed, 12 Apr 2017 15:16:14 +0000 (16:16 +0100)
I realised (Trac #13543) that we can improve demand analysis for
join point quite straightforwardly.

The idea is explained in
    Note [Demand analysis for join points]
in DmdAnal

compiler/stranal/DmdAnal.hs
testsuite/tests/simplCore/should_compile/T13543.hs [new file with mode: 0644]
testsuite/tests/simplCore/should_compile/T13543.stderr [new file with mode: 0644]
testsuite/tests/simplCore/should_compile/all.T

index 304a2be..78eefe3 100644 (file)
@@ -64,20 +64,20 @@ dmdAnalProgram dflags fam_envs binds
 dmdAnalTopBind :: AnalEnv
                -> CoreBind
                -> (AnalEnv, CoreBind)
-dmdAnalTopBind sigs (NonRec id rhs)
-  = (extendAnalEnv TopLevel sigs id2 (idStrictness id2), NonRec id2 rhs2)
+dmdAnalTopBind env (NonRec id rhs)
+  = (extendAnalEnv TopLevel env id2 (idStrictness id2), NonRec id2 rhs2)
   where
-    ( _, _,   rhs1) = dmdAnalRhsLetDown TopLevel Nothing sigs             id rhs
-    ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin sigs) id rhs1
+    ( _, _,   rhs1) = dmdAnalRhsLetDown TopLevel Nothing env             cleanEvalDmd id rhs
+    ( _, id2, rhs2) = dmdAnalRhsLetDown TopLevel Nothing (nonVirgin env) cleanEvalDmd id rhs1
         -- Do two passes to improve CPR information
         -- See Note [CPR for thunks]
         -- See Note [Optimistic CPR in the "virgin" case]
         -- See Note [Initial CPR for strict binders]
 
-dmdAnalTopBind sigs (Rec pairs)
-  = (sigs', Rec pairs')
+dmdAnalTopBind env (Rec pairs)
+  = (env', Rec pairs')
   where
-    (sigs', _, pairs')  = dmdFix TopLevel sigs pairs
+    (env', _, pairs')  = dmdFix TopLevel env cleanEvalDmd pairs
                 -- We get two iterations automatically
                 -- c.f. the NonRec case above
 
@@ -308,7 +308,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
 dmdAnal' env dmd (Let (NonRec id rhs) body)
   = (body_ty2, Let (NonRec id2 rhs') body')
   where
-    (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env id rhs
+    (lazy_fv, id1, rhs') = dmdAnalRhsLetDown NotTopLevel Nothing env dmd id rhs
     env1                 = extendAnalEnv NotTopLevel env id1 (idStrictness id1)
     (body_ty, body')     = dmdAnal env1 dmd body
     (body_ty1, id2)      = annotateBndr env body_ty id1
@@ -329,7 +329,7 @@ dmdAnal' env dmd (Let (NonRec id rhs) body)
 
 dmdAnal' env dmd (Let (Rec pairs) body)
   = let
-        (env', lazy_fv, pairs') = dmdFix NotTopLevel env pairs
+        (env', lazy_fv, pairs') = dmdFix NotTopLevel env dmd pairs
         (body_ty, body')        = dmdAnal env' dmd body
         body_ty1                = deleteFVs body_ty (map fst pairs)
         body_ty2                = addLazyFVs body_ty1 lazy_fv -- see Note [Lazy and unleasheable free variables]
@@ -509,17 +509,17 @@ dmdTransform env var dmd
 -- Recursive bindings
 dmdFix :: TopLevelFlag
        -> AnalEnv                            -- Does not include bindings for this binding
+       -> CleanDemand
        -> [(Id,CoreExpr)]
        -> (AnalEnv, DmdEnv, [(Id,CoreExpr)]) -- Binders annotated with stricness info
 
-dmdFix top_lvl env orig_pairs
+dmdFix top_lvl env let_dmd orig_pairs
   = loop 1 initial_pairs
   where
     bndrs = map fst orig_pairs
 
     -- See Note [Initialising strictness]
     initial_pairs | ae_virgin env = [(setIdStrictness id botSig, rhs) | (id, rhs) <- orig_pairs ]
-
                   | otherwise     = orig_pairs
 
     -- If fixed-point iteration does not yield a result we use this instead
@@ -562,7 +562,7 @@ dmdFix top_lvl env orig_pairs
         my_downRhs (env, lazy_fv) (id,rhs)
           = ((env', lazy_fv'), (id', rhs'))
           where
-            (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env id rhs
+            (lazy_fv1, id', rhs') = dmdAnalRhsLetDown top_lvl (Just bndrs) env let_dmd id rhs
             lazy_fv'              = plusVarEnv_C bothDmd lazy_fv lazy_fv1
             env'                  = extendAnalEnv top_lvl env id (idStrictness id')
 
@@ -621,18 +621,27 @@ dmdAnalTrivialRhs env id rhs fn
 -- This is the LetDown rule in the paper “Higher-Order Cardinality Analysis”.
 dmdAnalRhsLetDown :: TopLevelFlag
            -> Maybe [Id]   -- Just bs <=> recursive, Nothing <=> non-recursive
-           -> AnalEnv -> Id -> CoreExpr
+           -> AnalEnv -> CleanDemand
+           -> Id -> CoreExpr
            -> (DmdEnv, Id, CoreExpr)
 -- Process the RHS of the binding, add the strictness signature
 -- to the Id, and augment the environment with the signature as well.
-dmdAnalRhsLetDown top_lvl rec_flag env id rhs
+dmdAnalRhsLetDown top_lvl rec_flag env let_dmd id rhs
   | Just fn <- unpackTrivial rhs   -- See Note [Demand analysis for trivial right-hand sides]
   = dmdAnalTrivialRhs env id rhs fn
 
   | otherwise
   = (lazy_fv, id', mkLams bndrs' body')
   where
-    (bndrs, body)    = collectBinders rhs
+    (bndrs, body, body_dmd)
+       = case isJoinId_maybe id of
+           Just join_arity  -- See Note [Demand analysis for join points]
+                   | (bndrs, body) <- collectNBinders join_arity rhs
+                   -> (bndrs, body, let_dmd)
+
+           Nothing | (bndrs, body) <- collectBinders rhs
+                   -> (bndrs, body, mkBodyDmd env body)
+
     env_body         = foldl extendSigsWithLam env bndrs
     (body_ty, body') = dmdAnal env_body body_dmd body
     body_ty'         = removeDmdTyArgs body_ty -- zap possible deep CPR info
@@ -642,10 +651,6 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
     id'              = set_idStrictness env id sig_ty
         -- See Note [NOINLINE and strictness]
 
-    -- See Note [Product demands for function body]
-    body_dmd = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
-                 Nothing            -> cleanEvalDmd
-                 Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
 
     -- See Note [Aggregated demand for cardinality]
     rhs_fv1 = case rec_flag of
@@ -667,6 +672,13 @@ dmdAnalRhsLetDown top_lvl rec_flag env id rhs
        || not (isStrictDmd (idDemandInfo id) || ae_virgin env)
           -- See Note [Optimistic CPR in the "virgin" case]
 
+mkBodyDmd :: AnalEnv -> CoreExpr -> CleanDemand
+-- See Note [Product demands for function body]
+mkBodyDmd env body
+  = case deepSplitProductType_maybe (ae_fam_envs env) (exprType body) of
+       Nothing            -> cleanEvalDmd
+       Just (dc, _, _, _) -> cleanEvalProdDmd (dataConRepArity dc)
+
 unpackTrivial :: CoreExpr -> Maybe Id
 -- Returns (Just v) if the arg is really equal to v, modulo
 -- casts, type applications etc
@@ -691,7 +703,37 @@ useLetUp _ (Lam _ _)              = False
 useLetUp _ _                      = True
 
 
-{-
+{- Note [Demand analysis for join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+   g :: (Int,Int) -> Int
+   g (p,q) = p+q
+
+   f :: T -> Int -> Int
+   f x p = g (join j y = (p,y)
+              in case x of
+                   A -> j 3
+                   B -> j 4
+                   C -> (p,7))
+
+If j was a vanilla function definition, we'd analyse its body with
+evalDmd, and think that it was lazy in p.  But for join points we can
+do better!  We know that j's body will (if called at all) be evaluated
+with the demand that consumes the entire join-binding, in this case
+the argument demand from g.  Whizzo!  g evaluates both components of
+its arugment pair, so p will certainly be evaluated if j is called.
+
+For f to be strict in p, we need /all/ paths to evaluate p; in this
+case the C branch does so too, so we are fine.  So, as usual, we need
+to transport demands on free variables to the call site(s).  Compare
+Note [Lazy and unleasheable free variables].
+
+The implementation is easy.  Wwhen analysing a join point, we can
+analyse its body with the demand from the entire join-binding (written
+let_dmd here).
+
+Another win for join points!  Trac #13543.
+
 Note [Demand analysis for trivial right-hand sides]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
diff --git a/testsuite/tests/simplCore/should_compile/T13543.hs b/testsuite/tests/simplCore/should_compile/T13543.hs
new file mode 100644 (file)
index 0000000..88a0b14
--- /dev/null
@@ -0,0 +1,17 @@
+{-# LANGUAGE RankNTypes, GADTs #-}
+
+module Foo where
+
+g :: (Int, Int) -> Int
+{-# NOINLINE g #-}
+g (p,q) = p+q
+
+f :: Int -> Int -> Int -> Int
+f x p q
+  = g (let j y = (p,q)
+           {-# NOINLINE j #-}
+          in
+          case x of
+            2 -> j 3
+            _ -> j 4)
+
diff --git a/testsuite/tests/simplCore/should_compile/T13543.stderr b/testsuite/tests/simplCore/should_compile/T13543.stderr
new file mode 100644 (file)
index 0000000..0519ecb
--- /dev/null
@@ -0,0 +1 @@
\ No newline at end of file
index 7a079c7..1b45930 100644 (file)
@@ -259,3 +259,4 @@ test('T13468',
      normal,
      run_command,
      ['$MAKE -s --no-print-directory T13468'])
+test('T13543', normal, compile, ['-ddump-str-signatures'])