Fix join-point decision
authorSimon Peyton Jones <simonpj@microsoft.com>
Tue, 9 Jan 2018 13:53:09 +0000 (13:53 +0000)
committerSimon Peyton Jones <simonpj@microsoft.com>
Tue, 9 Jan 2018 16:25:53 +0000 (16:25 +0000)
This patch moves the "ok_unfolding" test
   from  CoreOpt.joinPointBinding_maybe
   to    OccurAnal.decideJoinPointHood

Previously the occurrence analyser was deciding to make
something a join point, but the simplifier was reversing
that decision, which made the decision about /other/ bindings
invalid.

Fixes Trac #14650.

compiler/coreSyn/CoreOpt.hs
compiler/simplCore/OccurAnal.hs
testsuite/tests/simplCore/should_compile/T14650.hs [new file with mode: 0644]
testsuite/tests/simplCore/should_compile/all.T

index 4240647..0f35e8f 100644 (file)
@@ -22,7 +22,7 @@ module CoreOpt (
 
 import GhcPrelude
 
-import CoreArity( joinRhsArity, etaExpandToJoinPoint )
+import CoreArity( etaExpandToJoinPoint )
 
 import CoreSyn
 import CoreSubst
@@ -646,58 +646,18 @@ joinPointBinding_maybe bndr rhs
   = Just (bndr, rhs)
 
   | AlwaysTailCalled join_arity <- tailCallInfo (idOccInfo bndr)
-  , not (bad_unfolding join_arity (idUnfolding bndr))
   , (bndrs, body) <- etaExpandToJoinPoint join_arity rhs
   = Just (bndr `asJoinId` join_arity, mkLams bndrs body)
 
   | otherwise
   = Nothing
 
-  where
-    -- bad_unfolding returns True if we should /not/ convert a non-join-id
-    -- into a join-id, even though it is AlwaysTailCalled
-    -- See Note [Join points and INLINE pragmas]
-    bad_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs })
-      = isStableSource src && join_arity > joinRhsArity rhs
-    bad_unfolding _ (DFunUnfolding {})
-      = True
-    bad_unfolding _ _
-      = False
-
 joinPointBindings_maybe :: [(InBndr, InExpr)] -> Maybe [(InBndr, InExpr)]
 joinPointBindings_maybe bndrs
   = mapM (uncurry joinPointBinding_maybe) bndrs
 
 
