Allocate initial 1-limb mpz_t on the Stack and introduce MPZ# type
authorHerbert Valerio Riedel <hvr@gnu.org>
Wed, 8 Jan 2014 23:19:31 +0000 (00:19 +0100)
committerHerbert Valerio Riedel <hvr@gnu.org>
Mon, 13 Jan 2014 11:42:02 +0000 (12:42 +0100)
We now allocate a 1-limb mpz_t on the stack instead of doing a more
expensive heap-allocation (especially if the heap-allocated copy becomes
garbage right away); this addresses #8647.

In order to delay heap allocations of 1-limb `ByteArray#`s instead of
the previous `(# Int#, ByteArray# #)` pair, a 3-tuple
`(# Int#, ByteArray#, Word# #)` is returned now. This tuple is given the
type-synonym `MPZ#`.

This 3-tuple representation uses either the 1st and the 2nd element, or
the 1st and the 3rd element to represent the limb(s) (NB: undefined
`ByteArray#` elements must not be accessed as they don't point to a
proper `ByteArray#`, see also `DUMMY_BYTE_ARR`); more specifically, the
following encoding is used (where `⊥` means undefined/unused):

 -  (#  0#, ⊥, 0## #) -> value = 0
 -  (#  1#, ⊥, w   #) -> value = w
 -  (# -1#, ⊥, w   #) -> value = -w
 -  (#  s#, d, 0## #) -> value = J# s d

The `mpzToInteger` helper takes care of converting `MPZ#` into an
`Integer`, and allocating a 1-limb `ByteArray#` in case the
value (`w`/`-w`) doesn't fit the `S# Int#` representation).

The following nofib benchmarks benefit from this optimization:

        Program      Size    Allocs   Runtime   Elapsed  TotalMem
 ------------------------------------------------------------------
     bernouilli     +0.2%     -5.2%      0.12      0.12     +0.0%
         gamteb     +0.2%     -1.7%      0.03      0.03     +0.0%
          kahan     +0.3%    -13.2%      0.17      0.17     +0.0%
         mandel     +0.2%    -24.6%      0.04      0.04     +0.0%
          power     +0.2%     -2.6%     -2.0%     -2.0%     -8.3%
      primetest     +0.1%    -17.3%      0.06      0.06     +0.0%
            rsa     +0.2%    -18.5%      0.02      0.02     +0.0%
            scs     +0.1%     -2.9%     -0.1%     -0.1%     +0.0%
         sphere     +0.3%     -0.8%      0.03      0.03     +0.0%
         symalg     +0.2%     -3.1%      0.01      0.01     +0.0%
 ------------------------------------------------------------------
            Min     +0.1%    -24.6%     -4.6%     -4.6%     -8.3%
            Max     +0.3%     +0.0%     +5.9%     +5.9%     +4.5%
 Geometric Mean     +0.2%     -1.0%     +0.2%     +0.2%     -0.0%

Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
GHC/Integer/GMP/Prim.hs
GHC/Integer/Type.lhs
cbits/gmp-wrappers.cmm

index 261df29..3790345 100644 (file)
@@ -4,6 +4,8 @@
 
 #include "MachDeps.h"
 module GHC.Integer.GMP.Prim (
+    MPZ#,
+
     cmpInteger#,
     cmpIntegerInt#,
 
@@ -79,6 +81,41 @@ import GHC.Types
 -- Double isn't available yet, and we shouldn't be using defaults anyway:
 default ()
 
+-- | This is represents a @mpz_t@ value in a heap-saving way.
+--
+-- The first tuple element, @/s/@, encodes the sign of the integer
+-- @/i/@ (i.e. @signum /s/ == signum /i/@), and the number of /limbs/
+-- used to represent the magnitude. If @abs /s/ > 1@, the 'ByteArray#'
+-- contains @abs /s/@ limbs encoding the integer. Otherwise, if @abs
+-- /s/ < 2@, the single limb is stored in the 'Word#' element instead
+-- (and the 'ByteArray#' element is undefined and MUST NOT be accessed
+-- as it doesn't point to a proper 'ByteArray#' but rather to an
+-- unsafe-coerced 'Int' in order be polite to the GC -- see
+-- @DUMMY_BYTE_ARR@ in gmp-wrappers.cmm)
+--
+-- More specifically, the following encoding is used (where `⊥` means
+-- undefined/unused):
+--
+-- * (#  0#, ⊥, 0## #) -> value = 0
+-- * (#  1#, ⊥, w   #) -> value = w
+-- * (# -1#, ⊥, w   #) -> value = -w
+-- * (#  s#, d, 0## #) -> value = J# s d
+--
+-- This representation allows to avoid temporary heap allocations
+-- (-> Trac #8647) of 1-limb 'ByteArray#'s which fit into the
+-- 'S#'-constructor. Moreover, this allows to delays 1-limb
+-- 'ByteArray#' heap allocations, as such 1-limb `mpz_t`s can be
+-- optimistically allocated on the Cmm stack and returned as a @#word@
+-- in case the `mpz_t` wasn't grown beyond 1 limb by the GMP
+-- operation.
+--
+-- See also the 'GHC.Integer.Type.mpzToInteger' function which ought
+-- to be used for converting 'MPZ#'s to 'Integer's and the
+-- @MP_INT_1LIMB_RETURN()@ macro in @gmp-wrappers.cmm@ which
+-- constructs 'MPZ#' values in the first place for implementation
+-- details.
+type MPZ# = (# Int#, ByteArray#, Word# #)
+
 -- | Returns -1,0,1 according as first argument is less than, equal to, or greater than second argument.
 --
 foreign import prim "integer_cmm_cmpIntegerzh" cmpInteger#
@@ -92,87 +129,87 @@ foreign import prim "integer_cmm_cmpIntegerIntzh" cmpIntegerInt#
 -- |
 --
 foreign import prim "integer_cmm_plusIntegerzh" plusInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Optimized version of 'plusInteger#' for summing big-ints with small-ints
 --
 foreign import prim "integer_cmm_plusIntegerIntzh" plusIntegerInt#
-  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_minusIntegerzh" minusInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Optimized version of 'minusInteger#' for substracting small-ints from big-ints
 --
 foreign import prim "integer_cmm_minusIntegerIntzh" minusIntegerInt#
-  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_timesIntegerzh" timesInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Optimized version of 'timesInteger#' for multiplying big-ints with small-ints
 --
 foreign import prim "integer_cmm_timesIntegerIntzh" timesIntegerInt#
-  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> MPZ#
 
 -- | Compute div and mod simultaneously, where div rounds towards negative
 -- infinity and\ @(q,r) = divModInteger#(x,y)@ implies
 -- @plusInteger# (timesInteger# q y) r = x@.
 --
 foreign import prim "integer_cmm_quotRemIntegerzh" quotRemInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #)
 
 -- | Variant of 'quotRemInteger#'
 --
 foreign import prim "integer_cmm_quotRemIntegerWordzh" quotRemIntegerWord#
-  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray#, Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Word# -> (# MPZ#, MPZ# #)
 
 -- | Rounds towards zero.
 --
 foreign import prim "integer_cmm_quotIntegerzh" quotInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Rounds towards zero.
 foreign import prim "integer_cmm_quotIntegerWordzh" quotIntegerWord#
-  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Word# -> MPZ#
 
 -- | Satisfies \texttt{plusInteger\# (timesInteger\# (quotInteger\# x y) y) (remInteger\# x y) == x}.
 --
 foreign import prim "integer_cmm_remIntegerzh" remInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Variant of 'remInteger#'
 foreign import prim "integer_cmm_remIntegerWordzh" remIntegerWord#
-  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Word# -> MPZ#
 
 -- | Compute div and mod simultaneously, where div rounds towards negative infinity
 -- and\texttt{(q,r) = divModInteger\#(x,y)} implies \texttt{plusInteger\# (timesInteger\# q y) r = x}.
 --
 foreign import prim "integer_cmm_divModIntegerzh" divModInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #)
 foreign import prim "integer_cmm_divIntegerzh" divInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 foreign import prim "integer_cmm_modIntegerzh" modInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Divisor is guaranteed to be a factor of dividend.
 --
 foreign import prim "integer_cmm_divExactIntegerzh" divExactInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Greatest common divisor.
 --
 foreign import prim "integer_cmm_gcdIntegerzh" gcdInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- | Extended greatest common divisor.
 --
 foreign import prim "integer_cmm_gcdExtIntegerzh" gcdExtInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray#, Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# MPZ#, MPZ# #)
 
 -- | Greatest common divisor, where second argument is an ordinary {\tt Int\#}.
 --
@@ -189,32 +226,34 @@ foreign import prim "integer_cmm_gcdIntzh" gcdInt#
 --  represent an {\tt Integer\#} holding the mantissa.
 --
 foreign import prim "integer_cmm_decodeDoublezh" decodeDouble#
-  :: Double# -> (# Int#, Int#, ByteArray# #)
+  :: Double# -> (# Int#, MPZ# #)
 
 -- |
 --
+-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value.
 foreign import prim "integer_cmm_int2Integerzh" int2Integer#
   :: Int# -> (# Int#, ByteArray# #)
 
 -- |
 --
+-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value.
 foreign import prim "integer_cmm_word2Integerzh" word2Integer#
   :: Word# -> (# Int#, ByteArray# #)
 
 -- |
 --
 foreign import prim "integer_cmm_andIntegerzh" andInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_orIntegerzh" orInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_xorIntegerzh" xorInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
@@ -224,37 +263,37 @@ foreign import prim "integer_cmm_testBitIntegerzh" testBitInteger#
 -- |
 --
 foreign import prim "integer_cmm_mul2ExpIntegerzh" mul2ExpInteger#
-  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_fdivQ2ExpIntegerzh" fdivQ2ExpInteger#
-  :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_powIntegerzh" powInteger#
-  :: Int# -> ByteArray# -> Word# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Word# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_powModIntegerzh" powModInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_powModSecIntegerzh" powModSecInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_recipModIntegerzh" recipModInteger#
-  :: Int# -> ByteArray# -> Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> Int# -> ByteArray# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_nextPrimeIntegerzh" nextPrimeInteger#
-  :: Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> MPZ#
 
 -- |
 --
@@ -269,12 +308,12 @@ foreign import prim "integer_cmm_sizeInBasezh" sizeInBaseInteger#
 -- |
 --
 foreign import prim "integer_cmm_importIntegerFromByteArrayzh" importIntegerFromByteArray#
-  :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray# #)
+  :: ByteArray# -> Word# -> Word# -> Int# -> MPZ#
 
 -- |
 --
 foreign import prim "integer_cmm_importIntegerFromAddrzh" importIntegerFromAddr#
-  :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray# #)
+  :: Addr# -> Word# -> Int# -> State# s -> (# State# s, MPZ# #)
 
 -- |
 --
@@ -289,12 +328,14 @@ foreign import prim "integer_cmm_exportIntegerToAddrzh" exportIntegerToAddr#
 -- |
 --
 foreign import prim "integer_cmm_complementIntegerzh" complementInteger#
-  :: Int# -> ByteArray# -> (# Int#, ByteArray# #)
+  :: Int# -> ByteArray# -> MPZ#
 
 #if WORD_SIZE_IN_BITS < 64
+-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value.
 foreign import prim "integer_cmm_int64ToIntegerzh" int64ToInteger#
   :: Int64# -> (# Int#, ByteArray# #)
 
+-- Note: This primitive doesn't use 'MPZ#' because its purpose is to instantiate a 'J#'-value.
 foreign import prim "integer_cmm_word64ToIntegerzh" word64ToInteger#
   :: Word64# -> (# Int#, ByteArray# #)
 
index 731c5fc..ab4fe9d 100644 (file)
@@ -37,6 +37,7 @@ import GHC.Prim (
 
 import GHC.Integer.GMP.Prim (
     -- GMP-related primitives
+    MPZ#,
     cmpInteger#, cmpIntegerInt#,
     plusInteger#, plusIntegerInt#, minusInteger#, minusIntegerInt#,
     timesInteger#, timesIntegerInt#,
@@ -172,6 +173,37 @@ smartJ# (-1#) mb# | isTrue# (v <# 0#) = S# v
     where
       v = negateInt# (indexIntArray# mb# 0#)
 smartJ# s# mb# = J# s# mb#
+
+-- |Construct 'Integer' out of a 'MPZ#' as returned by GMP wrapper primops
+--
+-- IMPORTANT: The 'ByteArray#' element MUST NOT be accessed unless the
+-- size-element indicates more than one limb!
+--
+-- See notes at definition site of 'MPZ#' in "GHC.Integer.GMP.Prim"
+-- for more details.
+mpzToInteger :: MPZ# -> Integer
+mpzToInteger (# 0#, _, _ #) = S# 0#
+mpzToInteger (# 1#, _, w# #) | isTrue# (v# >=# 0#) = S# v#
+                             | True = case word2Integer# w# of (# _, d #) -> J# 1# d
+    where
+      v# = word2Int# w#
+mpzToInteger (# -1#, _, w# #) | isTrue# (v# <=# 0#) = S# v#
+                              | True = case word2Integer# w# of (# _, d #) -> J# -1# d
+    where
+      v# = negateInt# (word2Int# w#)
+mpzToInteger (# s#, mb#, _ #) = J# s# mb#
+
+-- | Variant of 'mpzToInteger' for pairs of 'Integer's
+mpzToInteger2 :: (# MPZ#, MPZ# #) -> (# Integer, Integer #)
+mpzToInteger2 (# mpz1, mpz2 #) = (# i1, i2 #)
+    where
+      !i1 = mpzToInteger mpz1 -- This use of `!` avoids creating thunks,
+      !i2 = mpzToInteger mpz2 -- see also Note [Use S# if possible].
+
+-- |Negate MPZ#
+mpzNeg :: MPZ# -> MPZ#
+mpzNeg (# s#, mb#, w# #) = (# negateInt# s#, mb#, w# #)
+
 \end{code}
 
 Note [Use S# if possible]
@@ -221,26 +253,19 @@ Just using smartJ# in this way has good results:
 
 {-# NOINLINE quotRemInteger #-}
 quotRemInteger :: Integer -> Integer -> (# Integer, Integer #)
-quotRemInteger a@(S# INT_MINBOUND) b = quotRemInteger (toBig a) b
+quotRemInteger (S# INT_MINBOUND) b = quotRemInteger minIntAsBig b
 quotRemInteger (S# i) (S# j) = case quotRemInt# i j of
                                    (# q, r #) -> (# S# q, S# r #)
 quotRemInteger (J# s1 d1) (S# b) | isTrue# (b <# 0#)
   = case quotRemIntegerWord# s1 d1 (int2Word# (negateInt# b)) of
-          (# s3, d3, s4, d4 #) -> let !q = smartJ# (negateInt# s3) d3
-                                      !r = smartJ# s4 d4
-                                  in (# q, r #)
+          (# q, r #) -> let !q' = mpzToInteger(mpzNeg q)
+                            !r' = mpzToInteger(mpzNeg r)
+                        in (# q', r' #)
 quotRemInteger (J# s1 d1) (S# b)
-  = case quotRemIntegerWord# s1 d1 (int2Word# b) of
-          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
-                                      !r = smartJ# s4 d4
-                                  in (# q, r #)
+  = mpzToInteger2(quotRemIntegerWord# s1 d1 (int2Word# b))
 quotRemInteger i1@(S# _) i2@(J# _ _) = quotRemInteger (toBig i1) i2
 quotRemInteger (J# s1 d1) (J# s2 d2)
-  = case (quotRemInteger# s1 d1 s2 d2) of
-          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
-                                      !r = smartJ# s4 d4
-                                  in (# q, r #)
-                           -- See Note [Use S# if possible]
+  = mpzToInteger2(quotRemInteger# s1 d1 s2 d2) -- See Note [Use S# if possible]
 
 {-# NOINLINE divModInteger #-}
 divModInteger :: Integer -> Integer -> (# Integer, Integer #)
@@ -256,11 +281,7 @@ divModInteger (S# i) (S# j) = (# S# d, S# m #)
 
 divModInteger i1@(J# _ _) i2@(S# _) = divModInteger i1 (toBig i2)
 divModInteger i1@(S# _) i2@(J# _ _) = divModInteger (toBig i1) i2
-divModInteger (J# s1 d1) (J# s2 d2)
-  = case (divModInteger# s1 d1 s2 d2) of
-          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
-                                      !r = smartJ# s4 d4
-                                  in (# q, r #)
+divModInteger (J# s1 d1) (J# s2 d2) = mpzToInteger2 (divModInteger# s1 d1 s2 d2)
 
 {-# NOINLINE remInteger #-}
 remInteger :: Integer -> Integer -> Integer
@@ -276,12 +297,11 @@ remInteger ia@(S# a) (J# sb b)
 -}
 remInteger ia@(S# _) ib@(J# _ _) = remInteger (toBig ia) ib
 remInteger (J# sa a) (S# b)
-  = case remIntegerWord# sa a w of
-          (# sr, r #) -> smartJ# sr r
+  = mpzToInteger (remIntegerWord# sa a w)
   where
     w = int2Word# (if isTrue# (b <# 0#) then negateInt# b else b)
 remInteger (J# sa a) (J# sb b)
-  = case remInteger# sa a sb b of (# sr, r #) -> smartJ# sr r
+  = mpzToInteger (remInteger# sa a sb b)
 
 {-# NOINLINE quotInteger #-}
 quotInteger :: Integer -> Integer -> Integer
@@ -295,13 +315,11 @@ quotInteger (S# a) (J# sb b)
 -}
 quotInteger ia@(S# _) ib@(J# _ _) = quotInteger (toBig ia) ib
 quotInteger (J# sa a) (S# b) | isTrue# (b <# 0#)
-  = case quotIntegerWord# sa a (int2Word# (negateInt# b)) of
-          (# sq, q #) -> smartJ# (negateInt# sq) q
+  = mpzToInteger (mpzNeg (quotIntegerWord# sa a (int2Word# (negateInt# b))))
 quotInteger (J# sa a) (S# b)
-  = case quotIntegerWord# sa a (int2Word# b) of
-          (# sq, q #) -> smartJ# sq q
+  = mpzToInteger (quotIntegerWord# sa a (int2Word# b))
 quotInteger (J# sa a) (J# sb b)
-  = case quotInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
+  = mpzToInteger (quotInteger# sa a sb b)
 
 {-# NOINLINE modInteger #-}
 modInteger :: Integer -> Integer -> Integer
@@ -310,10 +328,9 @@ modInteger (S# a) (S# b) = S# (modInt# a b)
 modInteger ia@(S# _) ib@(J# _ _) = modInteger (toBig ia) ib
 modInteger (J# sa a) (S# b)
   = case int2Integer# b of { (# sb, b' #) ->
-    case modInteger# sa a sb b' of { (# sr, r #) ->
-    S# (integer2Int# sr r) }}
+    mpzToInteger (modInteger# sa a sb b') }
 modInteger (J# sa a) (J# sb b)
-  = case modInteger# sa a sb b of (# sr, r #) -> smartJ# sr r
+  = mpzToInteger (modInteger# sa a sb b)
 
 {-# NOINLINE divInteger #-}
 divInteger :: Integer -> Integer -> Integer
@@ -321,10 +338,9 @@ divInteger (S# INT_MINBOUND) b = divInteger minIntAsBig b
 divInteger (S# a) (S# b) = S# (divInt# a b)
 divInteger ia@(S# _) ib@(J# _ _) = divInteger (toBig ia) ib
 divInteger (J# sa a) (S# b)
-  = case int2Integer# b of { (# sb, b' #) ->
-    case divInteger# sa a sb b' of (# sq, q #) -> smartJ# sq q }
+  = case int2Integer# b of { (# sb, b' #) -> mpzToInteger (divInteger# sa a sb b') }
 divInteger (J# sa a) (J# sb b)
-  = case divInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
+  = mpzToInteger (divInteger# sa a sb b)
 \end{code}
 
 
@@ -344,8 +360,7 @@ gcdInteger ia@(S# a)  ib@(J# sb b)
        where !absA  = if isTrue# (a  <# 0#) then negateInt# a  else a
              !absSb = if isTrue# (sb <# 0#) then negateInt# sb else sb
 gcdInteger ia@(J# _ _) ib@(S# _) = gcdInteger ib ia
-gcdInteger (J# sa a) (J# sb b)
-  = case gcdInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
+gcdInteger (J# sa a) (J# sb b)   = mpzToInteger (gcdInteger# sa a sb b)
 
 -- | Extended euclidean algorithm.
 --
@@ -356,11 +371,7 @@ gcdExtInteger :: Integer -> Integer -> (# Integer, Integer #)
 gcdExtInteger a@(S# _)   b@(S# _) = gcdExtInteger (toBig a) (toBig b)
 gcdExtInteger a@(S# _) b@(J# _ _) = gcdExtInteger (toBig a) b
 gcdExtInteger a@(J# _ _) b@(S# _) = gcdExtInteger a (toBig b)
-gcdExtInteger (J# sa a) (J# sb b)
-  = case gcdExtInteger# sa a sb b of
-      (# sg, g, ss, s #) -> let !g' = smartJ# sg g
-                                !s' = smartJ# ss s
-                            in (# g', s' #)
+gcdExtInteger (J# sa a) (J# sb b) = mpzToInteger2 (gcdExtInteger# sa a sb b)
 
 -- | Compute least common multiple.
 {-# NOINLINE lcmInteger #-}
@@ -387,10 +398,8 @@ divExact (S# a) (J# sb b)
   = S# (quotInt# a (integer2Int# sb b))
 divExact (J# sa a) (S# b)
   = case int2Integer# b of
-    (# sb, b' #) -> case divExactInteger# sa a sb b' of
-                    (# sd, d #) -> smartJ# sd d
-divExact (J# sa a) (J# sb b)
-  = case divExactInteger# sa a sb b of (# sd, d #) -> smartJ# sd d
+    (# sb, b' #) -> mpzToInteger (divExactInteger# sa a sb b')
+divExact (J# sa a) (J# sb b) = mpzToInteger (divExactInteger# sa a sb b)
 \end{code}
 
 
@@ -529,14 +538,11 @@ plusInteger (S# i)      (S# j)   = case addIntC# i j of
                                        if isTrue# (c ==# 0#)
                                        then S# r
                                        else case int2Integer# i of
-                                            (# s, d #) -> case plusIntegerInt# s d j of
-                                                          (# s', d' #) -> J# s' d'
+                                            (# s, d #) -> mpzToInteger (plusIntegerInt# s d j)
 plusInteger i1@(J# _ _) (S# 0#)   = i1
-plusInteger (J# s1 d1)  (S# j)    = case plusIntegerInt# s1 d1 j of
-                                    (# s, d #) -> smartJ# s d
+plusInteger (J# s1 d1)  (S# j)    = mpzToInteger (plusIntegerInt# s1 d1 j)
 plusInteger i1@(S# _) i2@(J# _ _) = plusInteger i2 i1
-plusInteger (J# s1 d1) (J# s2 d2) = case plusInteger# s1 d1 s2 d2 of
-                                    (# s, d #) -> smartJ# s d
+plusInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (plusInteger# s1 d1 s2 d2)
 
 {-# NOINLINE minusInteger #-}
 minusInteger :: Integer -> Integer -> Integer
@@ -544,32 +550,25 @@ minusInteger (S# i)      (S# j)    = case subIntC# i j of
                                      (# r, c #) ->
                                          if isTrue# (c ==# 0#) then S# r
                                          else case int2Integer# i of
-                                              (# s, d #) -> case minusIntegerInt# s d j of
-                                                            (# s', d' #) -> J# s' d'
+                                              (# s, d #) -> mpzToInteger (minusIntegerInt# s d j)
 minusInteger i1@(J# _ _) (S# 0#)   = i1
-minusInteger (J# s1 d1)  (S# j)    = case minusIntegerInt# s1 d1 j of
-                                     (# s, d #) -> smartJ# s d
+minusInteger (J# s1 d1)  (S# j)    = mpzToInteger (minusIntegerInt# s1 d1 j)
 minusInteger (S# 0#)    (J# s2 d2) = J# (negateInt# s2) d2
-minusInteger (S# i)     (J# s2 d2) = case plusIntegerInt# (negateInt# s2) d2 i of
-                                     (# s, d #) -> smartJ# s d
-minusInteger (J# s1 d1) (J# s2 d2) = case minusInteger# s1 d1 s2 d2 of
-                                     (# s, d #) -> smartJ# s d
+minusInteger (S# i)     (J# s2 d2) = mpzToInteger (plusIntegerInt# (negateInt# s2) d2 i)
+minusInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (minusInteger# s1 d1 s2 d2)
 
 {-# NOINLINE timesInteger #-}
 timesInteger :: Integer -> Integer -> Integer
 timesInteger (S# i) (S# j)         = if isTrue# (mulIntMayOflo# i j ==# 0#)
                                      then S# (i *# j)
                                      else case int2Integer# i of
-                                          (# s, d #) -> case timesIntegerInt# s d j of
-                                                        (# s', d' #) -> smartJ# s' d'
+                                          (# s, d #) -> mpzToInteger (timesIntegerInt# s d j)
 timesInteger (S# 0#)     _         = S# 0#
 timesInteger (S# -1#)    i2        = negateInteger i2
 timesInteger (S# 1#)     i2        = i2
-timesInteger (S# i1)    (J# s2 d2) = case timesIntegerInt# s2 d2 i1 of
-                                     (# s, d #) -> J# s d
+timesInteger (S# i1)    (J# s2 d2) = mpzToInteger (timesIntegerInt# s2 d2 i1)
 timesInteger i1@(J# _ _) i2@(S# _) = timesInteger i2 i1 -- swap args & retry
-timesInteger (J# s1 d1) (J# s2 d2) = case timesInteger# s1 d1 s2 d2 of
-                                     (# s, d #) -> J# s d
+timesInteger (J# s1 d1) (J# s2 d2) = mpzToInteger (timesInteger# s1 d1 s2 d2)
 
 {-# NOINLINE negateInteger #-}
 negateInteger :: Integer -> Integer
@@ -599,8 +598,8 @@ encodeDoubleInteger (J# s# d#) e = encodeDouble# s# d# e
 {-# NOINLINE decodeDoubleInteger #-}
 decodeDoubleInteger :: Double# -> (# Integer, Int# #)
 decodeDoubleInteger d = case decodeDouble# d of
-                        (# exp#, s#, d# #) -> let !s = smartJ# s# d#
-                                              in (# s, exp# #)
+                        (# exp#, man# #) -> let !man = mpzToInteger man#
+                                            in (# man, exp# #)
 
 -- previous code: doubleFromInteger n = fromInteger n = encodeFloat n 0
 -- doesn't work too well, because encodeFloat is defined in
@@ -646,8 +645,7 @@ andInteger :: Integer -> Integer -> Integer
 x@(S# _)   `andInteger` y@(J# _ _)   = toBig x `andInteger` y
 x@(J# _ _) `andInteger` y@(S# _)     = x `andInteger` toBig y
 (J# s1 d1) `andInteger`   (J# s2 d2) =
-     case andInteger# s1 d1 s2 d2 of
-       (# s, d #) -> smartJ# s d
+     mpzToInteger (andInteger# s1 d1 s2 d2)
 
 {-# NOINLINE orInteger #-}
 orInteger :: Integer -> Integer -> Integer
@@ -655,8 +653,7 @@ orInteger :: Integer -> Integer -> Integer
 x@(S# _)   `orInteger` y@(J# _ _)   = toBig x `orInteger` y
 x@(J# _ _) `orInteger` y@(S# _)     = x `orInteger` toBig y
 (J# s1 d1) `orInteger`   (J# s2 d2) =
-     case orInteger# s1 d1 s2 d2 of
-       (# s, d #) -> J# s d
+     mpzToInteger (orInteger# s1 d1 s2 d2)
 
 {-# NOINLINE xorInteger #-}
 xorInteger :: Integer -> Integer -> Integer
@@ -664,27 +661,24 @@ xorInteger :: Integer -> Integer -> Integer
 x@(S# _)   `xorInteger` y@(J# _ _)   = toBig x `xorInteger` y
 x@(J# _ _) `xorInteger` y@(S# _)     = x `xorInteger` toBig y
 (J# s1 d1) `xorInteger`   (J# s2 d2) =
-     case xorInteger# s1 d1 s2 d2 of
-       (# s, d #) -> smartJ# s d
+     mpzToInteger (xorInteger# s1 d1 s2 d2)
 
 {-# NOINLINE complementInteger #-}
 complementInteger :: Integer -> Integer
 complementInteger (S# x)
     = S# (word2Int# (int2Word# x `xor#` int2Word# (0# -# 1#)))
 complementInteger (J# s d)
-    = case complementInteger# s d of (# s', d' #) -> smartJ# s' d'
+    = mpzToInteger (complementInteger# s d)
 
 {-# NOINLINE shiftLInteger #-}
 shiftLInteger :: Integer -> Int# -> Integer
 shiftLInteger j@(S# _) i = shiftLInteger (toBig j) i
-shiftLInteger (J# s d) i = case mul2ExpInteger# s d i of
-                           (# s', d' #) -> J# s' d'
+shiftLInteger (J# s d) i = mpzToInteger (mul2ExpInteger# s d i)
 
 {-# NOINLINE shiftRInteger #-}
 shiftRInteger :: Integer -> Int# -> Integer
 shiftRInteger j@(S# _) i = shiftRInteger (toBig j) i
-shiftRInteger (J# s d) i = case fdivQ2ExpInteger# s d i of
-                           (# s', d' #) -> smartJ# s' d'
+shiftRInteger (J# s d) i = mpzToInteger (fdivQ2ExpInteger# s d i)
 
 {-# NOINLINE testBitInteger #-}
 testBitInteger :: Integer -> Int# -> Bool
@@ -695,8 +689,7 @@ testBitInteger (J# s d) i = isTrue# (testBitInteger# s d i /=# 0#)
 {-# NOINLINE powInteger #-}
 powInteger :: Integer -> Word# -> Integer
 powInteger j@(S# _) e = powInteger (toBig j) e
-powInteger (J# s d) e = case powInteger# s d e of
-                            (# s', d' #) -> smartJ# s' d'
+powInteger (J# s d) e = mpzToInteger (powInteger# s d e)
 
 -- | \"@'powModInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to
 -- exponent @/e/@ modulo @/m/@.
@@ -709,8 +702,7 @@ powInteger (J# s d) e = case powInteger# s d e of
 {-# NOINLINE powModInteger #-}
 powModInteger :: Integer -> Integer -> Integer -> Integer
 powModInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) =
-    case powModInteger# s1 d1 s2 d2 s3 d3 of
-        (# s', d' #) -> smartJ# s' d'
+    mpzToInteger (powModInteger# s1 d1 s2 d2 s3 d3)
 powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m)
 
 -- | \"@'powModSecInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to
@@ -724,8 +716,7 @@ powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m)
 {-# NOINLINE powModSecInteger #-}
 powModSecInteger :: Integer -> Integer -> Integer -> Integer
 powModSecInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) =
-    case powModSecInteger# s1 d1 s2 d2 s3 d3 of
-        (# s', d' #) -> J# s' d'
+    mpzToInteger (powModSecInteger# s1 d1 s2 d2 s3 d3)
 powModSecInteger b e m = powModSecInteger (toBig b) (toBig e) (toBig m)
 
 -- | \"@'recipModInteger' /x/ /m/@\" computes the inverse of @/x/@ modulo @/m/@. If
@@ -740,8 +731,7 @@ recipModInteger :: Integer -> Integer -> Integer
 recipModInteger j@(S# _) m@(S# _)   = recipModInteger (toBig j) (toBig m)
 recipModInteger j@(S# _) m@(J# _ _) = recipModInteger (toBig j) m
 recipModInteger j@(J# _ _) m@(S# _) = recipModInteger j (toBig m)
-recipModInteger (J# s d) (J# ms md) = case recipModInteger# s d ms md of
-                           (# s', d' #) -> smartJ# s' d'
+recipModInteger (J# s d) (J# ms md) = mpzToInteger (recipModInteger# s d ms md)
 
 -- | Probalistic Miller-Rabin primality test.
 --
@@ -771,7 +761,7 @@ testPrimeInteger (J# s d) reps = testPrimeInteger# s d reps
 {-# NOINLINE nextPrimeInteger #-}
 nextPrimeInteger :: Integer -> Integer
 nextPrimeInteger j@(S# _) = nextPrimeInteger (toBig j)
-nextPrimeInteger (J# s d) = case nextPrimeInteger# s d of (# s', d' #) -> smartJ# s' d'
+nextPrimeInteger (J# s d) = mpzToInteger (nextPrimeInteger# s d)
 
 -- | Compute number of digits (without sign) in given @/base/@.
 --
@@ -861,7 +851,7 @@ exportIntegerToAddr j@(S# _) addr o e = exportIntegerToAddr (toBig j) addr o e -
 -- * returns a new 'Integer'
 {-# NOINLINE importIntegerFromByteArray #-}
 importIntegerFromByteArray :: ByteArray# -> Word# -> Word# -> Int# -> Integer
-importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e of (# s', d' #) -> smartJ# s' d'
+importIntegerFromByteArray ba o l e = mpzToInteger (importIntegerFromByteArray# ba o l e)
 
 -- | Read 'Integer' (without sign) from memory location at @/addr/@ in
 -- base-256 representation.
@@ -874,7 +864,7 @@ importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e
 {-# NOINLINE importIntegerFromAddr #-}
 importIntegerFromAddr :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Integer #)
 importIntegerFromAddr addr l e st = case importIntegerFromAddr# addr l e st of
-                                      (# st', s', d' #) -> let !j = smartJ# s' d' in (# st', j #)
+                                      (# st', mpz #) -> let !j = mpzToInteger mpz in (# st', j #)
 
 \end{code}
 
index 2c9bbd2..28c1333 100644 (file)
@@ -28,7 +28,6 @@
 #include "Cmm.h"
 #include "GmpDerivedConstants.h"
 
-import "integer-gmp" __gmpz_init;
 import "integer-gmp" __gmpz_add;
 import "integer-gmp" __gmpz_add_ui;
 import "integer-gmp" __gmpz_sub;
@@ -68,6 +67,8 @@ import "integer-gmp" __gmpz_export;
 
 import "integer-gmp" integer_cbits_decodeDouble;
 
+import "integer-gmp" stg_INTLIKE_closure;
+
 /* -----------------------------------------------------------------------------
    Arbitrary-precision Integer operations.
 
@@ -75,6 +76,15 @@ import "integer-gmp" integer_cbits_decodeDouble;
    the case for all the platforms that GHC supports, currently.
    -------------------------------------------------------------------------- */
 
+/* This is used when a dummy pointer is needed for a ByteArray# return value
+
+   Ideally this would be a statically allocated 'ByteArray#'
+   containing SIZEOF_W 0-bytes. However, since in those cases when a
+   dummy value is needed, the 'ByteArray#' is not supposed to be
+   accessed anyway, this is should be a tolerable hack.
+ */
+#define DUMMY_BYTE_ARR (stg_INTLIKE_closure+1)
+
 /* set mpz_t from Int#/ByteArray# */
 #define MP_INT_SET_FROM_BA(mp_ptr,i,ba)                  \
   MP_INT__mp_alloc(mp_ptr) = W_TO_INT(BYTE_ARR_WDS(ba)); \
@@ -85,42 +95,103 @@ import "integer-gmp" integer_cbits_decodeDouble;
 #define MP_INT_AS_PAIR(mp_ptr) \
   TO_W_(MP_INT__mp_size(mp_ptr)),(MP_INT__mp_d(mp_ptr)-SIZEOF_StgArrWords)
 
+#define MP_INT_TO_BA(mp_ptr) \
+  (MP_INT__mp_d(mp_ptr)-SIZEOF_StgArrWords)
+
+/* Size of mpz_t with single limb */
+#define SIZEOF_MP_INT_1LIMB (SIZEOF_MP_INT+WDS(1))
+
+/* Initialize 0-valued single-limb mpz_t at mp_ptr */
+#define MP_INT_1LIMB_INIT0(mp_ptr)                       \
+  MP_INT__mp_alloc(mp_ptr) = W_TO_INT(1);                \
+  MP_INT__mp_size(mp_ptr)  = W_TO_INT(0);                \
+  MP_INT__mp_d(mp_ptr)     = (mp_ptr+SIZEOF_MP_INT)
+
+
+/* return mpz_t as (# s::Int#, d::ByteArray#, l1::Word# #) tuple
+ *
+ * semantics:
+ *
+ *  (#  0, _, 0 #) -> value = 0
+ *  (#  1, _, w #) -> value =  w
+ *  (# -1, _, w #) -> value = -w
+ *  (#  s, d, 0 #) -> value =  J# s d
+ *
+ */
+#define MP_INT_1LIMB_RETURN(mp_ptr)                    \
+  CInt __mp_s;                                         \
+  __mp_s = MP_INT__mp_size(mp_ptr);                    \
+                                                       \
+  if (__mp_s == W_TO_INT(0))                           \
+  {                                                    \
+    return (0,DUMMY_BYTE_ARR,0);                       \
+  }                                                    \
+                                                       \
+  if (__mp_s == W_TO_INT(-1) || __mp_s == W_TO_INT(1)) \
+  {                                                    \
+    return (TO_W_(__mp_s),DUMMY_BYTE_ARR,W_[MP_INT__mp_d(mp_ptr)]); \
+  }                                                    \
+                                                       \
+  return (TO_W_(__mp_s),MP_INT_TO_BA(mp_ptr),0)
+
+/* Helper macro used by MP_INT_1LIMB_RETURN2 */
+#define MP_INT_1LIMB_AS_TUP3(s,d,w,mp_ptr) \
+  CInt s; P_ d; W_ w;                            \
+  s = MP_INT__mp_size(mp_ptr);                   \
+                                                 \
+  if (s == W_TO_INT(0))                          \
+  {                                              \
+    d = DUMMY_BYTE_ARR; w = 0;                            \
+  } else {                                       \
+    if (s == W_TO_INT(-1) || s == W_TO_INT(1))   \
+    {                                            \
+      d = DUMMY_BYTE_ARR; w = W_[MP_INT__mp_d(mp_ptr)];   \
+    } else {                                     \
+      d = MP_INT_TO_BA(mp_ptr); w = 0;           \
+    }                                            \
+  }
 
-/* :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray# #) */
+#define MP_INT_1LIMB_RETURN2(mp_ptr1,mp_ptr2)         \
+  MP_INT_1LIMB_AS_TUP3(__r1s,__r1d,__r1w,mp_ptr1);    \
+  MP_INT_1LIMB_AS_TUP3(__r2s,__r2d,__r2w,mp_ptr2);    \
+  return (TO_W_(__r1s),__r1d,__r1w, TO_W_(__r2s),__r2d,__r2w)
+
+/* :: ByteArray# -> Word# -> Word# -> Int# -> (# Int#, ByteArray#, Word# #) */
 integer_cmm_importIntegerFromByteArrayzh (P_ ba, W_ of, W_ sz, W_ e)
 {
   W_ src_ptr;
   W_ mp_result;
 
 again:
-  STK_CHK_GEN_N (SIZEOF_MP_INT);
+  STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB);
   MAYBE_GC(again);
 
-  mp_result = Sp - SIZEOF_MP_INT;
+  mp_result = Sp - SIZEOF_MP_INT_1LIMB;
+  MP_INT_1LIMB_INIT0(mp_result);
 
   src_ptr = BYTE_ARR_CTS(ba) + of;
 
-  ccall __gmpz_init(mp_result "ptr");
   ccall __gmpz_import(mp_result "ptr", sz, W_TO_INT(e), W_TO_INT(1), W_TO_INT(0), 0, src_ptr "ptr");
 
-  return(MP_INT_AS_PAIR(mp_result));
+  MP_INT_1LIMB_RETURN(mp_result);
 }
 
-/* :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray# #) */
+/* :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Int#, ByteArray#, Word# #) */
 integer_cmm_importIntegerFromAddrzh (W_ src_ptr, W_ sz, W_ e)
 {
   W_ mp_result;
 
 again:
-  STK_CHK_GEN_N (SIZEOF_MP_INT);
+  STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB);
   MAYBE_GC(again);
 
-  mp_result = Sp - SIZEOF_MP_INT;
+  mp_result = Sp - SIZEOF_MP_INT_1LIMB;
+
+  MP_INT_1LIMB_INIT0(mp_result);
 
-  ccall __gmpz_init(mp_result "ptr");
   ccall __gmpz_import(mp_result "ptr", sz, W_TO_INT(e), W_TO_INT(1), W_TO_INT(0), 0, src_ptr "ptr");
 
-  return(MP_INT_AS_PAIR(mp_result));
+  MP_INT_1LIMB_RETURN(mp_result);
 }
 
 /* :: Int# -> ByteArray# -> MutableByteArray# s -> Word# -> Int# -> State# s -> (# State# s, Word# #) */
@@ -329,22 +400,22 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2)                             \
   W_ mp_result1;                                                \
                                                                 \
 again:                                                          \
-  STK_CHK_GEN_N (3 * SIZEOF_MP_INT);                            \
+  STK_CHK_GEN_N (2*SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB);        \
   MAYBE_GC(again);                                              \
                                                                 \
-  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                          \
-  mp_tmp2    = Sp - 2 * SIZEOF_MP_INT;                          \
-  mp_result1 = Sp - 3 * SIZEOF_MP_INT;                          \
+  mp_tmp1    = Sp - 1*SIZEOF_MP_INT;                            \
+  mp_tmp2    = Sp - 2*SIZEOF_MP_INT;                            \
+  mp_result1 = Sp - 2*SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB;      \
                                                                 \
   MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                           \
   MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2);                           \
                                                                 \
-  ccall __gmpz_init(mp_result1 "ptr");                          \
+  MP_INT_1LIMB_INIT0(mp_result1);                               \
                                                                 \
   /* Perform the operation */                                   \
   ccall mp_fun(mp_result1 "ptr",mp_tmp1  "ptr",mp_tmp2  "ptr"); \
                                                                 \
-  return (MP_INT_AS_PAIR(mp_result1));                          \
+  MP_INT_1LIMB_RETURN(mp_result1);                              \
 }
 
 #define GMP_TAKE3_RET1(name,mp_fun)                             \
@@ -356,25 +427,25 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2, W_ ws3, P_ d3)              \
   W_ mp_result1;                                                \
                                                                 \
 again:                                                          \
-  STK_CHK_GEN_N (4 * SIZEOF_MP_INT);                            \
+  STK_CHK_GEN_N (3*SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB);        \
   MAYBE_GC(again);                                              \
                                                                 \
-  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                          \
-  mp_tmp2    = Sp - 2 * SIZEOF_MP_INT;                          \
-  mp_tmp3    = Sp - 3 * SIZEOF_MP_INT;                          \
-  mp_result1 = Sp - 4 * SIZEOF_MP_INT;                          \
+  mp_tmp1    = Sp - 1*SIZEOF_MP_INT;                            \
+  mp_tmp2    = Sp - 2*SIZEOF_MP_INT;                            \
+  mp_tmp3    = Sp - 3*SIZEOF_MP_INT;                            \
+  mp_result1 = Sp - 3*SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB;      \
                                                                 \
   MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                           \
   MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2);                           \
   MP_INT_SET_FROM_BA(mp_tmp3,ws3,d3);                           \
                                                                 \
-  ccall __gmpz_init(mp_result1 "ptr");                          \
+  MP_INT_1LIMB_INIT0(mp_result1);                               \
                                                                 \
   /* Perform the operation */                                   \
-  ccall mp_fun(mp_result1 "ptr",mp_tmp1  "ptr",mp_tmp2  "ptr",  \
-               mp_tmp3  "ptr");                                 \
+  ccall mp_fun(mp_result1 "ptr",                                \
+               mp_tmp1 "ptr", mp_tmp2 "ptr", mp_tmp3 "ptr");    \
                                                                 \
-  return (MP_INT_AS_PAIR(mp_result1));                          \
+  MP_INT_1LIMB_RETURN(mp_result1);                              \
 }
 
 #define GMP_TAKE1_UL1_RET1(name,mp_fun)                         \
@@ -385,20 +456,20 @@ name (W_ ws1, P_ d1, W_ wul)                                    \
                                                                 \
   /* call doYouWantToGC() */                                    \
 again:                                                          \
-  STK_CHK_GEN_N (2 * SIZEOF_MP_INT);                            \
+  STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB);          \
   MAYBE_GC(again);                                              \
                                                                 \
-  mp_tmp     = Sp - 1 * SIZEOF_MP_INT;                          \
-  mp_result  = Sp - 2 * SIZEOF_MP_INT;                          \
+  mp_tmp     = Sp - SIZEOF_MP_INT;                              \
+  mp_result  = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB;        \
                                                                 \
   MP_INT_SET_FROM_BA(mp_tmp,ws1,d1);                            \
                                                                 \
-  ccall __gmpz_init(mp_result "ptr");                           \
+  MP_INT_1LIMB_INIT0(mp_result);                                \
                                                                 \
   /* Perform the operation */                                   \
   ccall mp_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wul));  \
                                                                 \
-  return (MP_INT_AS_PAIR(mp_result));                           \
+  MP_INT_1LIMB_RETURN(mp_result);                               \
 }
 
 #define GMP_TAKE1_I1_RETI1(name,mp_fun)                         \
@@ -446,20 +517,20 @@ name (W_ ws1, P_ d1)                                            \
   W_ mp_result1;                                                \
                                                                 \
 again:                                                          \
-  STK_CHK_GEN_N (2 * SIZEOF_MP_INT);                            \
+  STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB);          \
   MAYBE_GC(again);                                              \
                                                                 \
-  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                          \
-  mp_result1 = Sp - 2 * SIZEOF_MP_INT;                          \
+  mp_tmp1    = Sp - SIZEOF_MP_INT;                              \
+  mp_result1 = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB;        \
                                                                 \
   MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                           \
                                                                 \
-  ccall __gmpz_init(mp_result1 "ptr");                          \
+  MP_INT_1LIMB_INIT0(mp_result1);                               \
                                                                 \
   /* Perform the operation */                                   \
   ccall mp_fun(mp_result1 "ptr",mp_tmp1 "ptr");                 \
                                                                 \
-  return(MP_INT_AS_PAIR(mp_result1));                           \
+  MP_INT_1LIMB_RETURN(mp_result1);                              \
 }
 
 #define GMP_TAKE2_RET2(name,mp_fun)                                     \
@@ -471,24 +542,25 @@ name (W_ ws1, P_ d1, W_ ws2, P_ d2)                                     \
   W_ mp_result2;                                                        \
                                                                         \
 again:                                                                  \
-  STK_CHK_GEN_N (4 * SIZEOF_MP_INT);                                    \
+  STK_CHK_GEN_N (2*SIZEOF_MP_INT + 2*SIZEOF_MP_INT_1LIMB);              \
   MAYBE_GC(again);                                                      \
                                                                         \
-  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                                  \
-  mp_tmp2    = Sp - 2 * SIZEOF_MP_INT;                                  \
-  mp_result1 = Sp - 3 * SIZEOF_MP_INT;                                  \
-  mp_result2 = Sp - 4 * SIZEOF_MP_INT;                                  \
+  mp_tmp1    = Sp - 1*SIZEOF_MP_INT;                                    \
+  mp_tmp2    = Sp - 2*SIZEOF_MP_INT;                                    \
+  mp_result1 = Sp - 2*SIZEOF_MP_INT - 1*SIZEOF_MP_INT_1LIMB;            \
+  mp_result2 = Sp - 2*SIZEOF_MP_INT - 2*SIZEOF_MP_INT_1LIMB;            \
                                                                         \
   MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                                   \
   MP_INT_SET_FROM_BA(mp_tmp2,ws2,d2);                                   \
                                                                         \
-  ccall __gmpz_init(mp_result1 "ptr");                                  \
-  ccall __gmpz_init(mp_result2 "ptr");                                  \
+  MP_INT_1LIMB_INIT0(mp_result1);                                       \
+  MP_INT_1LIMB_INIT0(mp_result2);                                       \
                                                                         \
   /* Perform the operation */                                           \
-  ccall mp_fun(mp_result1 "ptr",mp_result2 "ptr",mp_tmp1 "ptr",mp_tmp2 "ptr"); \
+  ccall mp_fun(mp_result1 "ptr", mp_result2 "ptr",                      \
+               mp_tmp1 "ptr", mp_tmp2 "ptr");                           \
                                                                         \
-  return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2));       \
+  MP_INT_1LIMB_RETURN2(mp_result1, mp_result2);                         \
 }
 
 #define GMP_TAKE1_UL1_RET2(name,mp_fun)                                 \
@@ -499,23 +571,23 @@ name (W_ ws1, P_ d1, W_ wul2)                                           \
   W_ mp_result2;                                                        \
                                                                         \
 again:                                                                  \
-  STK_CHK_GEN_N (3 * SIZEOF_MP_INT);                                    \
+  STK_CHK_GEN_N (SIZEOF_MP_INT + 2*SIZEOF_MP_INT_1LIMB);                \
   MAYBE_GC(again);                                                      \
                                                                         \
-  mp_tmp1    = Sp - 1 * SIZEOF_MP_INT;                                  \
-  mp_result1 = Sp - 2 * SIZEOF_MP_INT;                                  \
-  mp_result2 = Sp - 3 * SIZEOF_MP_INT;                                  \
+  mp_tmp1    = Sp - SIZEOF_MP_INT;                                      \
+  mp_result1 = Sp - SIZEOF_MP_INT - 1*SIZEOF_MP_INT_1LIMB;              \
+  mp_result2 = Sp - SIZEOF_MP_INT - 2*SIZEOF_MP_INT_1LIMB;              \
                                                                         \
   MP_INT_SET_FROM_BA(mp_tmp1,ws1,d1);                                   \
                                                                         \
-  ccall __gmpz_init(mp_result1 "ptr");                                  \
-  ccall __gmpz_init(mp_result2 "ptr");                                  \
+  MP_INT_1LIMB_INIT0(mp_result1);                                       \
+  MP_INT_1LIMB_INIT0(mp_result2);                                       \
                                                                         \
   /* Perform the operation */                                           \
   ccall mp_fun(mp_result1 "ptr", mp_result2 "ptr",                      \
                mp_tmp1 "ptr", W_TO_LONG(wul2));                         \
                                                                         \
-  return (MP_INT_AS_PAIR(mp_result1),MP_INT_AS_PAIR(mp_result2));       \
+  MP_INT_1LIMB_RETURN2(mp_result1, mp_result2);                         \
 }
 
 GMP_TAKE2_RET1(integer_cmm_plusIntegerzh,           __gmpz_add)
@@ -657,16 +729,17 @@ integer_cmm_cmpIntegerzh (W_ usize, P_ d1, W_ vsize, P_ d2)
 
 integer_cmm_decodeDoublezh (D_ arg)
 {
-    D_ arg;
-    W_ p;
     W_ mp_tmp1;
     W_ mp_tmp_w;
 
-    STK_CHK_GEN_N (2 * SIZEOF_MP_INT);
+#if SIZEOF_DOUBLE != SIZEOF_W
+    W_ p;
+
+    STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_W);
     ALLOC_PRIM (ARR_SIZE);
 
-    mp_tmp1  = Sp - 1 * SIZEOF_MP_INT;
-    mp_tmp_w = Sp - 2 * SIZEOF_MP_INT;
+    mp_tmp1  = Sp - SIZEOF_MP_INT;
+    mp_tmp_w = Sp - SIZEOF_MP_INT - SIZEOF_W;
 
     /* Be prepared to tell Lennart-coded integer_cbits_decodeDouble
        where mantissa.d can be put (it does not care about the rest) */
@@ -675,14 +748,29 @@ integer_cmm_decodeDoublezh (D_ arg)
     StgArrWords_bytes(p) = DOUBLE_MANTISSA_SIZE;
     MP_INT__mp_d(mp_tmp1) = BYTE_ARR_CTS(p);
 
+#else
+    /* When SIZEOF_DOUBLE == SIZEOF_W == 8, the result will fit into a
+       single 8-byte limb, and so we avoid allocating on the Heap and
+       use only the Stack instead */
+
+    STK_CHK_GEN_N (SIZEOF_MP_INT_1LIMB + SIZEOF_W);
+
+    mp_tmp1  = Sp - SIZEOF_MP_INT_1LIMB;
+    mp_tmp_w = Sp - SIZEOF_MP_INT_1LIMB - SIZEOF_W;
+
+    MP_INT_1LIMB_INIT0(mp_tmp1);
+#endif
+
     /* Perform the operation */
-    ccall integer_cbits_decodeDouble(mp_tmp1 "ptr", mp_tmp_w "ptr",arg);
+    ccall integer_cbits_decodeDouble(mp_tmp1 "ptr", mp_tmp_w "ptr", arg);
+
+    /* returns: (Int# (expn), MPZ#) */
+    MP_INT_1LIMB_AS_TUP3(r1s, r1d, r1w, mp_tmp1);
 
-    /* returns: (Int# (expn), Int#, ByteArray#) */
-    return (W_[mp_tmp_w], TO_W_(MP_INT__mp_size(mp_tmp1)), p);
+    return (W_[mp_tmp_w], TO_W_(r1s), r1d, r1w);
 }
 
-/* :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray# #) */
+/* :: Int# -> ByteArray# -> Int# -> (# Int#, ByteArray#, Word# #) */
 #define GMPX_TAKE1_UL1_RET1(name,pos_arg_fun,neg_arg_fun)               \
 name(W_ ws1, P_ d1, W_ wl)                                              \
 {                                                                       \
@@ -690,23 +778,23 @@ name(W_ ws1, P_ d1, W_ wl)                                              \
   W_ mp_result;                                                         \
                                                                         \
 again:                                                                  \
-  STK_CHK_GEN_N (2 * SIZEOF_MP_INT);                                    \
+  STK_CHK_GEN_N (SIZEOF_MP_INT + SIZEOF_MP_INT_1LIMB);                  \
   MAYBE_GC(again);                                                      \
                                                                         \
-  mp_tmp     = Sp - 1 * SIZEOF_MP_INT;                                  \
-  mp_result  = Sp - 2 * SIZEOF_MP_INT;                                  \
+  mp_tmp     = Sp - SIZEOF_MP_INT;                                      \
+  mp_result  = Sp - SIZEOF_MP_INT - SIZEOF_MP_INT_1LIMB;                \
                                                                         \
   MP_INT_SET_FROM_BA(mp_tmp,ws1,d1);                                    \
                                                                         \
-  ccall __gmpz_init(mp_result "ptr");                                   \
+  MP_INT_1LIMB_INIT0(mp_result);                                        \
                                                                         \
   if(%lt(wl,0)) {                                                       \
       ccall neg_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(-wl)); \
-      return(MP_INT_AS_PAIR(mp_result));                                \
+  } else {                                                              \
+      ccall pos_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wl));  \
   }                                                                     \
                                                                         \
-  ccall pos_arg_fun(mp_result "ptr", mp_tmp "ptr", W_TO_LONG(wl));      \
-  return(MP_INT_AS_PAIR(mp_result));                                    \
+  MP_INT_1LIMB_RETURN(mp_result);                                       \
 }
 
 /* NB: We need both primitives as we can't express 'minusIntegerInt#'