Improve SpecConstr when there are many opportunities
authorSimon Peyton Jones <simonpj@microsoft.com>
Thu, 27 Apr 2017 10:15:00 +0000 (11:15 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Tue, 2 May 2017 07:57:42 +0000 (08:57 +0100)
SpecConstr has -fspec-contr-count=N which limits the maximum
number of specialisations we make for any particular function.
But until now, if that limit was exceeded we discarded all the
candidates!  So adding a new specialisaiton opportunity (by
adding a new call site, or improving the optimiser) could result
in less specialisation and worse performance.

This patch instead picks the top N candidates, resulting in
less brittle behaviour.

See Note [Choosing patterns].

compiler/specialise/SpecConstr.hs

index cd5a90c..735a71a 100644 (file)
@@ -62,6 +62,7 @@ import Module
 
 import TyCon ( TyCon )
 import GHC.Exts( SpecConstrAnnotation(..) )
+import Data.Ord( comparing )
 
 {-
 -----------------------------------------------------
@@ -1555,48 +1556,20 @@ specialise env bind_calls (RI { ri_fn = fn, ri_lam_bndrs = arg_bndrs
 
   | Just all_calls <- lookupVarEnv bind_calls fn
   = -- pprTrace "specialise entry {" (ppr fn <+> ppr all_calls) $
-    do  { (boring_call, all_pats) <- callsToPats env specs arg_occs all_calls
-                -- Bale out if too many specialisations
-        ; let pats = filter (is_small_enough . fst) all_pats
-              is_small_enough vars = isWorkerSmallEnough (sc_dflags env) vars
-                  -- We are about to construct w/w pair in 'spec_one'.
-                  -- Omit specialisation leading to high arity workers.
-                  -- See Note [Limit w/w arity] in WwLib
-              n_pats      = length pats
-              spec_count' = n_pats + spec_count
-        ; case sc_count env of
-            Just max | not (sc_force env) && spec_count' > max
-                -- Suppress this scary message for
-                -- ordinary users!  Trac #5125
-                -> if (debugIsOn || hasPprDebug (sc_dflags env))
-                   then pprTrace "SpecConstr" msg $
-                        return (nullUsage, spec_info)
-                   else return (nullUsage, spec_info)
-                where
-                   msg = vcat [ sep [ text "Function" <+> quotes (ppr fn)
-                                    , nest 2 (text "has" <+>
-                                              speakNOf spec_count' (text "call pattern") <> comma <+>
-                                              text "but the limit is" <+> int max) ]
-                              , text "Use -fspec-constr-count=n to set the bound"
-                              , extra ]
-                   extra = sdocWithPprDebug $ \dbg -> if dbg
-                              then text "Specialisations:"
-                                   <+> ppr (pats ++ [p | OS p _ _ _ <- specs])
-                              else text "Use -dppr-debug to see specialisations"
-
-            _normal_case -> do {
-
---        ; if (not (null pats) || isJust mb_unspec) then
---            pprTrace "specialise" (vcat [ ppr fn <+> text "with" <+> int (length pats) <+> text "good patterns"
+    do  { (boring_call, new_pats) <- callsToNewPats env fn spec_info arg_occs all_calls
+
+        ; let n_pats = length new_pats
+--        ; if (not (null new_pats) || isJust mb_unspec) then
+--            pprTrace "specialise" (vcat [ ppr fn <+> text "with" <+> int n_pats <+> text "good patterns"
 --                                        , text "mb_unspec" <+> ppr (isJust mb_unspec)
 --                                        , text "arg_occs" <+> ppr arg_occs
---                                        , text "good pats" <+> ppr pats])  $
+--                                        , text "good pats" <+> ppr new_pats])  $
 --               return ()
 --          else return ()
 
         ; let spec_env = decreaseSpecCount env n_pats
         ; (spec_usgs, new_specs) <- mapAndUnzipM (spec_one spec_env fn arg_bndrs body)
-                                                 (pats `zip` [spec_count..])
+                                                 (new_pats `zip` [spec_count..])
                 -- See Note [Specialise original body]
 
         ; let spec_usg = combineUsages spec_usgs
@@ -1613,13 +1586,16 @@ specialise env bind_calls (RI { ri_fn = fn, ri_lam_bndrs = arg_bndrs
 --        ; pprTrace "specialise return }" (vcat [ ppr fn
 --                                               , text "boring_call:" <+> ppr boring_call
 --                                               , text "new calls:" <+> ppr (scu_calls new_usg)]) $
-          ; return (new_usg, SI (new_specs ++ specs) spec_count' mb_unspec') } }
-
+          ; return (new_usg, SI (new_specs ++ specs)
+                                (spec_count + n_pats)
+                                mb_unspec') }
 
   | otherwise  -- No new seeds, so return nullUsage
   = return (nullUsage, spec_info)
 
 
+
+
 ---------------------
 spec_one :: ScEnv
          -> OutId       -- Function
@@ -1843,29 +1819,123 @@ end up with a rule LHS that doesn't bind the template variables
 The simplifier eliminates such things, but SpecConstr itself constructs
 new terms by substituting.  So the 'mkCast' in the Cast case of scExpr
 is very important!
+
+Note [Choosing patterns]
+~~~~~~~~~~~~~~~~~~~~~~~~
+If we get lots of patterns we may not want to make a specialisation
+for each of them (code bloat), so we choose as follows, implemented
+by trim_pats.
+
+* The flag -fspec-constr-count-N sets the sc_count field
+  of the ScEnv to (Just n).  This limits the total number
+  of specialisations for a given function to N.
+
+* -fno-spec-constr-count sets the sc_count field to Nothing,
+  which switches of the limit.
+
+* The ghastly ForceSpecConstr trick also switches of the limit
+  for a particular function
+
+* Otherwise we sort the patterns to choose the most general
+  ones first; more general => more widely applicable.
 -}
 
 type CallPat = ([Var], [CoreExpr])      -- Quantified variables and arguments
                                         -- See Note [SpecConstr call patterns]
 
-callsToPats :: ScEnv -> [OneSpec] -> [ArgOcc] -> [Call] -> UniqSM (Bool, [CallPat])
+callsToNewPats :: ScEnv -> Id
+               -> RuleInfo
+               -> [ArgOcc] -> [Call]
+               -> UniqSM (Bool, [CallPat])
         -- Result has no duplicate patterns,
         -- nor ones mentioned in done_pats
         -- Bool indicates that there was at least one boring pattern
-callsToPats env done_specs bndr_occs calls
+callsToNewPats env fn spec_info@(SI done_specs _ _) bndr_occs calls
   = do  { mb_pats <- mapM (callToPats env bndr_occs) calls
 
-        ; let good_pats :: [(CallPat, ValueEnv)]
+        ; let have_boring_call = any isNothing mb_pats
+
+              good_pats :: [(CallPat, ValueEnv)]
               good_pats = catMaybes mb_pats
+
+              -- Remove patterns that use too many recursive constructors
+              no_recursive = map fst (filterOut (is_too_recursive env) good_pats)
+
+              -- Remove patterns we have already done
+              new_pats = filterOut is_done no_recursive
               done_pats = [p | OS p _ _ _ <- done_specs]
               is_done p = any (samePat p) done_pats
-              no_recursive = map fst (filterOut (is_too_recursive env) good_pats)
+
+              -- Remove duplicates
+              non_dups = nubBy samePat new_pats
+
+              -- Remove ones that have too many worker variables
+              small_pats = filterOut too_big non_dups
+              too_big (vars,_) = not (isWorkerSmallEnough (sc_dflags env) vars)
+                  -- We are about to construct w/w pair in 'spec_one'.
+                  -- Omit specialisation leading to high arity workers.
+                  -- See Note [Limit w/w arity] in WwLib
+
+                -- Discard specialisations if there are too many of them
+              trimmed_pats = trim_pats env fn spec_info small_pats
 
 --        ; pprTrace "callsToPats" (vcat [ text "calls:" <+> ppr calls
 --                                       , text "good_pats:" <+> ppr good_pats
 --                                       , text "no_recursive:" <+> ppr no_recursive ])  $
-          ; return (any isNothing mb_pats,
-                    filterOut is_done (nubBy samePat no_recursive)) }
+
+        ; return (have_boring_call, trimmed_pats) }
+
+
+trim_pats :: ScEnv -> Id -> RuleInfo -> [CallPat] -> [CallPat]
+-- See Note [Choosing patterns]
+trim_pats env fn (SI _ done_spec_count _) pats
+  | sc_force env
+    || isNothing mb_scc
+    || n_remaining >= n_pats
+  = pats
+  | otherwise
+  = emit_trace $
+    take n_remaining sorted_pats
+  where
+    n_pats         = length pats
+    spec_count'    = n_pats + done_spec_count
+    n_remaining    = max_specs - done_spec_count
+    mb_scc         = sc_count env
+    Just max_specs = mb_scc
+
+    sorted_pats = map fst $
+                  sortBy (comparing snd) $
+                  [(pat, pat_cons pat) | pat <- pats]
+     -- Sort in order of increasing number of constructors
+     -- (i.e. decreasing generality) and pick the initial
+     -- segment of this list
+
+    pat_cons :: CallPat -> Int
+    -- How many data consturorst of literals are in
+    -- the patten.  More data-cons => less general
+    pat_cons (qs, ps) = foldr ((+) . n_cons) 0 ps
+       where
+          q_set = mkVarSet qs
+          n_cons (Var v) | v `elemVarSet` q_set = 0
+                         | otherwise            = 1
+          n_cons (Cast e _)  = n_cons e
+          n_cons (App e1 e2) = n_cons e1 + n_cons e2
+          n_cons (Lit {})    = 1
+          n_cons _           = 0
+
+    emit_trace result
+       | debugIsOn || hasPprDebug (sc_dflags env)
+         -- Suppress this scary message for ordinary users!  Trac #5125
+       = pprTrace "SpecConstr" msg result
+       | otherwise
+       = result
+    msg = vcat [ sep [ text "Function" <+> quotes (ppr fn)
+                     , nest 2 (text "has" <+>
+                               speakNOf spec_count' (text "call pattern") <> comma <+>
+                               text "but the limit is" <+> int max_specs) ]
+               , text "Use -fspec-constr-count=n to set the bound"
+               , text "Discarding:" <+> ppr (drop n_remaining sorted_pats) ]
+
 
 is_too_recursive :: ScEnv -> (CallPat, ValueEnv) -> Bool
     -- Count the number of recursive constructors in a call pattern,