Make Specialise work with casts
authorSimon Peyton Jones <simonpj@microsoft.com>
Tue, 28 Feb 2017 17:11:33 +0000 (12:11 -0500)
committerDavid Feuer <David.Feuer@gmail.com>
Tue, 28 Feb 2017 17:11:35 +0000 (12:11 -0500)
With my upcoming early-inlining patch it turned out that Specialise
was getting stuck on casts.  This patch fixes it; see Note
[Account for casts in binding] in Specialise.

Reviewers: austin, goldfire, bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D3192

compiler/coreSyn/CoreSubst.hs
compiler/specialise/Specialise.hs
compiler/types/Coercion.hs
compiler/types/Unify.hs

index 53072dc..043d3c3 100644 (file)
@@ -34,7 +34,7 @@ module CoreSubst (
         -- ** Simple expression optimiser
         simpleOptPgm, simpleOptExpr, simpleOptExprWith,
         exprIsConApp_maybe, exprIsLiteral_maybe, exprIsLambda_maybe,
-        pushCoArg, pushCoValArg, pushCoTyArg
+        pushCoArg, pushCoValArg, pushCoTyArg, collectBindersPushingCo
     ) where
 
 #include "HsVersions.h"
@@ -1614,7 +1614,7 @@ exprIsLambda_maybe _ _e
 
 Here we implement the "push rules" from FC papers:
 
-* The push-argument ules, where we can move a coercion past an argument.
+* The push-argument rules, where we can move a coercion past an argument.
   We have
       (fun |> co) arg
   and we want to transform it to
@@ -1687,7 +1687,7 @@ pushCoValArg co
   = Just (mkRepReflCo arg, mkRepReflCo res)
 
   | isFunTy tyL
-  , [_, _, co1, co2] <- decomposeCo 4 co
+  , (co1, co2) <- decomposeFunCo co
               -- If   co  :: (tyL1 -> tyL2) ~ (tyR1 -> tyR2)
               -- then co1 :: tyL1 ~ tyR1
               --      co2 :: tyL2 ~ tyR2
@@ -1711,7 +1711,7 @@ pushCoercionIntoLambda in_scope x e co
     , Pair s1s2 t1t2 <- coercionKind co
     , Just (_s1,_s2) <- splitFunTy_maybe s1s2
     , Just (t1,_t2) <- splitFunTy_maybe t1t2
-    = let [_rep1, _rep2, co1, co2] = decomposeCo 4 co
+    = let (co1, co2) = decomposeFunCo co
           -- Should we optimize the coercions here?
           -- Otherwise they might not match too well
           x' = x `setIdType` t1
@@ -1784,3 +1784,57 @@ pushCoDataCon dc dc_args co
 
   where
     Pair from_ty to_ty = coercionKind co
+
+collectBindersPushingCo :: CoreExpr -> ([Var], CoreExpr)
+-- Collect lambda binders, pushing coercions inside if possible
+-- E.g.   (\x.e) |> g         g :: <Int> -> blah
+--        = (\x. e |> Nth 1 g)
+--
+-- That is,
+--
+-- collectBindersPushingCo ((\x.e) |> g) === ([x], e |> Nth 1 g)
+collectBindersPushingCo e
+  = go [] e
+  where
+    -- Peel off lambdas until we hit a cast.
+    go :: [Var] -> CoreExpr -> ([Var], CoreExpr)
+    -- The accumulator is in reverse order
+    go bs (Lam b e)   = go (b:bs) e
+    go bs (Cast e co) = go_c bs e co
+    go bs e           = (reverse bs, e)
+
+    -- We are in a cast; peel off casts until we hit a lambda.
+    go_c :: [Var] -> CoreExpr -> Coercion -> ([Var], CoreExpr)
+    -- (go_c bs e c) is same as (go bs e (e |> c))
+    go_c bs (Cast e co1) co2 = go_c bs e (co1 `mkTransCo` co2)
+    go_c bs (Lam b e)    co  = go_lam bs b e co
+    go_c bs e            co  = (reverse bs, mkCast e co)
+
+    -- We are in a lambda under a cast; peel off lambdas and build a
+    -- new coercion for the body.
+    go_lam :: [Var] -> Var -> CoreExpr -> Coercion -> ([Var], CoreExpr)
+    -- (go_lam bs b e c) is same as (go_c bs (\b.e) c)
+    go_lam bs b e co
+      | isTyVar b
+      , let Pair tyL tyR = coercionKind co
+      , ASSERT( isForAllTy tyL )
+        isForAllTy tyR
+      , isReflCo (mkNthCo 0 co)  -- See Note [collectBindersPushingCo]
+      = go_c (b:bs) e (mkInstCo co (mkNomReflCo (mkTyVarTy b)))
+
+      | isId b
+      , let Pair tyL tyR = coercionKind co
+      , ASSERT( isFunTy tyL) isFunTy tyR
+      , (co_arg, co_res) <- decomposeFunCo co
+      , isReflCo co_arg  -- See Note [collectBindersPushingCo]
+      = go_c (b:bs) e co_res
+
+      | otherwise = (reverse bs, mkCast (Lam b e) co)
+
+{- Note [collectBindersPushingCo]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We just look for coercions of form
+   <type> -> blah
+(and similarly for foralls) to keep this function simple.  We could do
+more elaborate stuff, but it'd involve substitution etc.
+-}
index 9e189df..4419643 100644 (file)
@@ -1153,8 +1153,8 @@ specCalls :: Maybe Module      -- Just this_mod  =>  specialising imported fn
 
 specCalls mb_mod env rules_for_me calls_for_me fn rhs
         -- The first case is the interesting one
