Add new `mpz_{sub,add}_ui`-based primop (re #8647)
authorHerbert Valerio Riedel <hvr@gnu.org>
Fri, 3 Jan 2014 21:36:04 +0000 (22:36 +0100)
committerHerbert Valerio Riedel <hvr@gnu.org>
Sat, 4 Jan 2014 21:41:34 +0000 (22:41 +0100)
This adds `{plus,minus}IntegerInt#` which help to reduce temporary
allocations in `plusInteger` and `minusInteger`.

This and the previous commit introducing `timesIntegerInt#` (i.e. baeeef7af6e)
result in reduced allocations for the following nofib benchmarks on Linux/amd64:

         Program      Size    Allocs   Runtime   Elapsed  TotalMem
  ------------------------------------------------------------------
      bernouilli     +0.0%     -4.2%      0.12      0.12     +0.0%
           kahan     +0.1%    -12.6%      0.17      0.17     +0.0%
        pidigits     +0.0%     -0.5%     -4.7%     -4.5%     +0.0%
           power     +0.0%     -2.7%     +3.1%     +3.1%     +9.1%
       primetest     +0.0%     -4.2%      0.07      0.07     +0.0%
             rsa     +0.0%     -4.1%      0.02      0.02     +0.0%
             scs     +0.0%     -2.6%     -0.8%     -0.7%     +0.0%
  ------------------------------------------------------------------
             Min     +0.0%    -12.6%     -4.7%     -4.5%     -5.0%
             Max     +0.1%     +0.2%     +3.1%     +3.1%     +9.1%
  Geometric Mean     +0.1%     -0.3%     -0.0%     +0.0%     +0.1%
  ------------------------------------------------------------------

Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
GHC/Integer/GMP/Prim.hs
GHC/Integer/Type.lhs
cbits/gmp-wrappers.cmm

index 3958f13..80a59bd 100644 (file)
@@ -8,7 +8,9 @@ module GHC.Integer.GMP.Prim (
     cmpIntegerInt#,
 
     plusInteger#,
+    plusIntegerInt#,
     minusInteger#,
+    minusIntegerInt#,
     timesInteger#,
     timesIntegerInt#,
 
@@ -88,11 +90,21 @@ foreign import prim "integer_cmm_cmpIntegerIntzh" cmpIntegerInt#
 foreign import prim "integer_cmm_plusIntegerzh" plusInteger#
   :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
 
+-- | Optimized version of 'plusInteger#' for summing big-ints with small-ints
+--
+foreign import prim "integer_cmm_plusIntegerIntzh" plusIntegerInt#
+  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+
 -- |
 --
 foreign import prim "integer_cmm_minusIntegerzh" minusInteger#
   :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
 
+-- | Optimized version of 'minusInteger#' for substracting small-ints from big-ints
+--
+foreign import prim "integer_cmm_minusIntegerIntzh" minusIntegerInt#
+  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+
 -- |
 --
 foreign import prim "integer_cmm_timesIntegerzh" timesInteger#
index 5c6919c..0e3cec7 100644 (file)
@@ -38,7 +38,8 @@ import GHC.Prim (
 import GHC.Integer.GMP.Prim (
     -- GMP-related primitives
     cmpInteger#, cmpIntegerInt#,
-    plusInteger#, minusInteger#, timesInteger#, timesIntegerInt#,
+    plusInteger#, plusIntegerInt#, minusInteger#, minusIntegerInt#,
+    timesInteger#, timesIntegerInt#,
     quotRemInteger#, quotInteger#, remInteger#,
     divModInteger#, divInteger#, modInteger#,
     gcdInteger#, gcdExtInteger#, gcdIntegerInt#, gcdInt#, divExactInteger#,
@@ -505,25 +506,34 @@ signumInteger (J# s d)
 
 {-# NOINLINE plusInteger #-}
 plusInteger :: Integer -> Integer -> Integer
-plusInteger i1@(S# i) i2@(S# j)  = case addIntC# i j of
+plusInteger (S# i)      (S# j)   = case addIntC# i j of
                                    (# r, c #) ->
                                        if isTrue# (c ==# 0#)
                                        then S# r
-                                       else plusInteger (toBig i1) (toBig i2)
-plusInteger i1@(J# _ _) i2@(S# _) = plusInteger i1 (toBig i2)
-plusInteger i1@(S# _) i2@(J# _ _) = plusInteger (toBig i1) i2
+                                       else case int2Integer# i of
+                                            (# s, d #) -> case plusIntegerInt# s d j of
+                                                          (# s', d' #) -> J# s' d'
+plusInteger i1@(J# _ _) (S# 0#)   = i1
+plusInteger (J# s1 d1)  (S# j)    = case plusIntegerInt# s1 d1 j of
+                                    (# s, d #) -> smartJ# s d
+plusInteger i1@(S# _) i2@(J# _ _) = plusInteger i2 i1
 plusInteger (J# s1 d1) (J# s2 d2) = case plusInteger# s1 d1 s2 d2 of
                                     (# s, d #) -> smartJ# s d
 
 {-# NOINLINE minusInteger #-}
 minusInteger :: Integer -> Integer -> Integer
-minusInteger i1@(S# i) i2@(S# j)   = case subIntC# i j of
+minusInteger (S# i)      (S# j)    = case subIntC# i j of
                                      (# r, c #) ->
                                          if isTrue# (c ==# 0#) then S# r
-                                         else minusInteger (toBig i1)
-                                                           (toBig i2)
-minusInteger i1@(J# _ _) i2@(S# _) = minusInteger i1 (toBig i2)
-minusInteger i1@(S# _) i2@(J# _ _) = minusInteger (toBig i1) i2
+                                         else case int2Integer# i of
+                                              (# s, d #) -> case minusIntegerInt# s d j of
+                                                            (# s', d' #) -> J# s' d'
+minusInteger i1@(J# _ _) (S# 0#)   = i1
+minusInteger (J# s1 d1)  (S# j)    = case minusIntegerInt# s1 d1 j of
+                                     (# s, d #) -> smartJ# s d
+minusInteger (S# 0#)    (J# s2 d2) = J# (negateInt# s2) d2
+minusInteger (S# i)     (J# s2 d2) = case plusIntegerInt# (negateInt# s2) d2 i of
+                                     (# s, d #) -> smartJ# s d
 minusInteger (J# s1 d1) (J# s2 d2) = case minusInteger# s1 d1 s2 d2 of
                                      (# s, d #) -> smartJ# s d
 
index 39b6fba..3ab699e 100644 (file)
@@ -30,7 +30,9 @@
 
 import "integer-gmp" __gmpz_init;
 import "integer-gmp" __gmpz_add;
+import "integer-gmp" __gmpz_add_ui;
 import "integer-gmp" __gmpz_sub;
+import "integer-gmp" __gmpz_sub_ui;
 import "integer-gmp" __gmpz_mul;
 import "integer-gmp" __gmpz_mul_2exp;
 import "integer-gmp" __gmpz_mul_si;
@@ -646,3 +648,35 @@ integer_cmm_decodeDoublezh (D_ arg)
     /* returns: (Int# (expn), Int#, ByteArray#) */
     return (W_[mp_tmp_w], TO_W_(MP_INT__mp_size(mp_tmp1)), p);
 }
+
+/* :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) */
+#define GMPX_TAKE1_UL1_RET1(name,pos_arg_fun,neg_arg_fun)               \
+name(W_ ws1, P_ d1, W_ wl)                                              \
+{                                                                       \
+  W_ mp_tmp;                                                            \
+  W_ mp_result;                                                         \
+                                                                        \
+again:                                                                  \
+  STK_CHK_GEN_N (2 * SIZEOF_MP_INT);                                    \
+  MAYBE_GC(again);                                                      \
+                                                                        \
+  mp_tmp     = Sp - 1 * SIZEOF_MP_INT;                                  \
+  mp_result  = Sp - 2 * SIZEOF_MP_INT;                                  \
+                                                                        \
+  MP_INT_SET_FROM_BA(mp_tmp,ws1,d1);                                    \
+                                                                        \
+  ccall __gmpz_init(mp_result "ptr");                                   \
+                                                                        \
+  if(%lt(wl,0)) {                                                       \
+      ccall neg_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(-wl)); \
+      return(MP_INT_AS_PAIR(mp_result));                                \
+  }                                                                     \
+                                                                        \
+  ccall pos_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wl));      \
+  return(MP_INT_AS_PAIR(mp_result));                                    \
+}
+
+/* NB: We need both primitives as we can't express 'minusIntegerInt#'
+   in terms of 'plusIntegerInt#' for @minBound :: Int@ */
+GMPX_TAKE1_UL1_RET1(integer_cmm_plusIntegerIntzh,__gmpz_add_ui,__gmpz_sub_ui)
+GMPX_TAKE1_UL1_RET1(integer_cmm_minusIntegerIntzh,__gmpz_sub_ui,__gmpz_add_ui)