-{- Note [Join points and INLINE pragmas]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider
-   f x = let g = \x. not  -- Arity 1
-             {-# INLINE g #-}
-         in case x of
-              A -> g True True
-              B -> g True False
-              C -> blah2
-
-Here 'g' is always tail-called applied to 2 args, but the stable
-unfolding captured by the INLINE pragma has arity 1.  If we try to
-convert g to be a join point, its unfolding will still have arity 1
-(since it is stable, and we don't meddle with stable unfoldings), and
-Lint will complain (see Note [Invariants on join points], (2a), in
-CoreSyn.  Trac #13413.
-
-Moreover, since g is going to be inlined anyway, there is no benefit
-from making it a join point.
-
-If it is recursive, and uselessly marked INLINE, this will stop us
-making it a join point, which is annoying.  But occasionally
-(notably in class methods; see Note [Instances and loop breakers] in
-TcInstDcls) we mark recursive things as INLINE but the recursion
-unravels; so ignoring INLINE pragmas on recursive things isn't good
-either.
-
-
-************************************************************************
+{- *********************************************************************
 *                                                                      *
          exprIsConApp_maybe
 *                                                                      *
index bcc8410..b0987d5 100644 (file)
@@ -25,6 +25,7 @@ import CoreSyn
 import CoreFVs
 import CoreUtils        ( exprIsTrivial, isDefaultAlt, isExpandableApp,
                           stripTicksTopE, mkTicks )
+import CoreArity        ( joinRhsArity )
 import Id
 import IdInfo
 import Name( localiseName )
@@ -2664,9 +2665,8 @@ tagRecBinders lvl body_uds triples
            , AlwaysTailCalled arity <- tailCallInfo occ
            = Just arity
            | otherwise
-           = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if we're
-                                       -- making join points!
-             Nothing
+           = ASSERT(not will_be_joins) -- Should be AlwaysTailCalled if
+             Nothing                   -- we are making join points!
 
      -- 3. Compute final usage details from adjusted RHS details
      adj_uds   = body_uds +++ combineUsageDetailsList rhs_udss'
@@ -2694,10 +2694,15 @@ setBinderOcc occ_info bndr
 
 -- | Decide whether some bindings should be made into join points or not.
 -- Returns `False` if they can't be join points. Note that it's an
--- all-or-nothing decision, as if multiple binders are given, they're assumed to
--- be mutually recursive.
+-- all-or-nothing decision, as if multiple binders are given, they're
+-- assumed to be mutually recursive.
 --
--- See Note [Invariants for join points] in CoreSyn.
+-- It must, however, be a final decision. If we say "True" for 'f',
+-- and then subsequently decide /not/ make 'f' into a join point, then
+-- the decision about another binding 'g' might be invalidated if (say)
+-- 'f' tail-calls 'g'.
+--
+-- See Note [Invariants on join points] in CoreSyn.
 decideJoinPointHood :: TopLevelFlag -> UsageDetails
                     -> [CoreBndr]
                     -> Bool
@@ -2721,6 +2726,9 @@ decideJoinPointHood NotTopLevel usage bndrs
         AlwaysTailCalled arity <- tailCallInfo (lookupDetails usage bndr)
       , -- Invariant 1 as applied to LHSes of rules
         all (ok_rule arity) (idCoreRules bndr)
+        -- Invariant 2a: stable unfoldings
+        -- See Note [Join points and INLINE pragmas]
+      , ok_unfolding arity (realIdUnfolding bndr)
         -- Invariant 4: Satisfies polymorphism rule
       , isValidJoinPointType arity (idType bndr)
       = True
@@ -2732,14 +2740,52 @@ decideJoinPointHood NotTopLevel usage bndrs
       = args `lengthIs` join_arity
         -- Invariant 1 as applied to LHSes of rules
 
+    -- ok_unfolding returns False if we should /not/ convert a non-join-id
+    -- into a join-id, even though it is AlwaysTailCalled
+    ok_unfolding join_arity (CoreUnfolding { uf_src = src, uf_tmpl = rhs })
+      = not (isStableSource src && join_arity > joinRhsArity rhs)
+    ok_unfolding _ (DFunUnfolding {})
+      = False
+    ok_unfolding _ _
+      = True
+
 willBeJoinId_maybe :: CoreBndr -> Maybe JoinArity
 willBeJoinId_maybe bndr
-  | AlwaysTailCalled arity <- tailCallInfo (idOccInfo bndr)
-  = Just arity
-  | otherwise
-  = isJoinId_maybe bndr
+  = case tailCallInfo (idOccInfo bndr) of
+      AlwaysTailCalled arity -> Just arity
+      _                      -> isJoinId_maybe bndr
+
+
+{- Note [Join points and INLINE pragmas]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+   f x = let g = \x. not  -- Arity 1
+             {-# INLINE g #-}
+         in case x of
+              A -> g True True
+              B -> g True False
+              C -> blah2
+
+Here 'g' is always tail-called applied to 2 args, but the stable
+unfolding captured by the INLINE pragma has arity 1.  If we try to
+convert g to be a join point, its unfolding will still have arity 1
+(since it is stable, and we don't meddle with stable unfoldings), and
+Lint will complain (see Note [Invariants on join points], (2a), in
+CoreSyn.  Trac #13413.
+
+Moreover, since g is going to be inlined anyway, there is no benefit
+from making it a join point.
+
+If it is recursive, and uselessly marked INLINE, this will stop us
+making it a join point, which is annoying.  But occasionally
+(notably in class methods; see Note [Instances and loop breakers] in
+TcInstDcls) we mark recursive things as INLINE but the recursion
+unravels; so ignoring INLINE pragmas on recursive things isn't good
+either.
+
+See Invariant 2a of Note [Invariants on join points] in CoreSyn
+
 
-{-
 ************************************************************************
 *                                                                      *
 \subsection{Operations over OccInfo}
diff --git a/testsuite/tests/simplCore/should_compile/T14650.hs b/testsuite/tests/simplCore/should_compile/T14650.hs
new file mode 100644 (file)
index 0000000..b9eac20
--- /dev/null
@@ -0,0 +1,76 @@
+module MergeSort (\r
+  msortBy\r
+ ) where\r
+\r
+infixl 7 :%\r
+infixr 6 :&\r
+\r
+data LenList a = LL {-# UNPACK #-} !Int Bool [a]\r
+\r
+data LenListAnd a b = {-# UNPACK #-} !(LenList a) :% b\r
+\r
+data Stack a\r
+  = End\r
+  | {-# UNPACK #-} !(LenList a) :& (Stack a)\r
+\r
+msortBy :: (a -> a -> Ordering) -> [a] -> [a]\r
+msortBy cmp = mergeSplit End where\r
+  splitAsc n _ _ _ | n `seq` False = undefined\r
+  splitAsc n as _ [] = LL n True as :% []\r
+  splitAsc n as a bs@(b:bs') = case cmp a b of\r
+    GT -> LL n False as :% bs\r
+    _  -> splitAsc (n + 1) as b bs'\r
+\r
+  splitDesc n _ _ _ | n `seq` False = undefined\r
+  splitDesc n rs a [] = LL n True (a:rs) :% []\r
+  splitDesc n rs a bs@(b:bs') = case cmp a b of\r
+    GT -> splitDesc (n + 1) (a:rs) b bs'\r
+    _  -> LL n True (a:rs) :% bs\r
+\r
+  mergeLL (LL na fa as) (LL nb fb bs) = LL (na + nb) True $ mergeLs na as nb bs where\r
+    mergeLs nx  _ ny  _ | nx `seq` ny `seq` False = undefined\r
+    mergeLs  0  _ ny ys = if fb then ys else take ny ys\r
+    mergeLs  _ [] ny ys = if fb then ys else take ny ys\r
+    mergeLs nx xs  0  _ = if fa then xs else take nx xs\r
+    mergeLs nx xs  _ [] = if fa then xs else take nx xs\r
+    mergeLs nx xs@(x:xs') ny ys@(y:ys') = case cmp x y of\r
+      GT -> y:mergeLs nx xs (ny - 1) ys'\r
+      _  -> x:mergeLs (nx - 1) xs' ny ys\r
+\r
+  push ssx px@(LL nx _ _) = case ssx of\r
+    End -> px :% ssx\r
+    py@(LL ny _ _) :& ssy -> case ssy of\r
+      End\r
+        | nx >= ny -> mergeLL py px :% ssy\r
+      pz@(LL nz _ _) :& ssz\r
+        | nx >= ny || nx + ny >= nz -> case nx > nz of\r
+            False -> push ssy $ mergeLL py px\r
+            _     -> case push ssz $ mergeLL pz py of\r
+              pz' :% ssz' -> push (pz' :& ssz') px\r
+      _ -> px :% ssx\r
+\r
+  mergeAll _ px | px `seq` False = undefined\r
+  mergeAll ssx px@(LL nx _ xs) = case ssx of\r
+    End -> xs\r
+    py@(LL _ _ _) :& ssy -> case ssy of\r
+      End -> case mergeLL py px of\r
+        LL _ _ xys -> xys\r
+      pz@(LL nz _ _) :& ssz -> case nx > nz of\r
+        False -> mergeAll ssy $ mergeLL py px\r
+        _     -> case push ssz $ mergeLL pz py of\r
+          pz' :% ssz' -> mergeAll (pz' :& ssz') px\r
+\r
+  mergeSplit ss _ | ss `seq` False = undefined\r
+  mergeSplit ss [] = case ss of\r
+    End -> []\r
+    px :& ss' -> mergeAll ss' px\r
+  mergeSplit ss as@(a:as') = case as' of\r
+    [] -> mergeAll ss $ LL 1 True as\r
+    b:bs -> case cmp a b of\r
+      GT -> case splitDesc 2 [a] b bs of\r
+        px :% rs -> case push ss px of\r
+          px' :% ss' -> mergeSplit (px' :& ss') rs\r
+      _  -> case splitAsc 2 as b bs of\r
+        px :% rs -> case push ss px of\r
+          px' :% ss' -> mergeSplit (px' :& ss') rs\r
+  {-# INLINABLE mergeSplit #-}\r
index e51e8f7..e681ca7 100644 (file)
@@ -289,3 +289,4 @@ test('T14152a', [extra_files(['T14152.hs']), pre_cmd('cp T14152.hs T14152a.hs'),
                  only_ways(['optasm']), check_errmsg(r'dead code') ],
                 compile, ['-fno-exitification -ddump-simpl'])
 test('T13990', normal, compile, ['-dcore-lint -O'])
+test('T14650', normal, compile, ['-O2'])