Let the specialiser work on dicts under lambdas
authorSandy Maguire <sandy@sandymaguire.me>
Thu, 16 May 2019 16:12:10 +0000 (12:12 -0400)
committerMarge Bot <ben+marge-bot@smart-cactus.org>
Sun, 26 May 2019 12:57:20 +0000 (08:57 -0400)
Following the discussion under #16473, this change allows the
specializer to work on any dicts in a lambda, not just those that occur
at the beginning.

For example, if you use data types which contain dictionaries and
higher-rank functions then once these are erased by the optimiser you
end up with functions such as:

```
  go_s4K9
  Int#
  -> forall (m :: * -> *).
     Monad m =>
     (forall x. Union '[State (Sum Int)] x -> m x) -> m ()
```

The dictionary argument is after the Int# value argument, this patch
allows `go` to be specialised.

compiler/specialise/Specialise.hs
testsuite/tests/perf/compiler/Makefile
testsuite/tests/perf/compiler/T16473.hs [new file with mode: 0644]
testsuite/tests/perf/compiler/T16473.stdout [new file with mode: 0644]
testsuite/tests/perf/compiler/all.T
testsuite/tests/simplCore/should_compile/T7785.stderr
testsuite/tests/warnings/should_compile/T16282/T16282.stderr

index 9d87abc..c1396e4 100644 (file)
@@ -5,6 +5,7 @@
 -}
 
 {-# LANGUAGE CPP #-}
+{-# LANGUAGE ViewPatterns #-}
 module Specialise ( specProgram, specUnfolding ) where
 
 #include "HsVersions.h"
@@ -25,13 +26,13 @@ import VarEnv
 import CoreSyn
 import Rules
 import CoreOpt          ( collectBindersPushingCo )
-import CoreUtils        ( exprIsTrivial, applyTypeToArgs, mkCast )
+import CoreUtils        ( exprIsTrivial, mkCast, exprType )
 import CoreFVs
 import CoreArity        ( etaExpandToJoinPointRule )
 import UniqSupply
 import Name
 import MkId             ( voidArgId, voidPrimId )
-import Maybes           ( catMaybes, isJust )
+import Maybes           ( mapMaybe, isJust )
 import MonadUtils       ( foldlM )
 import BasicTypes
 import HscTypes
@@ -42,6 +43,7 @@ import Outputable
 import FastString
 import State
 import UniqDFM
+import TyCoRep (TyCoBinder (..))
 
 import Control.Monad
 import qualified Control.Monad.Fail as MonadFail
@@ -631,6 +633,190 @@ bitten by such instances to revert to the pre-7.10 behavior.
 See #10491
 -}
 
+-- | An argument that we might want to specialise.
+-- See Note [Specialising Calls] for the nitty gritty details.
+data SpecArg
+  =
+    -- | Type arguments that should be specialised, due to appearing
+    -- free in the type of a 'SpecDict'.
+    SpecType Type
+    -- | Type arguments that should remain polymorphic.
+  | UnspecType
+    -- | Dictionaries that should be specialised.
+  | SpecDict DictExpr
+    -- | Value arguments that should not be specialised.
+  | UnspecArg
+
+instance Outputable SpecArg where
+  ppr (SpecType t) = text "SpecType" <+> ppr t
+  ppr UnspecType   = text "UnspecType"
+  ppr (SpecDict d) = text "SpecDict" <+> ppr d
+  ppr UnspecArg    = text "UnspecArg"
+
+getSpecDicts :: [SpecArg] -> [DictExpr]
+getSpecDicts = mapMaybe go
+  where
+    go (SpecDict d) = Just d
+    go _            = Nothing
+
+getSpecTypes :: [SpecArg] -> [Type]
+getSpecTypes = mapMaybe go
+  where
+    go (SpecType t) = Just t
+    go _            = Nothing
+
+isUnspecArg :: SpecArg -> Bool
+isUnspecArg UnspecArg  = True
+isUnspecArg UnspecType = True
+isUnspecArg _          = False
+
+isValueArg :: SpecArg -> Bool
+isValueArg UnspecArg    = True
+isValueArg (SpecDict _) = True
+isValueArg _            = False
+
+-- | Given binders from an original function 'f', and the 'SpecArg's
+-- corresponding to its usage, compute everything necessary to build
+-- a specialisation.
+--
+-- We will use a running example. Consider the function
+--
+--    foo :: forall a b. Eq a => Int -> blah
+--    foo @a @b dEqA i = blah
+--
+-- which is called with the 'CallInfo'
+--
+--    [SpecType T1, UnspecType, SpecDict dEqT1, UnspecArg]
+--
+-- We'd eventually like to build the RULE
+--
+--    RULE "SPEC foo @T1 _"
+--      forall @a @b (dEqA' :: Eq a).
+--        foo @T1 @b dEqA' = $sfoo @b
+--
+-- and the specialisation '$sfoo'
+--
+--    $sfoo :: forall b. Int -> blah
+--    $sfoo @b = \i -> SUBST[a->T1, dEqA->dEqA'] blah
+--
+-- The cases for 'specHeader' below are presented in the same order as this
+-- running example. The result of 'specHeader' for this example is as follows:
+--
+--    ( -- Returned arguments
+--      env + [a -> T1, deqA -> dEqA']
+--    , []
+--
+--      -- RULE helpers
+--    , [b, dx', i]
+--    , [T1, b, dx', i]
+--
+--      -- Specialised function helpers
+--    , [b, i]
+--    , [dx]
+--    , [T1, b, dx_spec, i]
+--    )
+specHeader
+     :: SpecEnv
+     -> [CoreBndr]  -- The binders from the original function 'f'
+     -> [SpecArg]   -- From the CallInfo
+     -> SpecM ( -- Returned arguments
+                SpecEnv      -- Substitution to apply to the body of 'f'
+              , [CoreBndr]   -- All the remaining unspecialised args from the original function 'f'
+
+                -- RULE helpers
+              , [CoreBndr]   -- Binders for the RULE
+              , [CoreArg]    -- Args for the LHS of the rule
+
+                -- Specialised function helpers
+              , [CoreBndr]   -- Binders for $sf
+              , [DictBind]   -- Auxiliary dictionary bindings
+              , [CoreExpr]   -- Specialised arguments for unfolding
+              )
+
+-- We want to specialise on type 'T1', and so we must construct a substitution
+-- 'a->T1', as well as a LHS argument for the resulting RULE and unfolding
+-- details.
+specHeader env (bndr : bndrs) (SpecType t : args)
+  = do { let env' = extendTvSubstList env [(bndr, t)]
+       ; (env'', unused_bndrs, rule_bs, rule_es, bs', dx, spec_args)
+            <- specHeader env' bndrs args
+       ; pure ( env''
+              , unused_bndrs
+              , rule_bs
+              , Type t : rule_es
+              , bs'
+              , dx
+              , Type t : spec_args
+              )
+       }
+
+-- Next we have a type that we don't want to specialise. We need to perform
+-- a substitution on it (in case the type refers to 'a'). Additionally, we need
+-- to produce a binder, LHS argument and RHS argument for the resulting rule,
+-- /and/ a binder for the specialised body.
+specHeader env (bndr : bndrs) (UnspecType : args)
+  = do { let (env', bndr') = substBndr env bndr
+       ; (env'', unused_bndrs, rule_bs, rule_es, bs', dx, spec_args)
+            <- specHeader env' bndrs args
+       ; pure ( env''
+              , unused_bndrs
+              , bndr' : rule_bs
+              , varToCoreExpr bndr' : rule_es
+              , bndr' : bs'
+              , dx
+              , varToCoreExpr bndr' : spec_args
+              )
+       }
+
+-- Next we want to specialise the 'Eq a' dict away. We need to construct
+-- a wildcard binder to match the dictionary (See Note [Specialising Calls] for
+-- the nitty-gritty), as a LHS rule and unfolding details.
+specHeader env (bndr : bndrs) (SpecDict d : args)
+  = do { inst_dict_id <- newDictBndr env bndr
+       ; let (rhs_env2, dx_binds, spec_dict_args')
+                = bindAuxiliaryDicts env [bndr] [d] [inst_dict_id]
+       ; (env', unused_bndrs, rule_bs, rule_es, bs', dx, spec_args)
+             <- specHeader rhs_env2 bndrs args
+       ; pure ( env'
+              , unused_bndrs
+              -- See Note [Evidence foralls]
+              , exprFreeIdsList (varToCoreExpr inst_dict_id) ++ rule_bs
+              , varToCoreExpr inst_dict_id : rule_es
+              , bs'
+              , dx_binds ++ dx
+              , spec_dict_args' ++ spec_args
+              )
+       }
+
+-- Finally, we have the unspecialised argument 'i'. We need to produce
+-- a binder, LHS and RHS argument for the RULE, and a binder for the
+-- specialised body.
+--
+-- NB: Calls to 'specHeader' will trim off any trailing 'UnspecArg's, which is
+-- why 'i' doesn't appear in our RULE above. But we have no guarantee that
+-- there aren't 'UnspecArg's which come /before/ all of the dictionaries, so
+-- this case must be here.
+specHeader env (bndr : bndrs) (UnspecArg : args)
+  = do { let (env', bndr') = substBndr env bndr
+       ; (env'', unused_bndrs, rule_bs, rule_es, bs', dx, spec_args)
+             <- specHeader env' bndrs args
+       ; pure ( env''
+              , unused_bndrs
+              , bndr' : rule_bs
+              , varToCoreExpr bndr' : rule_es
+              , bndr' : bs'
+              , dx
+              , varToCoreExpr bndr' : spec_args
+              )
+       }
+
+-- Return all remaining binders from the original function. These have the
+-- invariant that they should all correspond to unspecialised arguments, so
+-- it's safe to stop processing at this point.
+specHeader env bndrs [] = pure (env, bndrs, [], [], [], [], [])
+specHeader env [] _     = pure (env, [], [], [], [], [], [])
+
+
 -- | Specialise a set of calls to imported bindings
 specImports :: DynFlags
             -> Module
@@ -1171,8 +1357,7 @@ type SpecInfo = ( [CoreRule]       -- Specialisation rules
 
 specCalls mb_mod env existing_rules calls_for_me fn rhs
         -- The first case is the interesting one
-  |  rhs_tyvars `lengthIs`      n_tyvars -- Rhs of fn's defn has right number of big lambdas
-  && rhs_bndrs1 `lengthAtLeast` n_dicts -- and enough dict args
+  |  callSpecArity pis <= fn_arity      -- See Note [Specialisation Must Preserve Sharing]
   && notNull calls_for_me               -- And there are some calls to specialise
   && not (isNeverActive (idInlineActivation fn))
         -- Don't specialise NOINLINE things
@@ -1193,15 +1378,14 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
     -- pprTrace "specDefn: none" (ppr fn <+> ppr calls_for_me) $
     return ([], [], emptyUDs)
   where
-    _trace_doc = sep [ ppr rhs_tyvars, ppr n_tyvars
-                     , ppr rhs_bndrs, ppr n_dicts
+    _trace_doc = sep [ ppr rhs_tyvars, ppr rhs_bndrs
                      , ppr (idInlineActivation fn) ]
 
     fn_type                 = idType fn
     fn_arity                = idArity fn
     fn_unf                  = realIdUnfolding fn  -- Ignore loop-breaker-ness here
-    (tyvars, theta, _)      = tcSplitSigmaTy fn_type
-    n_tyvars                = length tyvars
+    pis                     = fst $ splitPiTys fn_type
+    theta                   = getTheta pis
     n_dicts                 = length theta
     inl_prag                = idInlinePragma fn
     inl_act                 = inlinePragmaActivation inl_prag
@@ -1212,10 +1396,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
 
     (rhs_bndrs, rhs_body)      = collectBindersPushingCo rhs
                                  -- See Note [Account for casts in binding]
-    (rhs_tyvars, rhs_bndrs1)   = span isTyVar rhs_bndrs
-    (rhs_dict_ids, rhs_bndrs2) = splitAt n_dicts rhs_bndrs1
-    body                       = mkLams rhs_bndrs2 rhs_body
-                                 -- Glue back on the non-dict lambdas
+    rhs_tyvars = filter isTyVar rhs_bndrs
 
     in_scope = CoreSubst.substInScope (se_subst env)
 
@@ -1227,59 +1408,19 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
          -- NB: we look both in the new_rules (generated by this invocation
          --     of specCalls), and in existing_rules (passed in to specCalls)
 
-    mk_ty_args :: [Maybe Type] -> [TyVar] -> [CoreExpr]
-    mk_ty_args [] poly_tvs
-      = ASSERT( null poly_tvs ) []
-    mk_ty_args (Nothing : call_ts) (poly_tv : poly_tvs)
-      = Type (mkTyVarTy poly_tv) : mk_ty_args call_ts poly_tvs
-    mk_ty_args (Just ty : call_ts) poly_tvs
-      = Type ty : mk_ty_args call_ts poly_tvs
-    mk_ty_args (Nothing : _) [] = panic "mk_ty_args"
-
     ----------------------------------------------------------
         -- Specialise to one particular call pattern
     spec_call :: SpecInfo                         -- Accumulating parameter
               -> CallInfo                         -- Call instance
               -> SpecM SpecInfo
     spec_call spec_acc@(rules_acc, pairs_acc, uds_acc)
-              (CI { ci_key = CallKey call_ts, ci_args = call_ds })
-      = ASSERT( call_ts `lengthIs` n_tyvars  && call_ds `lengthIs` n_dicts )
-
-        -- Suppose f's defn is  f = /\ a b c -> \ d1 d2 -> rhs
-        -- Suppose the call is for f [Just t1, Nothing, Just t3] [dx1, dx2]
-
-        -- Construct the new binding
-        --      f1 = SUBST[a->t1,c->t3, d1->d1', d2->d2'] (/\ b -> rhs)
-        -- PLUS the rule
-        --      RULE "SPEC f" forall b d1' d2'. f b d1' d2' = f1 b
-        --      In the rule, d1' and d2' are just wildcards, not used in the RHS
-        -- PLUS the usage-details
-        --      { d1' = dx1; d2' = dx2 }
-        -- where d1', d2' are cloned versions of d1,d2, with the type substitution
-        -- applied.  These auxiliary bindings just avoid duplication of dx1, dx2
-        --
-        -- Note that the substitution is applied to the whole thing.
-        -- This is convenient, but just slightly fragile.  Notably:
-        --      * There had better be no name clashes in a/b/c
-        do { let
-                -- poly_tyvars = [b] in the example above
-                -- spec_tyvars = [a,c]
-                -- ty_args     = [t1,b,t3]
-                spec_tv_binds = [(tv,ty) | (tv, Just ty) <- rhs_tyvars `zip` call_ts]
-                env1          = extendTvSubstList env spec_tv_binds
-                (rhs_env, poly_tyvars) = substBndrs env1
-                                            [tv | (tv, Nothing) <- rhs_tyvars `zip` call_ts]
-
-             -- Clone rhs_dicts, including instantiating their types
-           ; inst_dict_ids <- mapM (newDictBndr rhs_env) rhs_dict_ids
-           ; let (rhs_env2, dx_binds, spec_dict_args)
-                            = bindAuxiliaryDicts rhs_env rhs_dict_ids call_ds inst_dict_ids
-                 ty_args    = mk_ty_args call_ts poly_tyvars
-                 ev_args    = map varToCoreExpr inst_dict_ids  -- ev_args, ev_bndrs:
-                 ev_bndrs   = exprsFreeIdsList ev_args         -- See Note [Evidence foralls]
-                 rule_args  = ty_args     ++ ev_args
-                 rule_bndrs = poly_tyvars ++ ev_bndrs
+              (CI { ci_key = call_args, ci_arity = call_arity })
+      = ASSERT(call_arity <= fn_arity)
 
+        -- See Note [Specialising Calls]
+        do { (rhs_env2, unused_bndrs, rule_bndrs, rule_args, unspec_bndrs, dx_binds, spec_args)
+               <- specHeader env rhs_bndrs $ dropWhileEndLE isUnspecArg call_args
+           ; let rhs_body' = mkLams unused_bndrs rhs_body
            ; dflags <- getDynFlags
            ; if already_covered dflags rules_acc rule_args
              then return spec_acc
@@ -1288,25 +1429,28 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                   --                           , ppr dx_binds ]) $
                   do
            {    -- Figure out the type of the specialised function
-             let body_ty = applyTypeToArgs rhs fn_type rule_args
-                 (lam_args, app_args)           -- Add a dummy argument if body_ty is unlifted
+             let body = mkLams unspec_bndrs rhs_body'
+                 body_ty = substTy rhs_env2 $ exprType body
+                 (lam_extra_args, app_args)     -- See Note [Specialisations Must Be Lifted]
                    | isUnliftedType body_ty     -- C.f. WwLib.mkWorkerArgs
                    , not (isJoinId fn)
-                   = (poly_tyvars ++ [voidArgId], poly_tyvars ++ [voidPrimId])
-                   | otherwise = (poly_tyvars, poly_tyvars)
-                 spec_id_ty = mkLamTypes lam_args body_ty
+                   = ([voidArgId], unspec_bndrs ++ [voidPrimId])
+                   | otherwise = ([], unspec_bndrs)
                  join_arity_change = length app_args - length rule_args
                  spec_join_arity | Just orig_join_arity <- isJoinId_maybe fn
                                  = Just (orig_join_arity + join_arity_change)
                                  | otherwise
                                  = Nothing
 
+           ; (spec_rhs, rhs_uds) <- specExpr rhs_env2 (mkLams lam_extra_args body)
+           ; let spec_id_ty = exprType spec_rhs
            ; spec_f <- newSpecIdSM fn spec_id_ty spec_join_arity
-           ; (spec_rhs, rhs_uds) <- specExpr rhs_env2 (mkLams lam_args body)
            ; this_mod <- getModule
            ; let
                 -- The rule to put in the function's specialisation is:
-                --      forall b, d1',d2'.  f t1 b t3 d1' d2' = f1 b
+                --      forall x @b d1' d2'.
+                --          f x @T1 @b @T2 d1' d2' = f1 x @b
+                -- See Note [Specialising Calls]
                 herald = case mb_mod of
                            Nothing        -- Specialising local fn
                                -> text "SPEC"
@@ -1315,7 +1459,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
 
                 rule_name = mkFastString $ showSDoc dflags $
                             herald <+> ftext (occNameFS (getOccName fn))
-                                   <+> hsep (map ppr_call_key_ty call_ts)
+                                   <+> hsep (mapMaybe ppr_call_key_ty call_args)
                             -- This name ends up in interface files, so use occNameString.
                             -- Otherwise uniques end up there, making builds
                             -- less deterministic (See #4012 comment:61 ff)
@@ -1338,6 +1482,7 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                       Nothing -> rule_wout_eta
 
                 -- Add the { d1' = dx1; d2' = dx2 } usage stuff
+                -- See Note [Specialising Calls]
                 spec_uds = foldr consDictBind rhs_uds dx_binds
 
                 --------------------------------------
@@ -1352,11 +1497,9 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                   = (inl_prag { inl_inline = NoUserInline }, noUnfolding)
 
                   | otherwise
-                  = (inl_prag, specUnfolding dflags poly_tyvars spec_app
-                                             arity_decrease fn_unf)
+                  = (inl_prag, specUnfolding dflags unspec_bndrs spec_app n_dicts fn_unf)
 
-                arity_decrease = length spec_dict_args
-                spec_app e = (e `mkApps` ty_args) `mkApps` spec_dict_args
+                spec_app e = e `mkApps` spec_args
 
                 --------------------------------------
                 -- Adding arity information just propagates it a bit faster
@@ -1368,13 +1511,116 @@ specCalls mb_mod env existing_rules calls_for_me fn rhs
                                         `setIdUnfolding`  spec_unf
                                         `asJoinId_maybe`  spec_join_arity
 
-           ; return ( spec_rule                  : rules_acc
+                _rule_trace_doc = vcat [ ppr spec_f, ppr fn_type, ppr spec_id_ty
+                                       , ppr rhs_bndrs, ppr call_args
+                                       , ppr spec_rule
+                                       ]
+
+           ; -- pprTrace "spec_call: rule" _rule_trace_doc
+             return ( spec_rule                  : rules_acc
                     , (spec_f_w_arity, spec_rhs) : pairs_acc
                     , spec_uds           `plusUDs` uds_acc
                     ) } }
 
-{- Note [Account for casts in binding]
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+{- Note [Specialisation Must Preserve Sharing]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider a function:
+
+    f :: forall a. Eq a => a -> blah
+    f =
+      if expensive
+         then f1
+         else f2
+
+As written, all calls to 'f' will share 'expensive'. But if we specialise 'f'
+at 'Int', eg:
+
+    $sfInt = SUBST[a->Int,dict->dEqInt] (if expensive then f1 else f2)
+
+    RULE "SPEC f"
+      forall (d :: Eq Int).
+        f Int _ = $sfIntf
+
+We've now lost sharing between 'f' and '$sfInt' for 'expensive'. Yikes!
+
+To avoid this, we only generate specialisations for functions whose arity is
+enough to bind all of the arguments we need to specialise.  This ensures our
+specialised functions don't do any work before receiving all of their dicts,
+and thus avoids the 'f' case above.
+
+Note [Specialisations Must Be Lifted]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider a function 'f':
+
+    f = forall a. Eq a => Array# a
+
+used like
+
+    case x of
+      True -> ...f @Int dEqInt...
+      False -> 0
+
+Naively, we might generate an (expensive) specialisation
+
+    $sfInt :: Array# Int
+
+even in the case that @x = False@! Instead, we add a dummy 'Void#' argument to
+the specialisation '$sfInt' ($sfInt :: Void# -> Array# Int) in order to
+preserve laziness.
+
+Note [Specialising Calls]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Suppose we have a function:
+
+    f :: Int -> forall a b c. (Foo a, Foo c) => Bar -> Qux
+    f = \x -> /\ a b c -> \d1 d2 bar -> rhs
+
+and suppose it is called at:
+
+    f 7 @T1 @T2 @T3 dFooT1 dFooT3 bar
+
+This call is described as a 'CallInfo' whose 'ci_key' is
+
+    [ UnspecArg, SpecType T1, UnspecType, SpecType T3, SpecDict dFooT1
+    , SpecDict dFooT3, UnspecArg ]
+
+Why are 'a' and 'c' identified as 'SpecType', while 'b' is 'UnspecType'?
+Because we must specialise the function on type variables that appear
+free in its *dictionary* arguments; but not on type variables that do not
+appear in any dictionaries, i.e. are fully polymorphic.
+
+Because this call has dictionaries applied, we'd like to specialise
+the call on any type argument that appears free in those dictionaries.
+In this case, those are (a ~ T1, c ~ T3).
+
+As a result, we'd like to generate a function:
+
+    $sf :: Int -> forall b. Bar -> Qux
+    $sf = SUBST[a->T1, c->T3, d1->d1', d2->d2'] (\x -> /\ b -> \bar -> rhs)
+
+Note that the substitution is applied to the whole thing.  This is
+convenient, but just slightly fragile.  Notably:
+  * There had better be no name clashes in a/b/c
+
+We must construct a rewrite rule:
+
+    RULE "SPEC f @T1 _ @T3"
+      forall (x :: Int) (@b :: Type) (d1' :: Foo T1) (d2' :: Foo T3).
+        f x @T1 @b @T3 d1' d2' = $sf x @b
+
+In the rule, d1' and d2' are just wildcards, not used in the RHS.  Note
+additionally that 'bar' isn't captured by this rule --- we bind only
+enough etas in order to capture all of the *specialised* arguments.
+
+Finally, we must also construct the usage-details
+
+     { d1' = dx1; d2' = dx2 }
+
+where d1', d2' are cloned versions of d1,d2, with the type substitution
+applied.  These auxiliary bindings just avoid duplication of dx1, dx2.
+
+Note [Account for casts in binding]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Consider
    f :: Eq a => a -> IO ()
    {-# INLINABLE f
@@ -1888,16 +2134,14 @@ data CallInfoSet = CIS Id (Bag CallInfo)
   -- These dups are eliminated by already_covered in specCalls
 
 data CallInfo
-  = CI { ci_key  :: CallKey     -- Type arguments
-       , ci_args :: [DictExpr]  -- Dictionary arguments
-       , ci_fvs  :: VarSet      -- Free vars of the ci_key and ci_args
+  = CI { ci_key  :: [SpecArg]   -- All arguments
+       , ci_arity :: Int        -- The number of variables necessary to bind
+                                -- all of the specialised arguments
+       , ci_fvs  :: VarSet      -- Free vars of the ci_key
                                 -- call (including tyvars)
                                 -- [*not* include the main id itself, of course]
     }
 
-newtype CallKey   = CallKey [Maybe Type]
-  -- Nothing => unconstrained type argument
-
 type DictExpr = CoreExpr
 
 ciSetFilter :: (CallInfo -> Bool) -> CallInfoSet -> CallInfoSet
@@ -1911,16 +2155,15 @@ pprCallInfo :: Id -> CallInfo -> SDoc
 pprCallInfo fn (CI { ci_key = key })
   = ppr fn <+> ppr key
 
-ppr_call_key_ty :: Maybe Type -> SDoc
-ppr_call_key_ty Nothing   = char '_'
-ppr_call_key_ty (Just ty) = char '@' <+> pprParendType ty
-
-instance Outputable CallKey where
-  ppr (CallKey ts) = brackets (fsep (map ppr_call_key_ty ts))
+ppr_call_key_ty :: SpecArg -> Maybe SDoc
+ppr_call_key_ty (SpecType ty) = Just $ char '@' <+> pprParendType ty
+ppr_call_key_ty UnspecType    = Just $ char '_'
+ppr_call_key_ty (SpecDict _)  = Nothing
+ppr_call_key_ty UnspecArg     = Nothing
 
 instance Outputable CallInfo where
-  ppr (CI { ci_key = key, ci_args = args, ci_fvs = fvs })
-    = text "CI" <> braces (hsep [ ppr key, ppr args, ppr fvs ])
+  ppr (CI { ci_key = key, ci_fvs = fvs })
+    = text "CI" <> braces (hsep [ fsep (mapMaybe ppr_call_key_ty key), ppr fvs ])
 
 unionCalls :: CallDetails -> CallDetails -> CallDetails
 unionCalls c1 c2 = plusDVarEnv_C unionCallInfoSet c1 c2
@@ -1939,17 +2182,29 @@ callInfoFVs :: CallInfoSet -> VarSet
 callInfoFVs (CIS _ call_info) =
   foldrBag (\(CI { ci_fvs = fv }) vs -> unionVarSet fv vs) emptyVarSet call_info
 
+computeArity :: [SpecArg] -> Int
+computeArity = length . filter isValueArg . dropWhileEndLE isUnspecArg
+
+callSpecArity :: [TyCoBinder] -> Int
+callSpecArity = length . filter (not . isNamedBinder) . dropWhileEndLE isVisibleBinder
+
+getTheta :: [TyCoBinder] -> [PredType]
+getTheta = fmap tyBinderType . filter isInvisibleBinder . filter (not . isNamedBinder)
+
+
 ------------------------------------------------------------
-singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails
-singleCall id tys dicts
+singleCall :: Id -> [SpecArg] -> UsageDetails
+singleCall id args
   = MkUD {ud_binds = emptyBag,
           ud_calls = unitDVarEnv id $ CIS id $
-                     unitBag (CI { ci_key = CallKey tys
-                                 , ci_args = dicts
+                     unitBag (CI { ci_key  = args -- used to be tys
+                                 , ci_arity = computeArity args
                                  , ci_fvs  = call_fvs }) }
   where
+    tys      = getSpecTypes args
+    dicts    = getSpecDicts args
     call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
-    tys_fvs  = tyCoVarsOfTypes (catMaybes tys)
+    tys_fvs  = tyCoVarsOfTypes tys
         -- The type args (tys) are guaranteed to be part of the dictionary
         -- types, because they are just the constrained types,
         -- and the dictionary is therefore sure to be bound
@@ -1973,8 +2228,8 @@ mkCallUDs' env f args
   = emptyUDs
 
   |  not (all type_determines_value theta)
-  || not (spec_tys `lengthIs` n_tyvars)
-  || not ( dicts   `lengthIs` n_dicts)
+  || not (computeArity ci_key <= idArity f)
+  || not (length dicts == length theta)
   || not (any (interestingDict env) dicts)    -- Note [Interesting dictionary arguments]
   -- See also Note [Specialisations already covered]
   = -- pprTrace "mkCallUDs: discarding" _trace_doc
@@ -1982,27 +2237,28 @@ mkCallUDs' env f args
 
   | otherwise
   = -- pprTrace "mkCallUDs: keeping" _trace_doc
-    singleCall f spec_tys dicts
+    singleCall f ci_key
   where
-    _trace_doc = vcat [ppr f, ppr args, ppr n_tyvars, ppr n_dicts
-                      , ppr (map (interestingDict env) dicts)]
-    (tyvars, theta, _)      = tcSplitSigmaTy (idType f)
-    constrained_tyvars      = tyCoVarsOfTypes theta
-    n_tyvars                = length tyvars
-    n_dicts                 = length theta
-
-    spec_tys = [mk_spec_ty tv ty | (tv, ty) <- tyvars `type_zip` args]
-    dicts    = [dict_expr | (_, dict_expr) <- theta `zip` (drop n_tyvars args)]
-
-    -- ignores Coercion arguments
-    type_zip :: [TyVar] -> [CoreExpr] -> [(TyVar, Type)]
-    type_zip tvs      (Coercion _ : args) = type_zip tvs args
-    type_zip (tv:tvs) (Type ty : args)    = (tv, ty) : type_zip tvs args
-    type_zip _        _                   = []
-
-    mk_spec_ty tyvar ty
-        | tyvar `elemVarSet` constrained_tyvars = Just ty
-        | otherwise                             = Nothing
+    _trace_doc = vcat [ppr f, ppr args, ppr (map (interestingDict env) dicts)]
+    pis                = fst $ splitPiTys $ idType f
+    theta              = getTheta pis
+    constrained_tyvars = tyCoVarsOfTypes theta
+
+    ci_key :: [SpecArg]
+    ci_key = fmap (\(t, a) ->
+      case t of
+        Named (binderVar -> tyVar)
+          |  tyVar `elemVarSet` constrained_tyvars
+          -> case a of
+              Type ty -> SpecType ty
+              _ -> pprPanic "ci_key" $ ppr a
+          |  otherwise
+          -> UnspecType
+        Anon InvisArg _ -> SpecDict a
+        Anon VisArg _ -> UnspecArg
+                ) $ zip pis args
+
+    dicts = getSpecDicts ci_key
 
     want_calls_for f = isLocalId f || isJust (maybeUnfoldingTemplate (realIdUnfolding f))
          -- For imported things, we gather call instances if
index 7d8e96f..b27c842 100644 (file)
@@ -2,8 +2,12 @@ TOP=../../..
 include $(TOP)/mk/boilerplate.mk
 include $(TOP)/mk/test.mk
 
-.PHONY: T4007
+.PHONY: T4007 T16473
 T4007:
        $(RM) -f T4007.hi T4007.o
        '$(TEST_HC)' $(TEST_HC_OPTS) -c -O -ddump-rule-firings T4007.hs
 
+T16473:
+       $(RM) -f T16473.hi T16473.o
+       '$(TEST_HC)' $(TEST_HC_OPTS) -c -O -ddump-rule-firings T16473.hs
+
diff --git a/testsuite/tests/perf/compiler/T16473.hs b/testsuite/tests/perf/compiler/T16473.hs
new file mode 100644 (file)
index 0000000..8a9751e
--- /dev/null
@@ -0,0 +1,102 @@
+{-# LANGUAGE BangPatterns        #-}
+{-# LANGUAGE DataKinds           #-}
+{-# LANGUAGE DeriveFunctor       #-}
+{-# LANGUAGE GADTs               #-}
+{-# LANGUAGE KindSignatures      #-}
+{-# LANGUAGE LambdaCase          #-}
+{-# LANGUAGE RankNTypes          #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators       #-}
+
+{-# OPTIONS_GHC -flate-specialise -O2 #-}
+
+module Main (main) where
+
+import qualified Control.Monad.State.Strict as S
+import           Data.Foldable
+import           Data.Functor.Identity
+import           Data.Kind
+import           Data.Monoid
+import           Data.Tuple
+
+main :: IO ()
+main = print $ badCore 100
+
+badCore :: Int -> Int
+badCore n  = getSum $ fst $ run  $ runState mempty $ for_ [0..n] $ \i ->   modify (<> Sum i)
+
+data Union (r :: [Type -> Type]) a where
+  Union :: e a -> Union '[e] a
+
+decomp :: Union (e ': r) a -> e a
+decomp (Union a) = a
+{-# INLINE decomp #-}
+
+absurdU :: Union '[] a -> b
+absurdU = absurdU
+
+newtype Semantic r a = Semantic
+  { runSemantic
+        :: forall m
+         . Monad m
+        => (forall x. Union r x -> m x)
+        -> m a
+  }
+
+instance Functor (Semantic f) where
+  fmap f (Semantic m) = Semantic $ \k -> fmap f $ m k
+  {-# INLINE fmap #-}
+
+instance Applicative (Semantic f) where
+  pure a = Semantic $ const $ pure a
+  {-# INLINE pure #-}
+  Semantic f <*> Semantic a = Semantic $ \k -> f k <*> a k
+  {-# INLINE (<*>) #-}
+
+instance Monad (Semantic f) where
+  return = pure
+  {-# INLINE return #-}
+  Semantic ma >>= f = Semantic $ \k -> do
+    z <- ma k
+    runSemantic (f z) k
+  {-# INLINE (>>=) #-}
+
+data State s a
+  = Get (s -> a)
+  | Put s a
+  deriving Functor
+
+get :: Semantic '[State s] s
+get = Semantic $ \k -> k $ Union $ Get id
+{-# INLINE get #-}
+
+put :: s -> Semantic '[State s] ()
+put !s = Semantic $ \k -> k $ Union $! Put s ()
+{-# INLINE put #-}
+
+modify :: (s -> s) -> Semantic '[State s] ()
+modify f = do
+  !s <- get
+  put $! f s
+{-# INLINE modify #-}
+
+runState :: s -> Semantic (State s ': r) a -> Semantic r (s, a)
+runState = interpretInStateT $ \case
+  Get k   -> fmap k S.get
+  Put s k -> S.put s >> pure k
+{-# INLINE[3] runState #-}
+
+run :: Semantic '[] a -> a
+run (Semantic m) = runIdentity $ m absurdU
+{-# INLINE run #-}
+
+interpretInStateT
+    :: (forall x. e x -> S.StateT s (Semantic r) x)
+    -> s
+    -> Semantic (e ': r) a
+    -> Semantic r (s, a)
+interpretInStateT f s (Semantic m) = Semantic $ \k ->
+  fmap swap $ flip S.runStateT s $ m $ \u ->
+    S.mapStateT (\z -> runSemantic z k) $ f $ decomp u
+{-# INLINE interpretInStateT #-}
+
diff --git a/testsuite/tests/perf/compiler/T16473.stdout b/testsuite/tests/perf/compiler/T16473.stdout
new file mode 100644 (file)
index 0000000..3a1f5a5
--- /dev/null
@@ -0,0 +1,139 @@
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op liftA2 (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op <*> (BUILTIN)
+Rule fired: Class op <$ (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op <*> (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op get (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op >> (BUILTIN)
+Rule fired: Class op put (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op get (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op >> (BUILTIN)
+Rule fired: Class op put (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op show (BUILTIN)
+Rule fired: Class op mempty (BUILTIN)
+Rule fired: Class op fromInteger (BUILTIN)
+Rule fired: integerToInt (BUILTIN)
+Rule fired: Class op <> (BUILTIN)
+Rule fired: Class op + (BUILTIN)
+Rule fired: Class op enumFromTo (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op *> (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: fold/build (GHC.Base)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: ># (BUILTIN)
+Rule fired: ==# (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: SPEC/Main $fApplicativeStateT @ Identity _ (Main)
+Rule fired: SPEC/Main $fMonadStateT_$c>>= @ Identity _ (Main)
+Rule fired: SPEC/Main $fMonadStateT_$c>> @ Identity _ (Main)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: SPEC/Main $fApplicativeStateT @ Identity _ (Main)
+Rule fired: SPEC/Main $fMonadStateT_$c>>= @ Identity _ (Main)
+Rule fired: SPEC/Main $fMonadStateT_$c>> @ Identity _ (Main)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: SPEC/Main $fApplicativeStateT @ Identity _ (Main)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: SPEC/Main $fApplicativeStateT @ Identity _ (Main)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op return (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: SPEC/Main $fFunctorStateT @ Identity _ (Main)
+Rule fired:
+    SPEC/Main $fApplicativeStateT_$cpure @ Identity _ (Main)
+Rule fired: SPEC/Main $fApplicativeStateT_$c<*> @ Identity _ (Main)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: SPEC/Main $fApplicativeStateT_$c*> @ Identity _ (Main)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: SPEC/Main $fFunctorStateT @ Identity _ (Main)
+Rule fired:
+    SPEC/Main $fApplicativeStateT_$cpure @ Identity _ (Main)
+Rule fired: SPEC/Main $fApplicativeStateT_$c<*> @ Identity _ (Main)
+Rule fired: SPEC/Main $fApplicativeStateT_$c*> @ Identity _ (Main)
+Rule fired: SPEC/Main $fMonadStateT @ Identity _ (Main)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op <*> (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op <*> (BUILTIN)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op $p1Applicative (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: Class op >>= (BUILTIN)
+Rule fired: Class op fmap (BUILTIN)
+Rule fired: SPEC go @ (StateT (Sum Int) Identity) (Main)
+Rule fired: Class op $p1Monad (BUILTIN)
+Rule fired: Class op pure (BUILTIN)
+Rule fired: SPEC/Main $fMonadStateT @ Identity _ (Main)
+Rule fired: SPEC go @ (StateT (Sum Int) Identity) (Main)
+Rule fired: Class op fmap (BUILTIN)
index 44216f4..0db9bcf 100644 (file)
@@ -404,3 +404,5 @@ test('T16190',
       collect_stats(),
       multimod_compile,
       ['T16190.hs', '-v0'])
+
+test('T16473', normal, makefile_test, ['T16473'])
index c0e91b9..3fd78bd 100644 (file)
@@ -1,7 +1,7 @@
 
 ==================== Tidy Core rules ====================
 "SPEC shared @ []"
-    forall (irred :: Domain [] Int) ($dMyFunctor :: MyFunctor []).
+    forall ($dMyFunctor :: MyFunctor []) (irred :: Domain [] Int).
       shared @ [] $dMyFunctor irred
       = bar_$sshared
 
index 3af33f1..e9cc798 100644 (file)
@@ -1,5 +1,10 @@
-\r
-T16282.hs: warning: [-Wall-missed-specialisations]\r
-    Could not specialise imported function ‘Data.Map.Internal.$w$cshowsPrec’\r
-      when specialising ‘Data.Map.Internal.$fShowMap_$cshowsPrec’\r
-    Probable fix: add INLINABLE pragma on ‘Data.Map.Internal.$w$cshowsPrec’\r
+
+T16282.hs: warning: [-Wall-missed-specialisations]
+    Could not specialise imported function ‘Data.Foldable.$wmapM_’
+      when specialising ‘mapM_’
+    Probable fix: add INLINABLE pragma on ‘Data.Foldable.$wmapM_’
+
+T16282.hs: warning: [-Wall-missed-specialisations]
+    Could not specialise imported function ‘Data.Map.Internal.$w$cshowsPrec’
+      when specialising ‘Data.Map.Internal.$fShowMap_$cshowsPrec’
+    Probable fix: add INLINABLE pragma on ‘Data.Map.Internal.$w$cshowsPrec’