Join points
[ghc.git] / compiler / stranal / WorkWrap.hs
index 2db3a71..0963df0 100644 (file)
@@ -14,6 +14,7 @@ import CoreFVs          ( exprFreeVars )
 import Var
 import Id
 import IdInfo
+import Type
 import UniqSupply
 import BasicTypes
 import DynFlags
@@ -237,6 +238,48 @@ There is an infelicity though.  We may get something like
 The code for f duplicates that for g, without any real benefit. It
 won't really be executed, because calls to f will go via the inlining.
 
+Note [Don't CPR join points]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+There's no point in doing CPR on a join point. If the whole function is getting
+CPR'd, then the case expression around the worker function will get pushed into
+the join point by the simplifier, which will have the same effect that CPR would
+have - the result will be returned in an unboxed tuple.
+
+  f z = let join j x y = (x+1, y+1)
+        in case z of A -> j 1 2
+                     B -> j 2 3
+
+  =>
+
+  f z = case $wf z of (# a, b #) -> (a, b)
+  $wf z = case (let join j x y = (x+1, y+1)
+                in case z of A -> j 1 2
+                             B -> j 2 3) of (a, b) -> (# a, b #)
+
+  =>
+
+  f z = case $wf z of (# a, b #) -> (a, b)
+  $wf z = let join j x y = (# x+1, y+1 #)
+          in case z of A -> j 1 2
+                       B -> j 2 3
+
+Doing CPR on a join point would be tricky anyway, as the worker could not be
+a join point because it would not be tail-called. However, doing the *argument*
+part of W/W still works for join points, since the wrapper body will make a tail
+call:
+
+  f z = let join j x y = x + y
+        in ...
+
+  =>
+
+  f z = let join $wj x# y# = x# +# y#
+                 j x y = case x of I# x# ->
+                         case y of I# y# ->
+                         $wj x# y#
+        in ...
+
 Note [Wrapper activation]
 ~~~~~~~~~~~~~~~~~~~~~~~~~
 When should the wrapper inlining be active?  It must not be active
@@ -289,12 +332,10 @@ tryWW dflags fam_envs is_rec fn_id rhs
         -- being inlined at a call site.
   = return [ (new_fn_id, rhs) ]
 
-  | not loop_breaker
-  , Just stable_unf <- certainlyWillInline dflags fn_unf
+  | Just stable_unf <- certainlyWillInline dflags fn_info
   = return [ (fn_id `setIdUnfolding` stable_unf, rhs) ]
-        -- Note [Don't w/w inline small non-loop-breaker, or INLINE, things]
-        -- NB: use idUnfolding because we don't want to apply
-        --     this criterion to a loop breaker!
+        -- See Note [Don't w/w INLINE things]
+        -- See Note [Don't w/w inline small non-loop-breaker things]
 
   | is_fun
   = splitFun dflags fam_envs new_fn_id fn_info wrap_dmds res_info rhs
@@ -306,18 +347,17 @@ tryWW dflags fam_envs is_rec fn_id rhs
   = return [ (new_fn_id, rhs) ]
 
   where
-    loop_breaker = isStrongLoopBreaker (occInfo fn_info)
     fn_info      = idInfo fn_id
     inline_act   = inlinePragmaActivation (inlinePragInfo fn_info)
-    fn_unf       = unfoldingInfo fn_info
     (wrap_dmds, res_info) = splitStrictSig (strictnessInfo fn_info)
 
     new_fn_id = zapIdUsedOnceInfo (zapIdUsageEnvInfo fn_id)
         -- See Note [Zapping DmdEnv after Demand Analyzer] and
         -- See Note [Zapping Used Once info in WorkWrap]
 
-    is_fun    = notNull wrap_dmds
-    is_thunk  = not is_fun && not (exprIsHNF rhs)
+    is_fun    = notNull wrap_dmds || isJoinId fn_id
+    is_thunk  = not is_fun && not (exprIsHNF rhs) && not (isJoinId fn_id)
+                           && not (isUnliftedType (idType fn_id))
 
 {-
 Note [Zapping DmdEnv after Demand Analyzer]
@@ -366,9 +406,10 @@ splitFun :: DynFlags -> FamInstEnvs -> Id -> IdInfo -> [Demand] -> DmdResult ->
 splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
   = WARN( not (wrap_dmds `lengthIs` arity), ppr fn_id <+> (ppr arity $$ ppr wrap_dmds $$ ppr res_info) ) do
     -- The arity should match the signature
-    stuff <- mkWwBodies dflags fam_envs rhs_fvs fun_ty wrap_dmds res_info
+    stuff <- mkWwBodies dflags fam_envs rhs_fvs mb_join_arity fun_ty
+                        wrap_dmds use_res_info
     case stuff of
-      Just (work_demands, wrap_fn, work_fn) -> do
+      Just (work_demands, join_arity, wrap_fn, work_fn) -> do
         work_uniq <- getUniqueM
         let work_rhs = work_fn rhs
             work_prag = InlinePragma { inl_src = SourceText "{-# INLINE"
@@ -379,7 +420,10 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
               -- idl_inline: copy from fn_id; see Note [Worker-wrapper for INLINABLE functions]
               -- idl_act: see Note [Activation for INLINABLE workers]
               -- inl_rule: it does not make sense for workers to be constructorlike.
-
+            work_join_arity | isJoinId fn_id = Just join_arity
+                            | otherwise      = Nothing
+              -- worker is join point iff wrapper is join point
+              -- (see Note [Don't CPR join points])
             work_id  = mkWorkerId work_uniq fn_id (exprType work_rhs)
                         `setIdOccInfo` occInfo fn_info
                                 -- Copy over occurrence info from parent
@@ -400,6 +444,9 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
 
                         `setIdArity` work_arity
                                 -- Set the arity so that the Core Lint check that the
+                                -- arity is consistent with the demand type goes
+                                -- through
+                        `asJoinId_maybe` work_join_arity
 
             work_arity = length work_demands
 
@@ -408,7 +455,6 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
             worker_demand | single_call = mkWorkerDemand work_arity
                           | otherwise   = topDmd
 
-                                -- arity is consistent with the demand type goes through
 
             wrap_act  = ActiveAfter NoSourceText 0
             wrap_rhs  = wrap_fn work_id
@@ -422,7 +468,7 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
 
             wrap_id   = fn_id `setIdUnfolding`  mkWwInlineRule wrap_rhs arity
                               `setInlinePragma` wrap_prag
-                              `setIdOccInfo`    NoOccInfo
+                              `setIdOccInfo`    noOccInfo
                                 -- Zap any loop-breaker-ness, to avoid bleating from Lint
                                 -- about a loop breaker with an INLINE rule
 
@@ -433,6 +479,7 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
 
       Nothing -> return [(fn_id, rhs)]
   where
+    mb_join_arity   = isJoinId_maybe fn_id
     rhs_fvs         = exprFreeVars rhs
     fun_ty          = idType fn_id
     inl_prag        = inlinePragInfo fn_info
@@ -441,7 +488,11 @@ splitFun dflags fam_envs fn_id fn_info wrap_dmds res_info rhs
                     -- The arity is set by the simplifier using exprEtaExpandArity
                     -- So it may be more than the number of top-level-visible lambdas
 
-    work_res_info = case returnsCPR_maybe res_info of
+    use_res_info  | isJoinId fn_id = topRes -- Note [Don't CPR join points]
+                  | otherwise      = res_info
+    work_res_info | isJoinId fn_id = res_info -- Worker remains CPR-able
+                  | otherwise
+                  = case returnsCPR_maybe res_info of
                        Just _  -> topRes    -- Cpr stuff done by wrapper; kill it here
                        Nothing -> res_info  -- Preserve exception/divergence
 
@@ -540,7 +591,8 @@ then the splitting will go deeper too.
 
 splitThunk :: DynFlags -> FamInstEnvs -> RecFlag -> Var -> Expr Var -> UniqSM [(Var, Expr Var)]
 splitThunk dflags fam_envs is_rec fn_id rhs
-  = do { (useful,_, wrap_fn, work_fn) <- mkWWstr dflags fam_envs [fn_id]
+  = ASSERT(not (isJoinId fn_id))
+    do { (useful,_, wrap_fn, work_fn) <- mkWWstr dflags fam_envs [fn_id]
        ; let res = [ (fn_id, Let (NonRec fn_id rhs) (wrap_fn (work_fn (Var fn_id)))) ]
        ; if useful then ASSERT2( isNonRec is_rec, ppr fn_id ) -- The thunk must be non-recursive
                    return res