Expose new internal exponentiation primitives
authorHerbert Valerio Riedel <hvr@gnu.org>
Sun, 29 Sep 2013 08:05:05 +0000 (10:05 +0200)
committerHerbert Valerio Riedel <hvr@gnu.org>
Sun, 29 Sep 2013 16:06:43 +0000 (18:06 +0200)
This exposes the GMP functions `mpz_pow_ui()`, `mpz_powm()`, and
`mpz_invert()` as `powInteger`, `powModInteger`, and `recipModInteger`
respectively in the module `GHC.Integer.GMP.Internals`.

Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
GHC/Integer/GMP/Internals.hs
GHC/Integer/GMP/Prim.hs
GHC/Integer/Type.lhs
cbits/gmp-wrappers.cmm

index 4ad1f62..4a7ff5d 100644 (file)
@@ -1,6 +1,6 @@
 {-# LANGUAGE NoImplicitPrelude #-}
 
-module GHC.Integer.GMP.Internals (Integer(..), gcdInt, gcdInteger, lcmInteger)
+module GHC.Integer.GMP.Internals (Integer(..), gcdInt, gcdInteger, lcmInteger, powInteger, powModInteger, recipModInteger)
     where
 
 import GHC.Integer.Type
index 7c28ce2..59aa6f4 100644 (file)
@@ -41,6 +41,10 @@ module GHC.Integer.GMP.Prim (
     mul2ExpInteger#,
     fdivQ2ExpInteger#,
 
+    powInteger#,
+    powModInteger#,
+    recipModInteger#,
+
 #if WORD_SIZE_IN_BITS < 64
     int64ToInteger#,  integerToInt64#,
     word64ToInteger#, integerToWord64#,
@@ -179,6 +183,21 @@ foreign import prim "integer_cmm_fdivQ2ExpIntegerzh" fdivQ2ExpInteger#
 
 -- |
 --
+foreign import prim "integer_cmm_powIntegerzh" powInteger#
+  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+
+-- |
+--
+foreign import prim "integer_cmm_powModIntegerzh" powModInteger#
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+
+-- |
+--
+foreign import prim "integer_cmm_recipModIntegerzh" recipModInteger#
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+
+-- |
+--
 foreign import prim "integer_cmm_complementIntegerzh" complementInteger#
   :: Int# -> ByteArray# -> (# Int#, ByteArray# #)
 
index a01542c..554160c 100644 (file)
@@ -45,6 +45,7 @@ import GHC.Integer.GMP.Prim (
     int2Integer#, integer2Int#, word2Integer#, integer2Word#,
     andInteger#, orInteger#, xorInteger#, complementInteger#,
     testBitInteger#, mul2ExpInteger#, fdivQ2ExpInteger#,
+    powInteger#, powModInteger#, recipModInteger#,
 #if WORD_SIZE_IN_BITS < 64
     int64ToInteger#,  integerToInt64#,
     word64ToInteger#, integerToWord64#,
@@ -581,8 +582,46 @@ shiftRInteger (J# s d) i = case fdivQ2ExpInteger# s d i of
 testBitInteger :: Integer -> Int# -> Bool
 testBitInteger j@(S# _) i = testBitInteger (toBig j) i
 testBitInteger (J# s d) i = isTrue# (testBitInteger# s d i /=# 0#)
+
+-- | @powInteger b e@ computes base @b@ raised to exponent @e@.
+{-# NOINLINE powInteger #-}
+powInteger :: Integer -> Word# -> Integer
+powInteger j@(S# _) e = powInteger (toBig j) e
+powInteger (J# s d) e = case powInteger# s d e of
+                            (# s', d' #) -> J# s' d'
+
+-- | @powModInteger b e m@ computes base @b@ raised to exponent @e@
+-- modulo @m@.
+--
+-- Negative exponents are supported if an inverse modulo @m@
+-- exists. It's advised to avoid calling this primitive with negative
+-- exponents unless it is guaranteed the inverse exists, as failure to
+-- do so will likely cause program abortion due to a divide-by-zero
+-- fault. See also 'recipModInteger'.
+{-# NOINLINE powModInteger #-}
+powModInteger :: Integer -> Integer -> Integer -> Integer
+powModInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) =
+    case powModInteger# s1 d1 s2 d2 s3 d3 of
+        (# s', d' #) -> J# s' d'
+powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m)
+
+-- | @recipModInteger x m@ computes the inverse of @x@ modulo @m@. If
+-- the inverse exists, the return value @y@ will satisfy @0 < y <
+-- abs(m)@, otherwise the result is 0.
+--
+-- Note: The implementation exploits the undocumented property of
+-- @mpz_invert()@ to not mangle the result operand (which is initialized
+-- to 0) in case of non-existence of the inverse.
+{-# NOINLINE recipModInteger #-}
+recipModInteger :: Integer -> Integer -> Integer
+recipModInteger j@(S# _) m@(S# _)   = recipModInteger (toBig j) (toBig m)
+recipModInteger j@(S# _) m@(J# _ _) = recipModInteger (toBig j) m
+recipModInteger j@(J# _ _) m@(S# _) = recipModInteger j (toBig m)
+recipModInteger (J# s d) (J# ms md) = case recipModInteger# s d ms md of
+                           (# s', d' #) -> J# s' d'
 \end{code}
 
+
 %*********************************************************
 %*                                                      *
 \subsection{The @Integer@ hashing@}
index 8a201f1..5c7bb0b 100644 (file)
@@ -49,6 +49,9 @@ import "integer-gmp" __gmpz_and;
 import "integer-gmp" __gmpz_xor;
 import "integer-gmp" __gmpz_ior;
 import "integer-gmp" __gmpz_com;
+import "integer-gmp" __gmpz_pow_ui;
+import "integer-gmp" __gmpz_powm;
+import "integer-gmp" __gmpz_invert;
 
 import "integer-gmp" integer_cbits_decodeDouble;
 
@@ -246,6 +249,47 @@ again:                                                          \
          MP_INT__mp_d(mp_result1) - SIZEOF_StgArrWords);        \
 }
 
+#define GMP_TAKE3_RET1(name,mp_fun)                             \
+name (W_ ws1, P_ d1, W_ ws2, P_ d2, W_ ws3, P_ d3)              \
+{                                                               \
+  CInt s1, s2, s3;                                              \
+  W_ mp_tmp1;                                                   \
+  W_ mp_tmp2;                                                   \
+  W_ mp_tmp3;                                                   \
+  W_ mp_result1;                                                \
+                                                                \
+again:                                                          \
+  STK_CHK_GEN_N (4 * SIZEOF_MP_INT);                            \
+  MAYBE_GC(again);                                              \
+                                                                \
+  s1 = W_TO_INT(ws1);                                           \
+  s2 = W_TO_INT(ws2);                                           \
+  s3 = W_TO_INT(ws3);                                           \
+                                                                \
+  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                          \
+  mp_tmp2    = Sp - 2 * SIZEOF_MP_INT;                          \
+  mp_tmp3    = Sp - 3 * SIZEOF_MP_INT;                          \
+  mp_result1 = Sp - 4 * SIZEOF_MP_INT;                          \
+  MP_INT__mp_alloc(mp_tmp1) = W_TO_INT(BYTE_ARR_WDS(d1));       \
+  MP_INT__mp_size(mp_tmp1)  = (s1);                             \
+  MP_INT__mp_d(mp_tmp1)     = BYTE_ARR_CTS(d1);                 \
+  MP_INT__mp_alloc(mp_tmp2) = W_TO_INT(BYTE_ARR_WDS(d2));       \
+  MP_INT__mp_size(mp_tmp2)  = (s2);                             \
+  MP_INT__mp_d(mp_tmp2)     = BYTE_ARR_CTS(d2);                 \
+  MP_INT__mp_alloc(mp_tmp3) = W_TO_INT(BYTE_ARR_WDS(d3));       \
+  MP_INT__mp_size(mp_tmp3)  = (s3);                             \
+  MP_INT__mp_d(mp_tmp3)     = BYTE_ARR_CTS(d3);                 \
+                                                                \
+  ccall __gmpz_init(mp_result1 "ptr");                          \
+                                                                \
+  /* Perform the operation */                                   \
+  ccall mp_fun(mp_result1 "ptr",mp_tmp1  "ptr",mp_tmp2  "ptr",  \
+               mp_tmp3  "ptr");                                 \
+                                                                \
+  return (TO_W_(MP_INT__mp_size(mp_result1)),                   \
+         MP_INT__mp_d(mp_result1) - SIZEOF_StgArrWords);        \
+}
+
 #define GMP_TAKE1_UL1_RET1(name,mp_fun)                         \
 name (W_ ws1, P_ d1, W_ wul)                                    \
 {                                                               \
@@ -389,6 +433,10 @@ GMP_TAKE1_RET1(integer_cmm_complementIntegerzh,     __gmpz_com)
 GMP_TAKE2_RET2(integer_cmm_quotRemIntegerzh,        __gmpz_tdiv_qr)
 GMP_TAKE2_RET2(integer_cmm_divModIntegerzh,         __gmpz_fdiv_qr)
 
+GMP_TAKE3_RET1(integer_cmm_powModIntegerzh,         __gmpz_powm)
+GMP_TAKE2_RET1(integer_cmm_recipModIntegerzh,       __gmpz_invert)
+GMP_TAKE1_UL1_RET1(integer_cmm_powIntegerzh,        __gmpz_pow_ui)
+
 integer_cmm_gcdIntzh (W_ int1, W_ int2)
 {
     W_ r;