Scrutinee Constant Folding
authorSylvain Henry <sylvain@haskus.fr>
Fri, 9 Dec 2016 15:26:34 +0000 (10:26 -0500)
committerBen Gamari <ben@smart-cactus.org>
Fri, 9 Dec 2016 15:27:34 +0000 (10:27 -0500)
This patch introduces new rules to perform constant folding through
case-expressions.

E.g.,
```
case t -# 10# of _ {  ===> case t of _ {
         5#      -> e1              15#     -> e1
         8#      -> e2              18#     -> e2
         DEFAULT -> e               DEFAULT -> e
```

The initial motivation is that it allows "Merge Nested Cases"
optimization to kick in and to further simplify the code
(see Trac #12877).

Currently we recognize the following operations for Word# and Int#: Add,
Sub, Xor, Not and Negate (for Int# only).

Test Plan: validate

Reviewers: simonpj, austin, bgamari

Reviewed By: simonpj, bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2762

GHC Trac Issues: #12877

compiler/basicTypes/Literal.hs
compiler/main/DynFlags.hs
compiler/prelude/PrelRules.hs
compiler/simplCore/SimplUtils.hs
docs/users_guide/using-optimisation.rst
testsuite/tests/perf/compiler/T12877.hs [new file with mode: 0644]
testsuite/tests/perf/compiler/T12877.stdout [new file with mode: 0644]
testsuite/tests/perf/compiler/all.T
utils/mkUserGuidePart/Options/Optimizations.hs

index 8137596..14ef785 100644 (file)
@@ -29,7 +29,7 @@ module Literal
         , inIntRange, inWordRange, tARGET_MAX_INT, inCharRange
         , isZeroLit
         , litFitsInChar
-        , litValue
+        , litValue, isLitValue, isLitValue_maybe, mapLitValue
 
         -- ** Coercions
         , word2IntLit, int2WordLit
@@ -59,6 +59,7 @@ import Data.ByteString (ByteString)
 import Data.Int
 import Data.Word
 import Data.Char
+import Data.Maybe ( isJust )
 import Data.Data ( Data )
 import Numeric ( fromRat )
 
@@ -271,13 +272,37 @@ isZeroLit _              = False
 -- | Returns the 'Integer' contained in the 'Literal', for when that makes
 -- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
 litValue  :: Literal -> Integer
-litValue (MachChar   c) = toInteger $ ord c
-litValue (MachInt    i) = i
-litValue (MachInt64  i) = i
-litValue (MachWord   i) = i
-litValue (MachWord64 i) = i
-litValue (LitInteger i _) = i
-litValue l = pprPanic "litValue" (ppr l)
+litValue l = case isLitValue_maybe l of
+   Just x  -> x
+   Nothing -> pprPanic "litValue" (ppr l)
+
+-- | Returns the 'Integer' contained in the 'Literal', for when that makes
+-- sense, i.e. for 'Char', 'Int', 'Word' and 'LitInteger'.
+isLitValue_maybe  :: Literal -> Maybe Integer
+isLitValue_maybe (MachChar   c)   = Just $ toInteger $ ord c
+isLitValue_maybe (MachInt    i)   = Just i
+isLitValue_maybe (MachInt64  i)   = Just i
+isLitValue_maybe (MachWord   i)   = Just i
+isLitValue_maybe (MachWord64 i)   = Just i
+isLitValue_maybe (LitInteger i _) = Just i
+isLitValue_maybe _                = Nothing
+
+-- | Apply a function to the 'Integer' contained in the 'Literal', for when that
+-- makes sense, e.g. for 'Char', 'Int', 'Word' and 'LitInteger'.
+mapLitValue  :: (Integer -> Integer) -> Literal -> Literal
+mapLitValue f (MachChar   c)   = MachChar (fchar c)
+   where fchar = chr . fromInteger . f . toInteger . ord
+mapLitValue f (MachInt    i)   = MachInt (f i)
+mapLitValue f (MachInt64  i)   = MachInt64 (f i)
+mapLitValue f (MachWord   i)   = MachWord (f i)
+mapLitValue f (MachWord64 i)   = MachWord64 (f i)
+mapLitValue f (LitInteger i t) = LitInteger (f i) t
+mapLitValue _ l                = pprPanic "mapLitValue" (ppr l)
+
+-- | Indicate if the `Literal` contains an 'Integer' value, e.g. 'Char',
+-- 'Int', 'Word' and 'LitInteger'.
+isLitValue  :: Literal -> Bool
+isLitValue = isJust . isLitValue_maybe
 
 {-
         Coercions
index cbf247c..d7cde29 100644 (file)
@@ -445,6 +445,7 @@ data GeneralFlag
    | Opt_IgnoreAsserts
    | Opt_DoEtaReduction
    | Opt_CaseMerge
+   | Opt_CaseFolding                    -- Constant folding through case-expressions
    | Opt_UnboxStrictFields
    | Opt_UnboxSmallStrictFields
    | Opt_DictsCheap
@@ -3561,6 +3562,7 @@ fFlagsDeps = [
   flagSpec "building-cabal-package"           Opt_BuildingCabalPackage,
   flagSpec "call-arity"                       Opt_CallArity,
   flagSpec "case-merge"                       Opt_CaseMerge,
+  flagSpec "case-folding"                     Opt_CaseFolding,
   flagSpec "cmm-elim-common-blocks"           Opt_CmmElimCommonBlocks,
   flagSpec "cmm-sink"                         Opt_CmmSink,
   flagSpec "cse"                              Opt_CSE,
@@ -4012,6 +4014,7 @@ optLevelFlags -- see Note [Documenting optimisation flags]
 
     , ([1,2],   Opt_CallArity)
     , ([1,2],   Opt_CaseMerge)
+    , ([1,2],   Opt_CaseFolding)
     , ([1,2],   Opt_CmmElimCommonBlocks)
     , ([1,2],   Opt_CmmSink)
     , ([1,2],   Opt_CSE)
index 8868047..e98fd9f 100644 (file)
@@ -15,7 +15,12 @@ ToDo:
 {-# LANGUAGE CPP, RankNTypes #-}
 {-# OPTIONS_GHC -optc-DNON_POSIX_SOURCE #-}
 
-module PrelRules ( primOpRules, builtinRules ) where
+module PrelRules
+   ( primOpRules
+   , builtinRules
+   , caseRules
+   )
+where
 
 #include "HsVersions.h"
 #include "../includes/MachDeps.h"
@@ -1385,3 +1390,53 @@ match_smallIntegerTo primOp _ _ _ [App (Var x) y]
   | idName x == smallIntegerName
   = Just $ App (Var (mkPrimOpId primOp)) y
 match_smallIntegerTo _ _ _ _ _ = Nothing
+
+
+
+--------------------------------------------------------
+-- Constant folding through case-expressions
+--
+-- cf Scrutinee Constant Folding in simplCore/SimplUtils
+--------------------------------------------------------
+
+-- | Match the scrutinee of a case and potentially return a new scrutinee and a
+-- function to apply to each literal alternative.
+caseRules :: CoreExpr -> Maybe (CoreExpr, Integer -> Integer)
+caseRules scrut = case scrut of
+
+   -- v `op` x#
+   App (App (Var f) v) (Lit l)
+      | Just op <- isPrimOpId_maybe f
+      , Just x  <- isLitValue_maybe l ->
+      case op of
+         WordAddOp -> Just (v, \y -> y-x      )
+         IntAddOp  -> Just (v, \y -> y-x      )
+         WordSubOp -> Just (v, \y -> y+x      )
+         IntSubOp  -> Just (v, \y -> y+x      )
+         XorOp     -> Just (v, \y -> y `xor` x)
+         XorIOp    -> Just (v, \y -> y `xor` x)
+         _         -> Nothing
+
+   -- x# `op` v
+   App (App (Var f) (Lit l)) v
+      | Just op <- isPrimOpId_maybe f
+      , Just x  <- isLitValue_maybe l ->
+      case op of
+         WordAddOp -> Just (v, \y -> y-x      )
+         IntAddOp  -> Just (v, \y -> y-x      )
+         WordSubOp -> Just (v, \y -> x-y      )
+         IntSubOp  -> Just (v, \y -> x-y      )
+         XorOp     -> Just (v, \y -> y `xor` x)
+         XorIOp    -> Just (v, \y -> y `xor` x)
+         _         -> Nothing
+
+   -- op v
+   App (Var f) v
+      | Just op <- isPrimOpId_maybe f ->
+      case op of
+         NotOp     -> Just (v, \y -> complement y)
+         NotIOp    -> Just (v, \y -> complement y)
+         IntNegOp  -> Just (v, \y -> negate y    )
+         _         -> Nothing
+
+   _ -> Nothing
index 48dce1d..6c47375 100644 (file)
@@ -60,6 +60,8 @@ import Util
 import MonadUtils
 import Outputable
 import Pair
+import PrelRules
+import Literal
 
 import Control.Monad    ( when )
 
@@ -1752,9 +1754,46 @@ mkCase tries these things
                 False -> False
 
     and similar friends.
+
+3.  Scrutinee Constant Folding
+
+     case x op# k# of _ {  ===> case x of _ {
+        a1# -> e1                  (a1# inv_op# k#) -> e1
+        a2# -> e2                  (a2# inv_op# k#) -> e2
+        ...                        ...
+        DEFAULT -> ed              DEFAULT -> ed
+
+     where (x op# k#) inv_op# k# == x
+
+    And similarly for commuted arguments and for some unary operations.
+
+    The purpose of this transformation is not only to avoid an arithmetic
+    operation at runtime but to allow other transformations to apply in cascade.
+
+    Example with the "Merge Nested Cases" optimization (from #12877):
+
+          main = case t of t0
+             0##     -> ...
+             DEFAULT -> case t0 `minusWord#` 1## of t1
+                0##    -> ...
+                DEFAUT -> case t1 `minusWord#` 1## of t2
+                   0##     -> ...
+                   DEFAULT -> case t2 `minusWord#` 1## of _
+                      0##     -> ...
+                      DEFAULT -> ...
+
+      becomes:
+
+          main = case t of _
+          0##     -> ...
+          1##     -> ...
+          2##     -> ...
+          3##     -> ...
+          DEFAULT -> ...
+
 -}
 
-mkCase, mkCase1, mkCase2
+mkCase, mkCase1, mkCase2, mkCase3
    :: DynFlags
    -> OutExpr -> OutId
    -> OutType -> [OutAlt]               -- Alternatives in standard (increasing) order
@@ -1848,9 +1887,42 @@ mkCase1 _dflags scrut case_bndr _ alts@((_,_,rhs1) : _)      -- Identity case
 mkCase1 dflags scrut bndr alts_ty alts = mkCase2 dflags scrut bndr alts_ty alts
 
 --------------------------------------------------
+--      2. Scrutinee Constant Folding
+--------------------------------------------------
+
+mkCase2 dflags scrut bndr alts_ty alts
+  | gopt Opt_CaseFolding dflags
+  , Just (scrut',f) <- caseRules scrut
+  = mkCase3 dflags scrut' bndr alts_ty (map (mapAlt f) alts)
+  | otherwise
+  = mkCase3 dflags scrut bndr alts_ty alts
+  where
+    -- We need to keep the correct association between the scrutinee and its
+    -- binder if the latter isn't dead. Hence we wrap rhs of alternatives with
+    -- "let bndr = ... in":
+    --
+    --     case v + 10 of y        =====> case v of y
+    --        20      -> e1                 10      -> let y = 20     in e1
+    --        DEFAULT -> e2                 DEFAULT -> let y = v + 10 in e2
+    --
+    -- Other transformations give: =====> case v of y'
+    --                                      10      -> let y = 20      in e1
+    --                                      DEFAULT -> let y = y' + 10 in e2
+    --
+    wrap_rhs l rhs
+      | isDeadBinder bndr = rhs
+      | otherwise         = Let (NonRec bndr l) rhs
+
+    mapAlt f alt@(c,bs,e) = case c of
+      DEFAULT          -> (c, bs, wrap_rhs scrut e)
+      LitAlt l
+        | isLitValue l -> (LitAlt (mapLitValue f l), bs, wrap_rhs (Lit l) e)
+      _ -> pprPanic "Unexpected alternative (mkCase2)" (ppr alt)
+
+--------------------------------------------------
 --      Catch-all
 --------------------------------------------------
-mkCase2 _dflags scrut bndr alts_ty alts
+mkCase3 _dflags scrut bndr alts_ty alts
   = return (Case scrut bndr alts_ty alts)
 
 {-
index 6b58093..3e660c1 100644 (file)
@@ -115,7 +115,7 @@ list.
 
     :default: on
 
-    Merge immediately-nested case expressions that scrutinse the same variable.
+    Merge immediately-nested case expressions that scrutinise the same variable.
     For example, ::
 
           case x of
@@ -131,6 +131,25 @@ list.
              Blue -> e2
              Green -> e2
 
+.. ghc-flag:: -fcase-folding
+
+    :default: on
+
+    Allow constant folding in case expressions that scrutinise some primops:
+    For example, ::
+
+          case x `minusWord#` 10## of
+             10## -> e1
+             20## -> e2
+             v    -> e3
+
+    Is transformed to, ::
+
+          case x of
+             20## -> e1
+             30## -> e2
+             _    -> let v = x `minusWord#` 10## in e3
+
 .. ghc-flag:: -fcall-arity
 
     :default: on
diff --git a/testsuite/tests/perf/compiler/T12877.hs b/testsuite/tests/perf/compiler/T12877.hs
new file mode 100644 (file)
index 0000000..2fc7d58
--- /dev/null
@@ -0,0 +1,117 @@
+-- This ugly cascading case reduces to:
+--    case x of
+--       0 -> "0"
+--       1 -> "1"
+--       _ -> "n"
+--
+-- but only if GHC's case-folding reduction kicks in.
+
+{-# NOINLINE test #-}
+test :: Word -> String
+test x = case x of
+   0  -> "0"
+   1  -> "1"
+   t  -> case t + 1 of
+      1 -> "0"
+      2 -> "1"
+      t  -> case t + 1 of
+         2 -> "0"
+         3 -> "1"
+         t  -> case t + 1 of
+            3 -> "0"
+            4 -> "1"
+            t  -> case t + 1 of
+               4 -> "0"
+               5 -> "1"
+               t  -> case t + 1 of
+                  5 -> "0"
+                  6 -> "1"
+                  t  -> case t + 1 of
+                     6 -> "0"
+                     7 -> "1"
+                     t  -> case t + 1 of
+                        7 -> "0"
+                        8 -> "1"
+                        t  -> case t + 1 of
+                           8 -> "0"
+                           9 -> "1"
+                           t  -> case t + 1 of
+                              10 -> "0"
+                              11 -> "1"
+                              t  -> case t + 1 of
+                                 11 -> "0"
+                                 12 -> "1"
+                                 t  -> case t + 1 of
+                                    12 -> "0"
+                                    13 -> "1"
+                                    t  -> case t + 1 of
+                                       13 -> "0"
+                                       14 -> "1"
+                                       t  -> case t + 1 of
+                                          14 -> "0"
+                                          15 -> "1"
+                                          t  -> case t + 1 of
+                                             15 -> "0"
+                                             16 -> "1"
+                                             t  -> case t + 1 of
+                                                16 -> "0"
+                                                17 -> "1"
+                                                t  -> case t + 1 of
+                                                   17 -> "0"
+                                                   18 -> "1"
+                                                   t  -> case t + 1 of
+                                                      18 -> "0"
+                                                      19 -> "1"
+                                                      t  -> case t + 1 of
+                                                         19 -> "0"
+                                                         20 -> "1"
+                                                         t  -> case t + 1 of
+                                                            20 -> "0"
+                                                            21 -> "1"
+                                                            t  -> case t + 1 of
+                                                               21 -> "0"
+                                                               22 -> "1"
+                                                               t  -> case t + 1 of
+                                                                  22 -> "0"
+                                                                  23 -> "1"
+                                                                  t  -> case t + 1 of
+                                                                     23 -> "0"
+                                                                     24 -> "1"
+                                                                     t  -> case t + 1 of
+                                                                        24 -> "0"
+                                                                        25 -> "1"
+                                                                        t  -> case t + 1 of
+                                                                           25 -> "0"
+                                                                           26 -> "1"
+                                                                           t  -> case t + 1 of
+                                                                              26 -> "0"
+                                                                              27 -> "1"
+                                                                              t  -> case t + 1 of
+                                                                                 27 -> "0"
+                                                                                 28 -> "1"
+                                                                                 t  -> case t + 1 of
+                                                                                    28 -> "0"
+                                                                                    29 -> "1"
+                                                                                    t  -> case t + 1 of
+                                                                                       29 -> "0"
+                                                                                       30 -> "1"
+                                                                                       t  -> case t + 1 of
+                                                                                          30 -> "0"
+                                                                                          31 -> "1"
+                                                                                          t  -> case t + 1 of
+                                                                                             31 -> "0"
+                                                                                             32 -> "1"
+                                                                                             t  -> case t + 1 of
+                                                                                                32 -> "0"
+                                                                                                33 -> "1"
+                                                                                                t  -> case t + 1 of
+                                                                                                   33 -> "0"
+                                                                                                   34 -> "1"
+                                                                                                   t  -> case t + 1 of
+                                                                                                      34 -> "0"
+                                                                                                      35 -> "1"
+                                                                                                      _  -> "n"
+
+main :: IO ()
+main = do
+   putStrLn [last (concat (fmap test [0..12345678]))]
diff --git a/testsuite/tests/perf/compiler/T12877.stdout b/testsuite/tests/perf/compiler/T12877.stdout
new file mode 100644 (file)
index 0000000..8ba3a16
--- /dev/null
@@ -0,0 +1 @@
+n
index 0ccde15..38cbdd0 100644 (file)
@@ -895,3 +895,16 @@ test('T12234',
      compile,
      [''])
 
+test('T12877',
+     [ stats_num_field('bytes allocated',
+          [(wordsize(64), 197582248, 5),
+          # initial:      197582248 (Linux)
+          ])
+     , compiler_stats_num_field('bytes allocated',
+          [(wordsize(64), 135979000, 5),
+          # initial:      135979000 (Linux)
+          ]),
+     ],
+     compile_and_run,
+     ['-O2'])
+
index 29d35a0..b0f9bc5 100644 (file)
@@ -15,6 +15,11 @@ optimizationsOptions =
          , flagType = DynamicFlag
          , flagReverse = "-fno-case-merge"
          }
+  , flag { flagName = "-fcase-folding"
+         , flagDescription = "Enable constant folding in case expressions. Implied by :ghc-flag:`-O`."
+         , flagType = DynamicFlag
+         , flagReverse = "-fno-case-folding"
+         }
   , flag { flagName = "-fcmm-elim-common-blocks"
          , flagDescription =
            "Enable Cmm common block elimination. Implied by :ghc-flag:`-O`."