Generate better fp abs for X86 and llvm with default cmm otherwise
authorDominic Steinitz <dominic@steinitz.org>
Tue, 7 Mar 2017 14:26:16 +0000 (09:26 -0500)
committerBen Gamari <ben@smart-cactus.org>
Tue, 7 Mar 2017 18:32:33 +0000 (13:32 -0500)
Currently we have this in libraries/base/GHC/Float.hs:
```
abs x | x == 0    = 0 -- handles (-0.0)
      | x >  0    = x
      | otherwise = negateFloat x
```
But 3-4 years ago it was noted that this was inefficient:
https://mail.haskell.org/pipermail/libraries/2013-April/019690.html

We can generate better code for X86 and llvm and for others generate
some custom cmm code which is similar to what the compiler generates
now.

Reviewers: austin, simonmar, hvr, bgamari

Reviewed By: bgamari

Subscribers: dfeuer, thomie

Differential Revision: https://phabricator.haskell.org/D3265

compiler/cmm/CmmMachOp.hs
compiler/cmm/PprC.hs
compiler/codeGen/StgCmmPrim.hs
compiler/llvmGen/LlvmCodeGen/CodeGen.hs
compiler/nativeGen/PPC/CodeGen.hs
compiler/nativeGen/SPARC/CodeGen.hs
compiler/nativeGen/X86/CodeGen.hs
compiler/nativeGen/X86/Ppr.hs
compiler/prelude/primops.txt.pp
libraries/base/GHC/Float.hs

index a8cbd68..d736f14 100644 (file)
@@ -528,6 +528,7 @@ data CallishMachOp
   | MO_F64_Atan
   | MO_F64_Log
   | MO_F64_Exp
+  | MO_F64_Fabs
   | MO_F64_Sqrt
   | MO_F32_Pwr
   | MO_F32_Sin
@@ -541,6 +542,7 @@ data CallishMachOp
   | MO_F32_Atan
   | MO_F32_Log
   | MO_F32_Exp
+  | MO_F32_Fabs
   | MO_F32_Sqrt
 
   | MO_UF_Conv Width
index dba8ca6..6a84e30 100644 (file)
@@ -754,6 +754,7 @@ pprCallishMachOp_for_C mop
         MO_F64_Log      -> text "log"
         MO_F64_Exp      -> text "exp"
         MO_F64_Sqrt     -> text "sqrt"
+        MO_F64_Fabs     -> unsupported
         MO_F32_Pwr      -> text "powf"
         MO_F32_Sin      -> text "sinf"
         MO_F32_Cos      -> text "cosf"
@@ -767,6 +768,7 @@ pprCallishMachOp_for_C mop
         MO_F32_Log      -> text "logf"
         MO_F32_Exp      -> text "expf"
         MO_F32_Sqrt     -> text "sqrtf"
+        MO_F32_Fabs     -> unsupported
         MO_WriteBarrier -> text "write_barrier"
         MO_Memcpy _     -> text "memcpy"
         MO_Memset _     -> text "memset"
index 14eb425..0edde06 100644 (file)
@@ -844,6 +844,12 @@ callishPrimOpSupported dflags op
       WordMul2Op     | ncg && x86ish
                          || llvm      -> Left (MO_U_Mul2     (wordWidth dflags))
                      | otherwise      -> Right genericWordMul2Op
+      FloatFabsOp    | (ncg && x86ish)
+                         || llvm      -> Left MO_F32_Fabs
+                     | otherwise      -> Right $ genericFabsOp W32
+      DoubleFabsOp   | (ncg && x86ish)
+                         || llvm      -> Left MO_F64_Fabs
+                     | otherwise      -> Right $ genericFabsOp W64
 
       _ -> pprPanic "emitPrimOp: can't translate PrimOp " (ppr op)
  where
@@ -1064,6 +1070,34 @@ genericWordMul2Op [res_h, res_l] [arg_x, arg_y]
                         topHalf (CmmReg r)])]
 genericWordMul2Op _ _ = panic "genericWordMul2Op"
 