-  |  rhs_tyvars `lengthIs`     n_tyvars -- Rhs of fn's defn has right number of big lambdas
-  && rhs_ids    `lengthAtLeast` n_dicts -- and enough dict args
+  |  rhs_tyvars `lengthIs`      n_tyvars -- Rhs of fn's defn has right number of big lambdas
+  && rhs_bndrs1 `lengthAtLeast` n_dicts -- and enough dict args
   && notNull calls_for_me               -- And there are some calls to specialise
   && not (isNeverActive (idInlineActivation fn))
         -- Don't specialise NOINLINE things
@@ -1178,7 +1178,7 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
     return ([], [], emptyUDs)
   where
     _trace_doc = sep [ ppr rhs_tyvars, ppr n_tyvars
-                     , ppr rhs_ids, ppr n_dicts
+                     , ppr rhs_bndrs, ppr n_dicts
                      , ppr (idInlineActivation fn) ]
 
     fn_type                 = idType fn
@@ -1194,11 +1194,12 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
         -- Figure out whether the function has an INLINE pragma
         -- See Note [Inline specialisations]
 
-    (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
-
-    rhs_dict_ids = take n_dicts rhs_ids
-    body         = mkLams (drop n_dicts rhs_ids) rhs_body
-                -- Glue back on the non-dict lambdas
+    (rhs_bndrs, rhs_body)      = CoreSubst.collectBindersPushingCo rhs
+                                 -- See Note [Account for casts in binding]
+    (rhs_tyvars, rhs_bndrs1)   = span isTyVar rhs_bndrs
+    (rhs_dict_ids, rhs_bndrs2) = splitAt n_dicts rhs_bndrs1
+    body                       = mkLams rhs_bndrs2 rhs_body
+                                 -- Glue back on the non-dict lambdas
 
     already_covered :: DynFlags -> [CoreExpr] -> Bool
     already_covered dflags args      -- Note [Specialisations already covered]
@@ -1350,7 +1351,23 @@ specCalls mb_mod env rules_for_me calls_for_me fn rhs
 
            ; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } }
 
