Reimplement `gcdExtInteger` (#9281)
authorHerbert Valerio Riedel <hvr@gnu.org>
Sat, 29 Nov 2014 16:19:05 +0000 (17:19 +0100)
committerHerbert Valerio Riedel <hvr@gnu.org>
Sat, 29 Nov 2014 17:48:07 +0000 (18:48 +0100)
`gcdExtInteger` has been available since `integer-gmp-0.5.1`
(added via 71e29584603cff38e7b83d3eb28b248362569d61)

libraries/integer-gmp2/cbits/wrappers.c
libraries/integer-gmp2/src/GHC/Integer/GMP/Internals.hs
libraries/integer-gmp2/src/GHC/Integer/Type.hs
testsuite/tests/lib/integer/integerGmpInternals.hs

index 3023816..0557ff7 100644 (file)
@@ -56,6 +56,24 @@ mp_limb_zero_p(const mp_limb_t sp[], mp_size_t sn)
   return !sn || ((sn == 1 || sn == -1) && !sp[0]);
 }
 
+static inline mp_size_t
+mp_size_abs(const mp_size_t x)
+{
+  return x>=0 ? x : -x;
+}
+
+static inline mp_size_t
+mp_size_min(const mp_size_t x, const mp_size_t y)
+{
+  return x<y ? x : y;
+}
+
+static inline mp_size_t
+mp_size_minabs(const mp_size_t x, const mp_size_t y)
+{
+  return mp_size_min(mp_size_abs(x), mp_size_abs(y));
+}
+
 /* Perform arithmetic right shift on MPNs (multi-precision naturals)
  *
  * pre-conditions:
@@ -249,6 +267,54 @@ integer_gmp_mpn_gcd(mp_limb_t r[],
   }
 }
 
+/* wraps mpz_gcdext()
+ *
+ * Set g to the greatest common divisor of x and y, and in addition
+ * set s and t to coefficients satisfying x*s + y*t = g.
+ *
+ * The {gp,gn} array is zero-padded (as otherwise 'gn' can't be
+ * reconstructed).
+ *
+ * g must have space for exactly gn=min(xn,yn) limbs.
+ * s must have space for at least xn limbs.
+ *
+ * return value: signed 'sn' of {sp,sn}
+ */
+mp_size_t
+integer_gmp_gcdext(mp_limb_t s0[], mp_limb_t g0[],
+                   const mp_limb_t x0[], const mp_size_t xn,
+                   const mp_limb_t y0[], const mp_size_t yn)
+{
+  const mp_size_t gn0 = mp_size_minabs(xn, yn);
+  const mpz_t x = CONST_MPZ_INIT(x0, mp_limb_zero_p(x0,xn) ? 0 : xn);
+  const mpz_t y = CONST_MPZ_INIT(y0, mp_limb_zero_p(y0,yn) ? 0 : yn);
+
+  mpz_t g, s;
+  mpz_init (g);
+  mpz_init (s);
+
+  mpz_gcdext (g, s, NULL, x, y);
+
+  const mp_size_t gn = g[0]._mp_size;
+  assert(0 <= gn && gn <= gn0);
+  memset(g0, 0, gn0*sizeof(mp_limb_t));
+  memcpy(g0, g[0]._mp_d, gn*sizeof(mp_limb_t));
+  mpz_clear (g);
+
+  const mp_size_t ssn = s[0]._mp_size;
+  const mp_size_t sn  = mp_size_abs(ssn);
+  assert(sn <= xn);
+  memcpy(s0, s[0]._mp_d, sn*sizeof(mp_limb_t));
+  mpz_clear (s);
+
+  if (!sn) {
+    s0[0] = 0;
+    return 1;
+  }
+
+  return ssn;
+}
+
 /* Truncating (i.e. rounded towards zero) integer division-quotient of MPN */
 void
 integer_gmp_mpn_tdiv_q (mp_limb_t q[],
index 9559755..48dd5d2 100644 (file)
@@ -44,6 +44,7 @@ module GHC.Integer.GMP.Internals
     , bitInteger
     , popCountInteger
     , gcdInteger
+    , gcdExtInteger
     , lcmInteger
     , sqrInteger
     , powModInteger
index 6284917..db24560 100644 (file)
@@ -1256,6 +1256,45 @@ gcdBigNat x@(BN# x#) y@(BN# y#)
     nx# = sizeofBigNat# x
     ny# = sizeofBigNat# y
 
+-- | Extended euclidean algorithm.
+--
+-- For @/a/@ and @/b/@, compute their greatest common divisor @/g/@
+-- and the coefficient @/s/@ satisfying @/a//s/ + /b//t/ = /g/@.
+--
+-- /Since: 0.5.1.0/
+{-# NOINLINE gcdExtInteger #-}
+gcdExtInteger :: Integer -> Integer -> (# Integer, Integer #)
+gcdExtInteger a b = case gcdExtSBigNat a' b' of
+    (# g, s #) -> let !g' = bigNatToInteger  g
+                      !s' = sBigNatToInteger s
+                  in (# g', s' #)
+  where
+    a' = integerToSBigNat a
+    b' = integerToSBigNat b
+
+-- internal helper
+gcdExtSBigNat :: SBigNat -> SBigNat -> (# BigNat, SBigNat #)
+gcdExtSBigNat x y = case runS go of (g,s) -> (# g, s #)
+  where
+    go = do
+        g@(MBN# g#) <- newBigNat# gn0#
+        s@(MBN# s#) <- newBigNat# (absI# xn#)
+        I# ssn_# <- liftIO (integer_gmp_gcdext# s# g# x# xn# y# yn#)
+        let ssn# = narrowGmpSize# ssn_#
+            sn#  = absI# ssn#
+        s' <- unsafeShrinkFreezeBigNat# s sn#
+        g' <- unsafeRenormFreezeBigNat# g
+        case ssn# >=# 0# of
+            0# -> return ( g', NegBN s' )
+            _  -> return ( g', PosBN s' )
+
+    !(BN# x#) = absSBigNat x
+    !(BN# y#) = absSBigNat y
+    xn# = ssizeofSBigNat# x
+    yn# = ssizeofSBigNat# y
+
+    gn0# = minI# (absI# xn#) (absI# yn#)
+
 ----------------------------------------------------------------------------
 -- modular exponentiation
 
@@ -1446,6 +1485,11 @@ foreign import ccall unsafe "integer_gmp_mpn_gcd"
   c_mpn_gcd# :: MutableByteArray# s -> ByteArray# -> GmpSize#
                 -> ByteArray# -> GmpSize# -> IO GmpSize
 
+foreign import ccall unsafe "integer_gmp_gcdext"
+  integer_gmp_gcdext# :: MutableByteArray# s -> MutableByteArray# s
+                         -> ByteArray# -> GmpSize#
+                         -> ByteArray# -> GmpSize# -> IO GmpSize
+
 -- mp_limb_t mpn_add_1 (mp_limb_t *rp, const mp_limb_t *s1p, mp_size_t n,
 --                      mp_limb_t s2limb)
 foreign import ccall unsafe "gmp.h __gmpn_add_1"
@@ -1952,3 +1996,7 @@ sgnI# x# = (x# ># 0#) -# (x# <# 0#)
 
 cmpI# :: Int# -> Int# -> Int#
 cmpI# x# y# = (x# ># y#) -# (x# <# y#)
+
+minI# :: Int# -> Int# -> Int#
+minI# x# y# | isTrue# (x# <=# y#) = x#
+            | True                = y#
index 2f49a75..628f8e0 100644 (file)
@@ -22,17 +22,7 @@ recipModInteger = I.recipModInteger
 
 -- FIXME: Lacks GMP2 version
 gcdExtInteger :: Integer -> Integer -> (Integer, Integer)
-gcdExtInteger a b = (d, u) -- stolen from `arithmoi` package
-  where
-    (d, x, y) = eGCD 0 1 1 0 (abs a) (abs b)
-    u | a < 0     = negate x
-      | otherwise = x
-    v | b < 0     = negate y
-      | otherwise = y
-    eGCD !n1 o1 !n2 o2 r s
-      | s == 0    = (r, o1, o2)
-      | otherwise = case r `quotRem` s of
-                      (q, t) -> eGCD (o1 - q*n1) n1 (o2 - q*n2) n2 s t
+gcdExtInteger a b = case I.gcdExtInteger a b of (# g, s #) -> (g, s)
 
 -- FIXME: Lacks GMP2 version
 powModSecInteger :: Integer -> Integer -> Integer -> Integer