+-- This replicates what we had in libraries/base/GHC/Float.hs:
+--
+--    abs x    | x == 0    = 0 -- handles (-0.0)
+--             | x >  0    = x
+--             | otherwise = negateFloat x
+genericFabsOp :: Width -> GenericOp
+genericFabsOp w [res_r] [aa]
+ = do dflags <- getDynFlags
+      let zero   = CmmLit (CmmFloat 0 w)
+
+          eq x y = CmmMachOp (MO_F_Eq w) [x, y]
+          gt x y = CmmMachOp (MO_F_Gt w) [x, y]
+
+          neg x  = CmmMachOp (MO_F_Neg w) [x]
+
+          g1 = catAGraphs [mkAssign (CmmLocal res_r) zero]
+          g2 = catAGraphs [mkAssign (CmmLocal res_r) aa]
+
+      res_t <- CmmLocal <$> newTemp (cmmExprType dflags aa)
+      let g3 = catAGraphs [mkAssign res_t aa,
+                           mkAssign (CmmLocal res_r) (neg (CmmReg res_t))]
+
+      g4 <- mkCmmIfThenElse (gt aa zero) g2 g3
+
+      emit =<< mkCmmIfThenElse (eq aa zero) g1 g4
+
+genericFabsOp _ _ _ = panic "genericFabsOp"
+
 -- These PrimOps are NOPs in Cmm
 
 nopOp :: PrimOp -> Bool
index d88d234..40c5498 100644 (file)
@@ -690,6 +690,7 @@ cmmPrimOpFunctions mop = do
     MO_F32_Exp    -> fsLit "expf"
     MO_F32_Log    -> fsLit "logf"
     MO_F32_Sqrt   -> fsLit "llvm.sqrt.f32"
+    MO_F32_Fabs   -> fsLit "llvm.fabs.f32"
     MO_F32_Pwr    -> fsLit "llvm.pow.f32"
 
     MO_F32_Sin    -> fsLit "llvm.sin.f32"
@@ -707,6 +708,7 @@ cmmPrimOpFunctions mop = do
     MO_F64_Exp    -> fsLit "exp"
     MO_F64_Log    -> fsLit "log"
     MO_F64_Sqrt   -> fsLit "llvm.sqrt.f64"
+    MO_F64_Fabs   -> fsLit "llvm.fabs.f64"
     MO_F64_Pwr    -> fsLit "llvm.pow.f64"
 
     MO_F64_Sin    -> fsLit "llvm.sin.f64"
index 849516f..1f06c7b 100644 (file)
@@ -1525,6 +1525,7 @@ genCCall' dflags gcp target dest_regs args
                     MO_F32_Exp   -> (fsLit "exp", True)
                     MO_F32_Log   -> (fsLit "log", True)
                     MO_F32_Sqrt  -> (fsLit "sqrt", True)
+                    MO_F32_Fabs  -> unsupported
 
                     MO_F32_Sin   -> (fsLit "sin", True)
                     MO_F32_Cos   -> (fsLit "cos", True)
@@ -1542,6 +1543,7 @@ genCCall' dflags gcp target dest_regs args
                     MO_F64_Exp   -> (fsLit "exp", False)
                     MO_F64_Log   -> (fsLit "log", False)
                     MO_F64_Sqrt  -> (fsLit "sqrt", False)
+                    MO_F64_Fabs  -> unsupported
 
                     MO_F64_Sin   -> (fsLit "sin", False)
                     MO_F64_Cos   -> (fsLit "cos", False)
index a6d3f94..3e9058b 100644 (file)
@@ -610,6 +610,7 @@ outOfLineMachOp_table mop
         MO_F32_Exp    -> fsLit "expf"
         MO_F32_Log    -> fsLit "logf"
         MO_F32_Sqrt   -> fsLit "sqrtf"
+        MO_F32_Fabs   -> unsupported
         MO_F32_Pwr    -> fsLit "powf"
 
         MO_F32_Sin    -> fsLit "sinf"
@@ -627,6 +628,7 @@ outOfLineMachOp_table mop
         MO_F64_Exp    -> fsLit "exp"
         MO_F64_Log    -> fsLit "log"
         MO_F64_Sqrt   -> fsLit "sqrt"
+        MO_F64_Fabs   -> unsupported
         MO_F64_Pwr    -> fsLit "pow"
 
         MO_F64_Sin    -> fsLit "sin"
index 877d822..704514e 100644 (file)
@@ -2043,17 +2043,24 @@ genCCall dflags is32Bit (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] =
 genCCall _ is32Bit target dest_regs args = do
   dflags <- getDynFlags
   let platform = targetPlatform dflags
+      sse2     = isSse2Enabled dflags
   case (target, dest_regs) of
     -- void return type prim op
     (PrimTarget op, []) ->
         outOfLineCmmOp op Nothing args
     -- we only cope with a single result for foreign calls
     (PrimTarget op, [r])
