Fix loss-of-SpecConstr bug
authorSimon Peyton Jones <simonpj@microsoft.com>
Tue, 2 May 2017 11:04:44 +0000 (12:04 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Tue, 2 May 2017 11:04:44 +0000 (12:04 +0100)
This bug, reported in Trac #13623 has been present since

  commit b8b3e30a6eedf9f213b8a718573c4827cfa230ba
  Author: Edward Z. Yang <ezyang@cs.stanford.edu>
  Date:   Fri Jun 24 11:03:47 2016 -0700

      Axe RecFlag on TyCons.

SpecConstr tries not to specialise indefinitely, and had a
limit (see Note [Limit recursive specialisation]) that made
use of info about whether or not a data constructor was
"recursive".  This info vanished in the above commit, making
the limit fire much more often -- and indeed it fired in this
test case, in a situation where specialisation is /highly/
desirable.

I refactored the test, to look instead at the number of
iterations of the loop of "and now specialise calls that
arise from the specialisation".  Actually less code, and
more robust.

I also added record field names to a couple of constructors,
and renamed RuleInfo to SpecInfo.

compiler/specialise/SpecConstr.hs
testsuite/tests/perf/should_run/T13623.hs [new file with mode: 0644]
testsuite/tests/perf/should_run/T13623.stdout [new file with mode: 0644]
testsuite/tests/perf/should_run/all.T

index 735a71a..dd6f191 100644 (file)
@@ -575,14 +575,19 @@ which can continue indefinitely.
 Roman's suggestion to fix this was to stop after a couple of times on recursive types,
 but still specialising on non-recursive types as much as possible.
 
-To implement this, we count the number of recursive constructors in each
-function argument. If the maximum is greater than the specConstrRecursive limit,
-do not specialise on that pattern.
+To implement this, we count the number of times we have gone round the
+"specialise recursively" loop ('go' in 'specRec').  Once have gone round
+more than N times (controlled by -fspec-constr-recursive=N) we check
 
-This is only necessary when ForceSpecConstr is on: otherwise the specConstrCount
-will force termination anyway.
+  - If sc_force is off, and sc_count is (Just max) then we don't
+    need to do anything: trim_pats will limit the number of specs
 
-See Trac #5550.
+  - Otherwise check if any function has now got more than (sc_count env)
+    specialisations.  If sc_count is "no limit" then we arbitrarily
+    choose 10 as the limit (ugh).
+
+See Trac #5550.   Also Trac #13623, where this test had become over-agressive,
+and we lost a wonderful specialisation that we really wanted!
 
 Note [NoSpecConstr]
 ~~~~~~~~~~~~~~~~~~~
@@ -793,7 +798,10 @@ the function is applied to a data constructor.
 data ScEnv = SCE { sc_dflags    :: DynFlags,
                    sc_module    :: !Module,
                    sc_size      :: Maybe Int,   -- Size threshold
+                                                -- Nothing => no limit
+
                    sc_count     :: Maybe Int,   -- Max # of specialisations for any one fn
+                                                -- Nothing => no limit
                                                 -- See Note [Avoiding exponential blowup]
 
                    sc_recursive :: Int,         -- Max # of specialisations over recursive type.
@@ -1424,15 +1432,16 @@ scRecRhs env (bndr,rhs)
                 -- Two pats are the same if they match both ways
 
 ----------------------
-ruleInfoBinds :: RhsInfo -> [OneSpec] -> [(Id,CoreExpr)]
-ruleInfoBinds (RI { ri_fn = fn, ri_new_rhs = new_rhs }) specs
-  = [(id,rhs) | OS _ _ id rhs <- specs] ++
+ruleInfoBinds :: RhsInfo -> SpecInfo -> [(Id,CoreExpr)]
+ruleInfoBinds (RI { ri_fn = fn, ri_new_rhs = new_rhs })
+              (SI { si_specs = specs })
+  = [(id,rhs) | OS { os_id = id, os_rhs = rhs } <- specs] ++
               -- First the specialised bindings
 
     [(fn `addIdSpecialisations` rules, new_rhs)]
               -- And now the original binding
   where
-    rules = [r | OS _ r _ _ <- specs]
+    rules = [r | OS { os_rule = r } <- specs]
 
 {-
 ************************************************************************
@@ -1452,12 +1461,13 @@ data RhsInfo
        , ri_arg_occs  :: [ArgOcc]      -- Info on how the xs occur in body
     }
 
-data RuleInfo = SI [OneSpec]            -- The specialisations we have generated
+data SpecInfo       -- Info about specialisations for a particular Id
+  = SI { si_specs :: [OneSpec]          -- The specialisations we have generated
 
-                   Int                  -- Length of specs; used for numbering them
+       , si_n_specs :: Int              -- Length of si_specs; used for numbering them
 
-                   (Maybe ScUsage)      -- Just cs  => we have not yet used calls in the
-                                        --             from calls in the *original* RHS as
+       , si_mb_unspec :: Maybe ScUsage  -- Just cs  => we have not yet used calls in the
+       }                                --             from calls in the *original* RHS as
                                         --             seeds for new specialisations;
                                         --             if you decide to do so, here is the
                                         --             RHS usage (which has not yet been
@@ -1467,67 +1477,93 @@ data RuleInfo = SI [OneSpec]            -- The specialisations we have generated
                                         -- See Note [spec_usg includes rhs_usg]
 
         -- One specialisation: Rule plus definition
-data OneSpec  = OS CallPat              -- Call pattern that generated this specialisation
-                   CoreRule             -- Rule connecting original id with the specialisation
-                   OutId OutExpr        -- Spec id + its rhs
+data OneSpec =
+  OS { os_pat  :: CallPat    -- Call pattern that generated this specialisation
+     , os_rule :: CoreRule   -- Rule connecting original id with the specialisation
+     , os_id   :: OutId      -- Spec id
+     , os_rhs  :: OutExpr }  -- Spec rhs
 
+noSpecInfo :: SpecInfo
+noSpecInfo = SI { si_specs = [], si_n_specs = 0, si_mb_unspec = Nothing }
 
 ----------------------
 specNonRec :: ScEnv
            -> ScUsage         -- Body usage
            -> RhsInfo         -- Structure info usage info for un-specialised RHS
-           -> UniqSM (ScUsage, [OneSpec])       -- Usage from RHSs (specialised and not)
-                                                --     plus details of specialisations
+           -> UniqSM (ScUsage, SpecInfo)       -- Usage from RHSs (specialised and not)
+                                               --     plus details of specialisations
 
 specNonRec env body_usg rhs_info
-  = do { (spec_usg, SI specs _ _) <- specialise env (scu_calls body_usg)
-                                                rhs_info
-                                                (SI [] 0 (Just (ri_rhs_usg rhs_info)))
-       ; return (spec_usg, specs) }
+  = specialise env (scu_calls body_usg) rhs_info
+               (noSpecInfo { si_mb_unspec = Just (ri_rhs_usg rhs_info) })
 
 ----------------------
 specRec :: TopLevelFlag -> ScEnv
-        -> ScUsage                             -- Body usage
-        -> [RhsInfo]                           -- Structure info and usage info for un-specialised RHSs
-        -> UniqSM (ScUsage, [[OneSpec]])       -- Usage from all RHSs (specialised and not)
-                                               --     plus details of specialisations
+        -> ScUsage                         -- Body usage
+        -> [RhsInfo]                       -- Structure info and usage info for un-specialised RHSs
+        -> UniqSM (ScUsage, [SpecInfo])    -- Usage from all RHSs (specialised and not)
+                                           --     plus details of specialisations
 
 specRec top_lvl env body_usg rhs_infos
-  = do { (spec_usg, spec_infos) <- go seed_calls nullUsage init_spec_infos
-       ; return (spec_usg, [ s | SI s _ _ <- spec_infos ]) }
+  = go 1 seed_calls nullUsage init_spec_infos
   where
     (seed_calls, init_spec_infos)    -- Note [Seeding top-level recursive groups]
        | isTopLevel top_lvl
        , any (isExportedId . ri_fn) rhs_infos   -- Seed from body and RHSs
-       = (all_calls,     [SI [] 0 Nothing | _ <- rhs_infos])
+       = (all_calls,     [noSpecInfo | _ <- rhs_infos])
        | otherwise                              -- Seed from body only
-       = (calls_in_body, [SI [] 0 (Just (ri_rhs_usg ri)) | ri <- rhs_infos])
+       = (calls_in_body, [noSpecInfo { si_mb_unspec = Just (ri_rhs_usg ri) }
+                         | ri <- rhs_infos])
 
     calls_in_body = scu_calls body_usg
     calls_in_rhss = foldr (combineCalls . scu_calls . ri_rhs_usg) emptyVarEnv rhs_infos
     all_calls = calls_in_rhss `combineCalls` calls_in_body
 
     -- Loop, specialising, until you get no new specialisations
-    go seed_calls usg_so_far spec_infos
+    go :: Int   -- Which iteration of the "until no new specialisations"
+                -- loop we are on; first iteration is 1
+       -> CallEnv   -- Seed calls
+                    -- Two accumulating parameters:
+       -> ScUsage      -- Usage from earlier specialisations
+       -> [SpecInfo]   -- Details of specialisations so far
+       -> UniqSM (ScUsage, [SpecInfo])
+    go n_iter seed_calls usg_so_far spec_infos
       | isEmptyVarEnv seed_calls
-      = -- pprTrace "specRec" (vcat [ ppr (map ri_fn rhs_infos)
-        --                         , ppr seed_calls
-        --                         , ppr body_usg ]) $
+      = -- pprTrace "specRec1" (vcat [ ppr (map ri_fn rhs_infos)
+        --                           , ppr seed_calls
+        --                           , ppr body_usg ]) $
+        return (usg_so_far, spec_infos)
+
+      -- Limit recursive specialisation
+      -- See Note [Limit recursive specialisation]
+      | n_iter > sc_recursive env  -- Too many iterations of the 'go' loop
+      , sc_force env || isNothing (sc_count env)
+           -- If both of these are false, the sc_count
+           -- threshold will prevent non-termination
+      , any ((> the_limit) . si_n_specs) spec_infos
+      = -- pprTrace "specRec2" (ppr (map (map os_pat . si_specs) spec_infos)) $
         return (usg_so_far, spec_infos)
+
       | otherwise
       = do  { specs_w_usg <- zipWithM (specialise env seed_calls) rhs_infos spec_infos
             ; let (extra_usg_s, new_spec_infos) = unzip specs_w_usg
                   extra_usg = combineUsages extra_usg_s
                   all_usg   = usg_so_far `combineUsage` extra_usg
-            ; go (scu_calls extra_usg) all_usg new_spec_infos }
+            ; go (n_iter + 1) (scu_calls extra_usg) all_usg new_spec_infos }
+
+    -- See Note [Limit recursive specialisation]
+    the_limit = case sc_count env of
+                  Nothing  -> 10    -- Ugh!
+                  Just max -> max
+
 
 ----------------------
 specialise
    :: ScEnv
    -> CallEnv                     -- Info on newly-discovered calls to this function
    -> RhsInfo
-   -> RuleInfo                    -- Original RHS plus patterns dealt with
-   -> UniqSM (ScUsage, RuleInfo)  -- New specialised versions and their usage
+   -> SpecInfo                    -- Original RHS plus patterns dealt with
+   -> UniqSM (ScUsage, SpecInfo)  -- New specialised versions and their usage
 
 -- See Note [spec_usg includes rhs_usg]
 
@@ -1540,7 +1576,8 @@ specialise
 
 specialise env bind_calls (RI { ri_fn = fn, ri_lam_bndrs = arg_bndrs
                               , ri_lam_body = body, ri_arg_occs = arg_occs })
-               spec_info@(SI specs spec_count mb_unspec)
+               spec_info@(SI { si_specs = specs, si_n_specs = spec_count
+                             , si_mb_unspec = mb_unspec })
   | isBottomingId fn      -- Note [Do not specialise diverging functions]
                           -- and do not generate specialisation seeds from its RHS
   = -- pprTrace "specialise bot" (ppr fn) $
@@ -1550,7 +1587,7 @@ specialise env bind_calls (RI { ri_fn = fn, ri_lam_bndrs = arg_bndrs
     || null arg_bndrs                     -- Only specialise functions
   = -- pprTrace "specialise inactive" (ppr fn) $
     case mb_unspec of    -- Behave as if there was a single, boring call
-      Just rhs_usg -> return (rhs_usg, SI specs spec_count Nothing)
+      Just rhs_usg -> return (rhs_usg, spec_info { si_mb_unspec = Nothing })
                          -- See Note [spec_usg includes rhs_usg]
       Nothing      -> return (nullUsage, spec_info)
 
@@ -1583,12 +1620,15 @@ specialise env bind_calls (RI { ri_fn = fn, ri_lam_bndrs = arg_bndrs
                       Just rhs_usg | boring_call -> (spec_usg `combineUsage` rhs_usg, Nothing)
                       _                          -> (spec_usg,                      mb_unspec)
 
---        ; pprTrace "specialise return }" (vcat [ ppr fn
---                                               , text "boring_call:" <+> ppr boring_call
---                                               , text "new calls:" <+> ppr (scu_calls new_usg)]) $
-          ; return (new_usg, SI (new_specs ++ specs)
-                                (spec_count + n_pats)
-                                mb_unspec') }
+--        ; pprTrace "specialise return }"
+--             (vcat [ ppr fn
+--                   , text "boring_call:" <+> ppr boring_call
+--                   , text "new calls:" <+> ppr (scu_calls new_usg)]) $
+--          return ()
+
+          ; return (new_usg, SI { si_specs = new_specs ++ specs
+                                , si_n_specs = spec_count + n_pats
+                                , si_mb_unspec = mb_unspec' }) }
 
   | otherwise  -- No new seeds, so return nullUsage
   = return (nullUsage, spec_info)
@@ -1640,7 +1680,8 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
               -- changes (#4012).
               rule_name  = mkFastString ("SC:" ++ occNameString fn_occ ++ show rule_number)
               spec_name  = mkInternalName spec_uniq spec_occ fn_loc
---      ; pprTrace "{spec_one" (ppr (sc_count env) <+> ppr fn <+> ppr pats <+> text "-->" <+> ppr spec_name) $
+--      ; pprTrace "{spec_one" (ppr (sc_count env) <+> ppr fn
+--                              <+> ppr pats <+> text "-->" <+> ppr spec_name) $
 --        return ()
 
         -- Specialise the body
@@ -1679,7 +1720,9 @@ spec_one env fn arg_bndrs body (call_pat@(qvars, pats), rule_number)
               rule       = mkRule this_mod True {- Auto -} True {- Local -}
                                   rule_name inline_act fn_name qvars pats rule_rhs
                            -- See Note [Transfer activation]
-        ; return (spec_usg, OS call_pat rule spec_id spec_rhs) }
+        ; return (spec_usg, OS { os_pat = call_pat, os_rule = rule
+                               , os_id = spec_id
+                               , os_rhs = spec_rhs }) }
 
 
 -- See Note [Strictness information in worker binders]
@@ -1720,7 +1763,7 @@ calcSpecStrictness fn qvars pats
 Note [spec_usg includes rhs_usg]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 In calls to 'specialise', the returned ScUsage must include the rhs_usg in
-the passed-in RuleInfo, unless there are no calls at all to the function.
+the passed-in SpecInfo, unless there are no calls at all to the function.
 
 The caller can, indeed must, assume this.  He should not combine in rhs_usg
 himself, or he'll get rhs_usg twice -- and that can lead to an exponential
@@ -1844,27 +1887,23 @@ type CallPat = ([Var], [CoreExpr])      -- Quantified variables and arguments
                                         -- See Note [SpecConstr call patterns]
 
 callsToNewPats :: ScEnv -> Id
-               -> RuleInfo
+               -> SpecInfo
                -> [ArgOcc] -> [Call]
                -> UniqSM (Bool, [CallPat])
         -- Result has no duplicate patterns,
         -- nor ones mentioned in done_pats
         -- Bool indicates that there was at least one boring pattern
-callsToNewPats env fn spec_info@(SI done_specs _ _) bndr_occs calls
+callsToNewPats env fn spec_info@(SI { si_specs = done_specs }) bndr_occs calls
   = do  { mb_pats <- mapM (callToPats env bndr_occs) calls
 
         ; let have_boring_call = any isNothing mb_pats
 
-              good_pats :: [(CallPat, ValueEnv)]
+              good_pats :: [CallPat]
               good_pats = catMaybes mb_pats
 
-              -- Remove patterns that use too many recursive constructors
-              no_recursive = map fst (filterOut (is_too_recursive env) good_pats)
-
               -- Remove patterns we have already done
-              new_pats = filterOut is_done no_recursive
-              done_pats = [p | OS p _ _ _ <- done_specs]
-              is_done p = any (samePat p) done_pats
+              new_pats = filterOut is_done good_pats
+              is_done p = any (samePat p . os_pat) done_specs
 
               -- Remove duplicates
               non_dups = nubBy samePat new_pats
@@ -1880,22 +1919,24 @@ callsToNewPats env fn spec_info@(SI done_specs _ _) bndr_occs calls
               trimmed_pats = trim_pats env fn spec_info small_pats
 
 --        ; pprTrace "callsToPats" (vcat [ text "calls:" <+> ppr calls
---                                       , text "good_pats:" <+> ppr good_pats
---                                       , text "no_recursive:" <+> ppr no_recursive ])  $
+--                                       , text "good_pats:" <+> ppr good_pats ]) $
+--          return ()
 
         ; return (have_boring_call, trimmed_pats) }
 
 
-trim_pats :: ScEnv -> Id -> RuleInfo -> [CallPat] -> [CallPat]
+trim_pats :: ScEnv -> Id -> SpecInfo -> [CallPat] -> [CallPat]
 -- See Note [Choosing patterns]
-trim_pats env fn (SI _ done_spec_count _) pats
+trim_pats env fn (SI { si_n_specs = done_spec_count }) pats
   | sc_force env
     || isNothing mb_scc
     || n_remaining >= n_pats
-  = pats
+  = pats          -- No need to trim
+
   | otherwise
-  = emit_trace $
+  = emit_trace $  -- Need to trim, so keep the best ones
     take n_remaining sorted_pats
+
   where
     n_pats         = length pats
     spec_count'    = n_pats + done_spec_count
@@ -1937,27 +1978,7 @@ trim_pats env fn (SI _ done_spec_count _) pats
                , text "Discarding:" <+> ppr (drop n_remaining sorted_pats) ]
 
 
-is_too_recursive :: ScEnv -> (CallPat, ValueEnv) -> Bool
-    -- Count the number of recursive constructors in a call pattern,
-    -- filter out if there are more than the maximum.
-    -- This is only necessary if ForceSpecConstr is in effect:
-    -- otherwise specConstrCount will cause specialisation to terminate.
-    -- See Note [Limit recursive specialisation]
--- TODO: make me more accurate
-is_too_recursive env ((_,exprs), val_env)
- = sc_force env && maximum (map go exprs) > sc_recursive env
- where
-  go e
-   | Just (ConVal (DataAlt _) args) <- isValue val_env e
-   = 1 + sum (map go args)
-
-   | App f a                         <- e
-   = go f + go a
-
-   | otherwise
-   = 0
-
-callToPats :: ScEnv -> [ArgOcc] -> Call -> UniqSM (Maybe (CallPat, ValueEnv))
+callToPats :: ScEnv -> [ArgOcc] -> Call -> UniqSM (Maybe CallPat)
         -- The [Var] is the variables to quantify over in the rule
         --      Type variables come first, since they may scope
         --      over the following term variables
@@ -1993,7 +2014,7 @@ callToPats env bndr_occs (Call _ args con_env)
 
         ; -- pprTrace "callToPats"  (ppr args $$ ppr bndr_occs) $
           if interesting
-          then return (Just ((qvars', pats), con_env))
+          then return (Just (qvars', pats))
           else return Nothing }
 
     -- argToPat takes an actual argument, and returns an abstracted
diff --git a/testsuite/tests/perf/should_run/T13623.hs b/testsuite/tests/perf/should_run/T13623.hs
new file mode 100644 (file)
index 0000000..7a048b2
--- /dev/null
@@ -0,0 +1,82 @@
+{-# LANGUAGE BangPatterns, GADTs, ExistentialQuantification #-}
+{-# OPTIONS_GHC -cpp #-}
+
+module Main where
+
+
+import GHC.Types
+
+
+foo :: Int -> Int -> IO Int
+foo = \i j -> sfoldl' (+) 0 $ xs i j +++ ys i j
+  where xs k l = senumFromStepN k l 200000
+        ys k l = senumFromStepN k l 300000
+        {-# Inline xs #-}
+        {-# Inline ys #-}
+{-# Inline foo #-}
+
+
+main = do { n <- foo 1 1; print n }
+
+
+
+-------------------------------------------------------------------------------
+-- vector junk
+-------------------------------------------------------------------------------
+
+#define PHASE_FUSED [1]
+#define PHASE_INNER [0]
+
+#define INLINE_FUSED INLINE PHASE_FUSED
+#define INLINE_INNER INLINE PHASE_INNER
+
+data Stream m a = forall s. Stream (s -> m (Step s a)) s
+
+data Step s a where
+  Yield :: a -> s -> Step s a
+  Skip  :: s -> Step s a
+  Done  :: Step s a
+
+senumFromStepN :: (Num a, Monad m) => a -> a -> Int -> Stream m a
+{-# INLINE_FUSED senumFromStepN #-}
+senumFromStepN x y n = x `seq` y `seq` n `seq` Stream step (x,n)
+  where
+    {-# INLINE_INNER step #-}
+    step (w,m) | m > 0     = return $ Yield w (w+y,m-1)
+               | otherwise = return $ Done
+
+sfoldl' :: Monad m => (a -> b -> a) -> a -> Stream m b -> m a
+{-# INLINE sfoldl' #-}
+sfoldl' f = sfoldlM' (\a b -> return (f a b))
+
+sfoldlM' :: Monad m => (a -> b -> m a) -> a -> Stream m b -> m a
+{-# INLINE_FUSED sfoldlM' #-}
+sfoldlM' m w (Stream step t) = foldlM'_loop SPEC w t
+  where
+    foldlM'_loop !_ z s
+      = z `seq`
+        do
+          r <- step s
+          case r of
+            Yield x s' -> do { z' <- m z x; foldlM'_loop SPEC z' s' }
+            Skip    s' -> foldlM'_loop SPEC z s'
+            Done       -> return z
+
+infixr 5 +++
+(+++) :: Monad m => Stream m a -> Stream m a -> Stream m a
+{-# INLINE_FUSED (+++) #-}
+Stream stepa ta +++ Stream stepb tb = Stream step (Left ta)
+  where
+    {-# INLINE_INNER step #-}
+    step (Left  sa) = do
+                        r <- stepa sa
+                        case r of
+                          Yield x sa' -> return $ Yield x (Left  sa')
+                          Skip    sa' -> return $ Skip    (Left  sa')
+                          Done        -> return $ Skip    (Right tb)
+    step (Right sb) = do
+                        r <- stepb sb
+                        case r of
+                          Yield x sb' -> return $ Yield x (Right sb')
+                          Skip    sb' -> return $ Skip    (Right sb')
+                          Done        -> return $ Done
diff --git a/testsuite/tests/perf/should_run/T13623.stdout b/testsuite/tests/perf/should_run/T13623.stdout
new file mode 100644 (file)
index 0000000..ac3eff3
--- /dev/null
@@ -0,0 +1 @@
+65000250000
index 0451348..9c92cd6 100644 (file)
@@ -539,3 +539,11 @@ test('DeriveNull',
     ['-O'])
 
 test('DeriveNullTermination', normal, compile_and_run, [''])
+
+test('T13623',
+    [stats_num_field('bytes allocated',
+                    [ (wordsize(64), 50936, 5) ]),
+                    # 2017-05-02     50936 initial
+     only_ways(['normal'])],
+    compile_and_run,
+    ['-O2'])