-{- Note [Evidence foralls]
+{- Note [Account for casts in binding]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+   f :: Eq a => a -> IO ()
+   {-# INLINABLE f
+       StableUnf = (/\a \(d:Eq a) (x:a). blah) |> g
+     #-}
+   f = ...
+
+In f's stable unfolding we have done some modest simplification which
+has pushed the cast to the outside.  (I wonder if this is the Right
+Thing, but it's what happens now; see SimplUtils Note [Casts and
+lambdas].)  Now that stable unfolding must be specialised, so we want
+to push the cast back inside. It would be terrible if the cast
+defeated specialisation!  Hence the use of collectBindersPushingCo.
+
+Note [Evidence foralls]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~
 Suppose (Trac #12212) that we are specialising
    f :: forall a b. (Num a, F a ~ F b) => blah
index d195b2f..6b7a640 100644 (file)
@@ -48,7 +48,7 @@ module Coercion (
         mapStepResult, unwrapNewTypeStepper,
         topNormaliseNewType_maybe, topNormaliseTypeX,
 
-        decomposeCo, getCoVar_maybe,
+        decomposeCo, decomposeFunCo, getCoVar_maybe,
         splitTyConAppCo_maybe,
         splitAppCo_maybe,
         splitFunCo_maybe,
@@ -293,8 +293,20 @@ ppr_co_ax_branch ppr_rhs
         Destructing coercions
 %*                                                                      *
 %************************************************************************
+
+Note [Function coercions]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Remember that
+  (->) :: forall r1 r2. TYPE r1 -> TYPE r2 -> TYPE LiftedRep
+
+Hence
+  FunCo r co1 co2 :: (s1->t1) ~r (s2->t2)
+is short for
+  TyConAppCo (->) co_rep1 co_rep2 co1 co2
+where co_rep1, co_rep2 are the coercions on the representations.
 -}
 
+
 -- | This breaks a 'Coercion' with type @T A B C ~ T D E F@ into
 -- a list of 'Coercion's of kinds @A ~ D@, @B ~ E@ and @E ~ F@. Hence:
 --
@@ -304,6 +316,16 @@ decomposeCo arity co
   = [mkNthCo n co | n <- [0..(arity-1)] ]
            -- Remember, Nth is zero-indexed
 
+decomposeFunCo :: Coercion -> (Coercion, Coercion)
+-- Expects co :: (s1 -> t1) ~ (s2 -> t2)
+-- Returns (co1 :: s1~s2, co2 :: t1~t2)
+-- See Note [Function coercions] for the "2" and "3"
+decomposeFunCo co = ASSERT2( all_ok, ppr co )
+                    (mkNthCo 2 co, mkNthCo 3 co)
+  where
+    Pair s1t1 s2t2 = coercionKind co
+    all_ok = isFunTy s1t1 && isFunTy s2t2
+
 -- | Attempts to obtain the type variable underlying a 'Coercion'
 getCoVar_maybe :: Coercion -> Maybe CoVar
 getCoVar_maybe (CoVarCo cv) = Just cv
@@ -554,7 +576,7 @@ mkNomReflCo = mkReflCo Nominal
 mkTyConAppCo :: HasDebugCallStack => Role -> TyCon -> [Coercion] -> Coercion
 mkTyConAppCo r tc cos
   | tc `hasKey` funTyConKey
-  , [_rep1, _rep2, co1, co2] <- cos
+  , [_rep1, _rep2, co1, co2] <- cos   -- See Note [Function coercions]
   = -- (a :: TYPE ra) -> (b :: TYPE rb)  ~  (c :: TYPE rc) -> (d :: TYPE rd)
     -- rep1 :: ra  ~  rc        rep2 :: rb  ~  rd
     -- co1  :: a   ~  c         co2  :: b   ~  d
@@ -882,14 +904,26 @@ mkNthCo n (Refl r ty)
 mkNthCo 0 (ForAllCo _ kind_co _) = kind_co
   -- If co :: (forall a1:k1. t1) ~ (forall a2:k2. t2)
   -- then (nth 0 co :: k1 ~ k2)
-mkNthCo n (TyConAppCo _ _ arg_cos) = arg_cos `getNth` n
+
 mkNthCo n co@(FunCo _ arg res)
+  -- See Note [Function coercions]
+  -- If FunCo _ arg_co res_co ::   (s1:TYPE sk1 -> s2:TYPE sk2)
+  --                             ~ (t1:TYPE tk1 -> t2:TYPE tk2)
+  -- Then we want to behave as if co was
+  --    TyConAppCo argk_co resk_co arg_co res_co
+  -- where
+  --    argk_co :: sk1 ~ tk1  =  mkNthCo 0 (mkKindCo arg_co)
+  --    resk_co :: sk2 ~ tk2  =  mkNthCo 0 (mkKindCo res_co)
+  --                             i.e. mkRuntimeRepCo
   = case n of
       0 -> mkRuntimeRepCo arg
       1 -> mkRuntimeRepCo res
       2 -> arg
       3 -> res
       _ -> pprPanic "mkNthCo(FunCo)" (ppr n $$ ppr co)
+
+mkNthCo n (TyConAppCo _ _ arg_cos) = arg_cos `getNth` n
+
 mkNthCo n co = NthCo n co
 
 mkLRCo :: LeftOrRight -> Coercion -> Coercion
@@ -937,8 +971,10 @@ mkKindCo co
        -- generally, calling coercionKind during coercion creation is a bad idea,
        -- as it can lead to exponential behavior. But, we don't have nested mkKindCos,
        -- so it's OK here.
-  , typeKind ty1 `eqType` typeKind ty2
-  = Refl Nominal (typeKind ty1)
+  , let tk1 = typeKind ty1
+        tk2 = typeKind ty2
+  , tk1 `eqType` tk2
+  = Refl Nominal tk1
   | otherwise
   = KindCo co
 
index 7c6409f..ed879eb 100644 (file)
@@ -843,7 +843,7 @@ unify_ty (CoercionTy co1) (CoercionTy co2) kco
              -> do { b <- tvBindFlagL cv
                    ; if b == BindMe
                        then do { checkRnEnvRCo co2
-                               ; let [_, _, co_l, co_r] = decomposeCo 4 kco
+                               ; let (co_l, co_r) = decomposeFunCo kco
                                   -- cv :: t1 ~ t2
                                   -- co2 :: s1 ~ s2
                                   -- co_l :: t1 ~ s1