-      | not is32Bit -> outOfLineCmmOp op (Just r) args
+      | sse2 -> case op of
+          MO_F32_Fabs -> case args of
+            [x] -> sse2FabsCode W32 x
+            _ -> panic "genCCall: Wrong number of arguments for fabs"
+          MO_F64_Fabs -> case args of
+            [x] -> sse2FabsCode W64 x
+            _ -> panic "genCCall: Wrong number of arguments for fabs"
+          _other_op -> outOfLineCmmOp op (Just r) args
       | otherwise -> do
         l1 <- getNewLabelNat
         l2 <- getNewLabelNat
-        sse2 <- sse2Enabled
         if sse2
           then
             outOfLineCmmOp op (Just r) args
@@ -2082,6 +2089,23 @@ genCCall _ is32Bit target dest_regs args = do
               = panic $ "genCCall.actuallyInlineFloatOp: bad number of arguments! ("
                       ++ show (length args) ++ ")"
 
+        sse2FabsCode :: Width -> CmmExpr -> NatM InstrBlock
+        sse2FabsCode w x = do
+          let fmt = floatFormat w
+          x_code <- getAnyReg x
+          let
+            const | FF32 <- fmt = CmmInt 0x7fffffff W32
+                  | otherwise   = CmmInt 0x7fffffffffffffff W64
+          Amode amode amode_code <- memConstant (widthInBytes w) const
+          tmp <- getNewRegNat fmt
+          let
+            code dst = x_code dst `appOL` amode_code `appOL` toOL [
+                MOV fmt (OpAddr amode) (OpReg tmp),
+                AND fmt (OpReg tmp) (OpReg dst)
+                ]
+
+          return $ code (getRegisterReg platform True (CmmLocal r))
+
     (PrimTarget (MO_S_QuotRem  width), _) -> divOp1 platform True  width dest_regs args
     (PrimTarget (MO_U_QuotRem  width), _) -> divOp1 platform False width dest_regs args
     (PrimTarget (MO_U_QuotRem2 width), _) -> divOp2 platform False width dest_regs args
@@ -2599,6 +2623,7 @@ outOfLineCmmOp mop res args
 
         fn = case mop of
               MO_F32_Sqrt  -> fsLit "sqrtf"
+              MO_F32_Fabs  -> unsupported
               MO_F32_Sin   -> fsLit "sinf"
               MO_F32_Cos   -> fsLit "cosf"
               MO_F32_Tan   -> fsLit "tanf"
@@ -2615,6 +2640,7 @@ outOfLineCmmOp mop res args
               MO_F32_Pwr   -> fsLit "powf"
 
               MO_F64_Sqrt  -> fsLit "sqrt"
+              MO_F64_Fabs  -> unsupported
               MO_F64_Sin   -> fsLit "sin"
               MO_F64_Cos   -> fsLit "cos"
               MO_F64_Tan   -> fsLit "tan"
@@ -3050,8 +3076,16 @@ sse2NegCode w x = do
   x_code <- getAnyReg x
   -- This is how gcc does it, so it can't be that bad:
   let
-    const | FF32 <- fmt = CmmInt 0x80000000 W32
-          | otherwise   = CmmInt 0x8000000000000000 W64
+    const = case fmt of
+      FF32 -> CmmInt 0x80000000 W32
+      FF64 -> CmmInt 0x8000000000000000 W64
+      x@II8  -> wrongFmt x
+      x@II16 -> wrongFmt x
+      x@II32 -> wrongFmt x
+      x@II64 -> wrongFmt x
+      x@FF80 -> wrongFmt x
+      where
+        wrongFmt x = panic $ "sse2NegCode: " ++ show x
   Amode amode amode_code <- memConstant (widthInBytes w) const
   tmp <- getNewRegNat fmt
   let
index 223ea13..7d19e99 100644 (file)
@@ -631,6 +631,8 @@ pprInstr (SUB_CC format src dst)
 pprInstr (AND II64 src@(OpImm (ImmInteger mask)) dst)
   | 0 <= mask && mask < 0xffffffff
     = pprInstr (AND II32 src dst)
+pprInstr (AND FF32 src dst) = pprOpOp (sLit "andps") FF32 src dst
+pprInstr (AND FF64 src dst) = pprOpOp (sLit "andpd") FF64 src dst
 pprInstr (AND format src dst) = pprFormatOpOp (sLit "and") format src dst
 pprInstr (OR  format src dst) = pprFormatOpOp (sLit "or")  format src dst
 
