Wrap `gmpz_tdiv_{q,r,qr}_ui` to optimize `quot`/`rem`
authorHerbert Valerio Riedel <hvr@gnu.org>
Wed, 8 Jan 2014 21:45:13 +0000 (22:45 +0100)
committerHerbert Valerio Riedel <hvr@gnu.org>
Wed, 8 Jan 2014 22:37:41 +0000 (23:37 +0100)
This is useful as `quot`/`rem` are often used with small-int divisors,
like when computing the digits of an `Integer`. This optimization
reduces allocations in the following `nofib` benchmarks:

      Program        Size    Allocs   Runtime   Elapsed  TotalMem
   -----------------------------------------------------------------
        power       +0.3%     -0.8%     -1.2%     -1.2%     +0.0%
    primetest       +0.3%     -3.9%      0.07      0.07     +0.0%
          rsa       +0.3%     -4.0%      0.02      0.02     +0.0%
       symalg       +0.2%     -1.4%      0.01      0.01     +0.0%

This addresses #8647

Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
GHC/Integer/GMP/Prim.hs
GHC/Integer/Type.lhs
cbits/gmp-wrappers.cmm

index 80a59bd..261df29 100644 (file)
@@ -15,8 +15,12 @@ module GHC.Integer.GMP.Prim (
     timesIntegerInt#,
 
     quotRemInteger#,
+    quotRemIntegerWord#,
     quotInteger#,
+    quotIntegerWord#,
     remInteger#,
+    remIntegerWord#,
+
     divModInteger#,
     divInteger#,
     modInteger#,
@@ -122,16 +126,29 @@ foreign import prim "integer_cmm_timesIntegerIntzh" timesIntegerInt#
 foreign import prim "integer_cmm_quotRemIntegerzh" quotRemInteger#
   :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #)
 
+-- | Variant of 'quotRemInteger#'
+--
+foreign import prim "integer_cmm_quotRemIntegerWordzh" quotRemIntegerWord#
+  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray#, Int#, ByteArray# #)
+
 -- | Rounds towards zero.
 --
 foreign import prim "integer_cmm_quotIntegerzh" quotInteger#
   :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
 
+-- | Rounds towards zero.
+foreign import prim "integer_cmm_quotIntegerWordzh" quotIntegerWord#
+  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+
 -- | Satisfies \texttt{plusInteger\# (timesInteger\# (quotInteger\# x y) y) (remInteger\# x y) == x}.
 --
 foreign import prim "integer_cmm_remIntegerzh" remInteger#
   :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
 
+-- | Variant of 'remInteger#'
+foreign import prim "integer_cmm_remIntegerWordzh" remIntegerWord#
+  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+
 -- | Compute div and mod simultaneously, where div rounds towards negative infinity
 -- and\texttt{(q,r) = divModInteger\#(x,y)} implies \texttt{plusInteger\# (timesInteger\# q y) r = x}.
 --
index 0e3cec7..adba180 100644 (file)
@@ -40,7 +40,8 @@ import GHC.Integer.GMP.Prim (
     cmpInteger#, cmpIntegerInt#,
     plusInteger#, plusIntegerInt#, minusInteger#, minusIntegerInt#,
     timesInteger#, timesIntegerInt#,
-    quotRemInteger#, quotInteger#, remInteger#,
+    quotRemInteger#, quotRemIntegerWord#,
+    quotInteger#, quotIntegerWord#, remInteger#, remIntegerWord#,
     divModInteger#, divInteger#, modInteger#,
     gcdInteger#, gcdExtInteger#, gcdIntegerInt#, gcdInt#, divExactInteger#,
     decodeDouble#,
@@ -219,7 +220,16 @@ quotRemInteger :: Integer -> Integer -> (# Integer, Integer #)
 quotRemInteger a@(S# INT_MINBOUND) b = quotRemInteger (toBig a) b
 quotRemInteger (S# i) (S# j) = case quotRemInt# i j of
                                    (# q, r #) -> (# S# q, S# r #)
-quotRemInteger i1@(J# _ _) i2@(S# _) = quotRemInteger i1 (toBig i2)
+quotRemInteger (J# s1 d1) (S# b) | isTrue# (b <# 0#)
+  = case quotRemIntegerWord# s1 d1 (int2Word# (negateInt# b)) of
+          (# s3, d3, s4, d4 #) -> let !q = smartJ# (negateInt# s3) d3
+                                      !r = smartJ# s4 d4
+                                  in (# q, r #)
+quotRemInteger (J# s1 d1) (S# b)
+  = case quotRemIntegerWord# s1 d1 (int2Word# b) of
+          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
+                                      !r = smartJ# s4 d4
+                                  in (# q, r #)
 quotRemInteger i1@(S# _) i2@(J# _ _) = quotRemInteger (toBig i1) i2
 quotRemInteger (J# s1 d1) (J# s2 d2)
   = case (quotRemInteger# s1 d1 s2 d2) of
@@ -262,9 +272,10 @@ remInteger ia@(S# a) (J# sb b)
 -}
 remInteger ia@(S# _) ib@(J# _ _) = remInteger (toBig ia) ib
 remInteger (J# sa a) (S# b)
-  = case int2Integer# b of { (# sb, b' #) ->
-    case remInteger# sa a sb b' of { (# sr, r #) ->
-    S# (integer2Int# sr r) }}
+  = case remIntegerWord# sa a w of
+          (# sr, r #) -> smartJ# sr r
+  where
+    w = int2Word# (if isTrue# (b <# 0#) then negateInt# b else b)
 remInteger (J# sa a) (J# sb b)
   = case remInteger# sa a sb b of (# sr, r #) -> smartJ# sr r
 
@@ -279,9 +290,12 @@ quotInteger (S# a) (J# sb b)
   | otherwise  = S# 0
 -}
 quotInteger ia@(S# _) ib@(J# _ _) = quotInteger (toBig ia) ib
+quotInteger (J# sa a) (S# b) | isTrue# (b <# 0#)
+  = case quotIntegerWord# sa a (int2Word# (negateInt# b)) of
+          (# sq, q #) -> smartJ# (negateInt# sq) q
 quotInteger (J# sa a) (S# b)
-  = case int2Integer# b of { (# sb, b' #) ->
-    case quotInteger# sa a sb b' of (# sq, q #) -> smartJ# sq q }
+  = case quotIntegerWord# sa a (int2Word# b) of
+          (# sq, q #) -> smartJ# sq q
 quotInteger (J# sa a) (J# sb b)
   = case quotInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
 
index 3ab699e..47de995 100644 (file)
@@ -43,10 +43,13 @@ import "integer-gmp" __gmpz_gcdext;
 import "integer-gmp" __gmpn_gcd_1;
 import "integer-gmp" __gmpn_cmp;
 import "integer-gmp" __gmpz_tdiv_q;
+import "integer-gmp" __gmpz_tdiv_q_ui;
 import "integer-gmp" __gmpz_tdiv_r;
+import "integer-gmp" __gmpz_tdiv_r_ui;
 import "integer-gmp" __gmpz_fdiv_q;
 import "integer-gmp" __gmpz_fdiv_r;
 import "integer-gmp" __gmpz_tdiv_qr;
+import "integer-gmp" __gmpz_tdiv_qr_ui;
 import "integer-gmp" __gmpz_fdiv_qr;
 import "integer-gmp" __gmpz_divexact;
 import "integer-gmp" __gmpz_and;
@@ -488,6 +491,33 @@ again:                                                                  \
   return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2));       \
 }
 
+#define GMP_TAKE1_UL1_RET2(name,mp_fun)                                 \
+name (W_ ws1, P_ d1, W_ wul2, P_ d2)                                    \
+{                                                                       \
+  W_ mp_tmp1;                                                           \
+  W_ mp_result1;                                                        \
+  W_ mp_result2;                                                        \
+                                                                        \
+again:                                                                  \
+  STK_CHK_GEN_N (3 * SIZEOF_MP_INT);                                    \
+  MAYBE_GC(again);                                                      \
+                                                                        \
+  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                                  \
+  mp_result1 = Sp - 2 * SIZEOF_MP_INT;                                  \
+  mp_result2 = Sp - 3 * SIZEOF_MP_INT;                                  \
+                                                                        \
+  MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                                   \
+                                                                        \
+  ccall __gmpz_init(mp_result1 "ptr");                                  \
+  ccall __gmpz_init(mp_result2 "ptr");                                  \
+                                                                        \
+  /* Perform the operation */                                           \
+  ccall mp_fun(mp_result1 "ptr", mp_result2 "ptr",                      \
+               mp_tmp1 "ptr", W_TO_LONG(wul2));                         \
+                                                                        \
+  return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2));       \
+}
+
 GMP_TAKE2_RET1(integer_cmm_plusIntegerzh,           __gmpz_add)
 GMP_TAKE2_RET1(integer_cmm_minusIntegerzh,          __gmpz_sub)
 GMP_TAKE2_RET1(integer_cmm_timesIntegerzh,          __gmpz_mul)
@@ -496,7 +526,9 @@ GMP_TAKE2_RET1(integer_cmm_gcdIntegerzh,            __gmpz_gcd)
 #define CMM_GMPZ_GCDEXT(g,s,a,b) __gmpz_gcdext(g,s,NULL,a,b)
 GMP_TAKE2_RET2(integer_cmm_gcdExtIntegerzh,         CMM_GMPZ_GCDEXT)
 GMP_TAKE2_RET1(integer_cmm_quotIntegerzh,           __gmpz_tdiv_q)
+GMP_TAKE1_UL1_RET1(integer_cmm_quotIntegerWordzh,   __gmpz_tdiv_q_ui)
 GMP_TAKE2_RET1(integer_cmm_remIntegerzh,            __gmpz_tdiv_r)
+GMP_TAKE1_UL1_RET1(integer_cmm_remIntegerWordzh,    __gmpz_tdiv_r_ui)
 GMP_TAKE2_RET1(integer_cmm_divIntegerzh,            __gmpz_fdiv_q)
 GMP_TAKE2_RET1(integer_cmm_modIntegerzh,            __gmpz_fdiv_r)
 GMP_TAKE2_RET1(integer_cmm_divExactIntegerzh,       __gmpz_divexact)
@@ -509,6 +541,7 @@ GMP_TAKE1_UL1_RET1(integer_cmm_fdivQ2ExpIntegerzh,  __gmpz_fdiv_q_2exp)
 GMP_TAKE1_RET1(integer_cmm_complementIntegerzh,     __gmpz_com)
 
 GMP_TAKE2_RET2(integer_cmm_quotRemIntegerzh,        __gmpz_tdiv_qr)
+GMP_TAKE1_UL1_RET2(integer_cmm_quotRemIntegerWordzh,__gmpz_tdiv_qr_ui)
 GMP_TAKE2_RET2(integer_cmm_divModIntegerzh,         __gmpz_fdiv_qr)
 
 GMP_TAKE3_RET1(integer_cmm_powModIntegerzh,         __gmpz_powm)