index 11928b6..76cfe67 100644 (file)
@@ -531,6 +531,8 @@ primop   DoubleDivOp   "/##"   Dyadic
 
 primop   DoubleNegOp   "negateDouble#"  Monadic   Double# -> Double#
 
+primop   DoubleFabsOp  "fabsDouble#"    Monadic   Double# -> Double#
+
 primop   Double2IntOp   "double2Int#"          GenPrimOp  Double# -> Int#
    {Truncates a {\tt Double#} value to the nearest {\tt Int#}.
     Results are undefined if the truncation if truncation yields
@@ -657,6 +659,8 @@ primop   FloatDivOp   "divideFloat#"      Dyadic
 
 primop   FloatNegOp   "negateFloat#"      Monadic    Float# -> Float#
 
+primop   FloatFabsOp  "fabsFloat#"        Monadic    Float# -> Float#
+
 primop   Float2IntOp   "float2Int#"      GenPrimOp  Float# -> Int#
    {Truncates a {\tt Float#} value to the nearest {\tt Int#}.
     Results are undefined if the truncation if truncation yields
index 18dd288..64467b3 100644 (file)
@@ -245,9 +245,7 @@ instance  Num Float  where
     (-)         x y     =  minusFloat x y
     negate      x       =  negateFloat x
     (*)         x y     =  timesFloat x y
-    abs x    | x == 0    = 0 -- handles (-0.0)
-             | x >  0    = x
-             | otherwise = negateFloat x
+    abs         x       =  fabsFloat x
     signum x | x > 0     = 1
              | x < 0     = negateFloat 1
              | otherwise = x -- handles 0.0, (-0.0), and NaN
@@ -427,9 +425,7 @@ instance  Num Double  where
     (-)         x y     =  minusDouble x y
     negate      x       =  negateDouble x
     (*)         x y     =  timesDouble x y
-    abs x    | x == 0    = 0 -- handles (-0.0)
-             | x >  0    = x
-             | otherwise = negateDouble x
+    abs         x       =  fabsDouble x
     signum x | x > 0     = 1
              | x < 0     = negateDouble 1
              | otherwise = x -- handles 0.0, (-0.0), and NaN
@@ -1087,13 +1083,14 @@ geFloat     (F# x) (F# y) = isTrue# (geFloat# x y)
 ltFloat     (F# x) (F# y) = isTrue# (ltFloat# x y)
 leFloat     (F# x) (F# y) = isTrue# (leFloat# x y)
 
-expFloat, logFloat, sqrtFloat :: Float -> Float
+expFloat, logFloat, sqrtFloat, fabsFloat :: Float -> Float
 sinFloat, cosFloat, tanFloat  :: Float -> Float
 asinFloat, acosFloat, atanFloat  :: Float -> Float
 sinhFloat, coshFloat, tanhFloat  :: Float -> Float
 expFloat    (F# x) = F# (expFloat# x)
 logFloat    (F# x) = F# (logFloat# x)
 sqrtFloat   (F# x) = F# (sqrtFloat# x)
+fabsFloat   (F# x) = F# (fabsFloat# x)
 sinFloat    (F# x) = F# (sinFloat# x)
 cosFloat    (F# x) = F# (cosFloat# x)
 tanFloat    (F# x) = F# (tanFloat# x)
@@ -1131,13 +1128,14 @@ double2Float (D# x) = F# (double2Float# x)
 float2Double :: Float -> Double
 float2Double (F# x) = D# (float2Double# x)
 
-expDouble, logDouble, sqrtDouble :: Double -> Double
+expDouble, logDouble, sqrtDouble, fabsDouble :: Double -> Double
 sinDouble, cosDouble, tanDouble  :: Double -> Double
 asinDouble, acosDouble, atanDouble  :: Double -> Double
 sinhDouble, coshDouble, tanhDouble  :: Double -> Double
 expDouble    (D# x) = D# (expDouble# x)
 logDouble    (D# x) = D# (logDouble# x)
 sqrtDouble   (D# x) = D# (sqrtDouble# x)
+fabsDouble   (D# x) = D# (fabsDouble# x)
 sinDouble    (D# x) = D# (sinDouble# x)
 cosDouble    (D# x) = D# (cosDouble# x)
 tanDouble    (D# x) = D# (tanDouble# x)