Try harder to demote results from `J#` to `S#` (re #8638)
authorHerbert Valerio Riedel <hvr@gnu.org>
Mon, 30 Dec 2013 15:05:20 +0000 (16:05 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Thu, 2 Jan 2014 15:26:42 +0000 (15:26 +0000)
Signed-off-by: Herbert Valerio Riedel <hvr@gnu.org>
GHC/Integer/Type.lhs

index 85ffa7c..c206462 100644 (file)
@@ -24,6 +24,7 @@ module GHC.Integer.Type where
 import GHC.Prim (
     -- Other types we use, convert from, or convert to
     Int#, Word#, Double#, Float#, ByteArray#, MutableByteArray#, Addr#, State#,
+    indexIntArray#,
     -- Conversions between those types
     int2Word#, int2Double#, int2Float#, word2Int#,
     -- Operations on Int# that we use for operations on S#
@@ -101,7 +102,11 @@ smallInteger i = S# i
 
 {-# NOINLINE wordToInteger #-}
 wordToInteger :: Word# -> Integer
-wordToInteger w = case word2Integer# w of (# s, d #) -> J# s d
+wordToInteger w = if isTrue# (i >=# 0#)
+                  then S# i
+                  else case word2Integer# w of (# s, d #) -> J# s d
+    where
+      !i = word2Int# w
 
 {-# NOINLINE integerToWord #-}
 integerToWord :: Integer -> Word#
@@ -140,9 +145,26 @@ integerToInt :: Integer -> Int#
 integerToInt (S# i)   = i
 integerToInt (J# s d) = integer2Int# s d
 
+-- | Promote 'S#' to 'J#'
 toBig :: Integer -> Integer
 toBig (S# i)     = case int2Integer# i of { (# s, d #) -> J# s d }
 toBig i@(J# _ _) = i
+
+-- | Demote 'J#' to 'S#' if possible. See also 'smartJ#'.
+toSmall :: Integer -> Integer
+toSmall i@(S# _)  = i
+toSmall (J# 0# _) = S# 0#
+toSmall (J# 1# mb#)  | isTrue# (v ># 0#) = S# v
+    where
+      v = indexIntArray# mb# 0#
+toSmall (J# -1# mb#) | isTrue# (v <# 0#) = S# v
+    where
+      v = negateInt# (indexIntArray# mb# 0#)
+toSmall i         = i
+
+-- | Smart 'J#' constructor which tries to construct 'S#' if possible
+smartJ# :: Int# -> ByteArray# -> Integer
+smartJ# s# mb# = toSmall (J# s# mb#)
 \end{code}
 
 
@@ -175,8 +197,9 @@ quotRemInteger i1@(J# _ _) i2@(S# _) = quotRemInteger i1 (toBig i2)
 quotRemInteger i1@(S# _) i2@(J# _ _) = quotRemInteger (toBig i1) i2
 quotRemInteger (J# s1 d1) (J# s2 d2)
   = case (quotRemInteger# s1 d1 s2 d2) of
-          (# s3, d3, s4, d4 #)
-            -> (# J# s3 d3, J# s4 d4 #)
+          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
+                                      !r = smartJ# s4 d4
+                                  in (# q, r #)
 
 {-# NOINLINE divModInteger #-}
 divModInteger :: Integer -> Integer -> (# Integer, Integer #)
@@ -191,8 +214,9 @@ divModInteger i1@(J# _ _) i2@(S# _) = divModInteger i1 (toBig i2)
 divModInteger i1@(S# _) i2@(J# _ _) = divModInteger (toBig i1) i2
 divModInteger (J# s1 d1) (J# s2 d2)
   = case (divModInteger# s1 d1 s2 d2) of
-          (# s3, d3, s4, d4 #)
-            -> (# J# s3 d3, J# s4 d4 #)
+          (# s3, d3, s4, d4 #) -> let !q = smartJ# s3 d3
+                                      !r = smartJ# s4 d4
+                                  in (# q, r #)
 
 {-# NOINLINE remInteger #-}
 remInteger :: Integer -> Integer -> Integer
@@ -212,7 +236,7 @@ remInteger (J# sa a) (S# b)
     case remInteger# sa a sb b' of { (# sr, r #) ->
     S# (integer2Int# sr r) }}
 remInteger (J# sa a) (J# sb b)
-  = case remInteger# sa a sb b of (# sr, r #) -> J# sr r
+  = case remInteger# sa a sb b of (# sr, r #) -> smartJ# sr r
 
 {-# NOINLINE quotInteger #-}
 quotInteger :: Integer -> Integer -> Integer
@@ -227,9 +251,9 @@ quotInteger (S# a) (J# sb b)
 quotInteger ia@(S# _) ib@(J# _ _) = quotInteger (toBig ia) ib
 quotInteger (J# sa a) (S# b)
   = case int2Integer# b of { (# sb, b' #) ->
-    case quotInteger# sa a sb b' of (# sq, q #) -> J# sq q }
+    case quotInteger# sa a sb b' of (# sq, q #) -> smartJ# sq q }
 quotInteger (J# sa a) (J# sb b)
-  = case quotInteger# sa a sb b of (# sg, g #) -> J# sg g
+  = case quotInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
 
 {-# NOINLINE modInteger #-}
 modInteger :: Integer -> Integer -> Integer
@@ -241,7 +265,7 @@ modInteger (J# sa a) (S# b)
     case modInteger# sa a sb b' of { (# sr, r #) ->
     S# (integer2Int# sr r) }}
 modInteger (J# sa a) (J# sb b)
-  = case modInteger# sa a sb b of (# sr, r #) -> J# sr r
+  = case modInteger# sa a sb b of (# sr, r #) -> smartJ# sr r
 
 {-# NOINLINE divInteger #-}
 divInteger :: Integer -> Integer -> Integer
@@ -250,9 +274,9 @@ divInteger (S# a) (S# b) = S# (divInt# a b)
 divInteger ia@(S# _) ib@(J# _ _) = divInteger (toBig ia) ib
 divInteger (J# sa a) (S# b)
   = case int2Integer# b of { (# sb, b' #) ->
-    case divInteger# sa a sb b' of (# sq, q #) -> J# sq q }
+    case divInteger# sa a sb b' of (# sq, q #) -> smartJ# sq q }
 divInteger (J# sa a) (J# sb b)
-  = case divInteger# sa a sb b of (# sg, g #) -> J# sg g
+  = case divInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
 \end{code}
 
 
@@ -273,7 +297,7 @@ gcdInteger ia@(S# a)  ib@(J# sb b)
              !absSb = if isTrue# (sb <# 0#) then negateInt# sb else sb
 gcdInteger ia@(J# _ _) ib@(S# _) = gcdInteger ib ia
 gcdInteger (J# sa a) (J# sb b)
-  = case gcdInteger# sa a sb b of (# sg, g #) -> J# sg g
+  = case gcdInteger# sa a sb b of (# sg, g #) -> smartJ# sg g
 
 -- | Extended euclidean algorithm.
 --
@@ -286,7 +310,9 @@ gcdExtInteger a@(S# _) b@(J# _ _) = gcdExtInteger (toBig a) b
 gcdExtInteger a@(J# _ _) b@(S# _) = gcdExtInteger a (toBig b)
 gcdExtInteger (J# sa a) (J# sb b)
   = case gcdExtInteger# sa a sb b of
-      (# sg, g, ss, s #) -> (# J# sg g, J# ss s #)
+      (# sg, g, ss, s #) -> let !g' = smartJ# sg g
+                                !s' = smartJ# ss s
+                            in (# g', s' #)
 
 -- | Compute least common multiple.
 {-# NOINLINE lcmInteger #-}
@@ -314,9 +340,9 @@ divExact (S# a) (J# sb b)
 divExact (J# sa a) (S# b)
   = case int2Integer# b of
     (# sb, b' #) -> case divExactInteger# sa a sb b' of
-                    (# sd, d #) -> J# sd d
+                    (# sd, d #) -> smartJ# sd d
 divExact (J# sa a) (J# sb b)
-  = case divExactInteger# sa a sb b of (# sd, d #) -> J# sd d
+  = case divExactInteger# sa a sb b of (# sd, d #) -> smartJ# sd d
 \end{code}
 
 
@@ -458,7 +484,7 @@ plusInteger i1@(S# i) i2@(S# j)  = case addIntC# i j of
 plusInteger i1@(J# _ _) i2@(S# _) = plusInteger i1 (toBig i2)
 plusInteger i1@(S# _) i2@(J# _ _) = plusInteger (toBig i1) i2
 plusInteger (J# s1 d1) (J# s2 d2) = case plusInteger# s1 d1 s2 d2 of
-                                    (# s, d #) -> J# s d
+                                    (# s, d #) -> smartJ# s d
 
 {-# NOINLINE minusInteger #-}
 minusInteger :: Integer -> Integer -> Integer
@@ -470,15 +496,18 @@ minusInteger i1@(S# i) i2@(S# j)   = case subIntC# i j of
 minusInteger i1@(J# _ _) i2@(S# _) = minusInteger i1 (toBig i2)
 minusInteger i1@(S# _) i2@(J# _ _) = minusInteger (toBig i1) i2
 minusInteger (J# s1 d1) (J# s2 d2) = case minusInteger# s1 d1 s2 d2 of
-                                     (# s, d #) -> J# s d
+                                     (# s, d #) -> smartJ# s d
 
 {-# NOINLINE timesInteger #-}
 timesInteger :: Integer -> Integer -> Integer
 timesInteger i1@(S# i) i2@(S# j)   = if isTrue# (mulIntMayOflo# i j ==# 0#)
                                      then S# (i *# j)
                                      else timesInteger (toBig i1) (toBig i2)
-timesInteger i1@(J# _ _) i2@(S# _) = timesInteger i1 (toBig i2)
+timesInteger (S# 0#)     _         = S# 0#
+timesInteger (S# -1#)    i2        = negateInteger i2
+timesInteger (S# 1#)     i2        = i2
 timesInteger i1@(S# _) i2@(J# _ _) = timesInteger (toBig i1) i2
+timesInteger i1@(J# _ _) i2@(S# _) = timesInteger i2 i1 -- swap args & retry
 timesInteger (J# s1 d1) (J# s2 d2) = case timesInteger# s1 d1 s2 d2 of
                                      (# s, d #) -> J# s d
 
@@ -510,7 +539,8 @@ encodeDoubleInteger (J# s# d#) e = encodeDouble# s# d# e
 {-# NOINLINE decodeDoubleInteger #-}
 decodeDoubleInteger :: Double# -> (# Integer, Int# #)
 decodeDoubleInteger d = case decodeDouble# d of
-                        (# exp#, s#, d# #) -> (# J# s# d#, exp# #)
+                        (# exp#, s#, d# #) -> let !s = smartJ# s# d#
+                                              in (# s, exp# #)
 
 -- previous code: doubleFromInteger n = fromInteger n = encodeFloat n 0
 -- doesn't work too well, because encodeFloat is defined in
@@ -557,7 +587,7 @@ x@(S# _)   `andInteger` y@(J# _ _)   = toBig x `andInteger` y
 x@(J# _ _) `andInteger` y@(S# _)     = x `andInteger` toBig y
 (J# s1 d1) `andInteger`   (J# s2 d2) =
      case andInteger# s1 d1 s2 d2 of
-       (# s, d #) -> J# s d
+       (# s, d #) -> smartJ# s d
 
 {-# NOINLINE orInteger #-}
 orInteger :: Integer -> Integer -> Integer
@@ -575,14 +605,14 @@ x@(S# _)   `xorInteger` y@(J# _ _)   = toBig x `xorInteger` y
 x@(J# _ _) `xorInteger` y@(S# _)     = x `xorInteger` toBig y
 (J# s1 d1) `xorInteger`   (J# s2 d2) =
      case xorInteger# s1 d1 s2 d2 of
-       (# s, d #) -> J# s d
+       (# s, d #) -> smartJ# s d
 
 {-# NOINLINE complementInteger #-}
 complementInteger :: Integer -> Integer
 complementInteger (S# x)
     = S# (word2Int# (int2Word# x `xor#` int2Word# (0# -# 1#)))
 complementInteger (J# s d)
-    = case complementInteger# s d of (# s', d' #) -> J# s' d'
+    = case complementInteger# s d of (# s', d' #) -> smartJ# s' d'
 
 {-# NOINLINE shiftLInteger #-}
 shiftLInteger :: Integer -> Int# -> Integer
@@ -594,7 +624,7 @@ shiftLInteger (J# s d) i = case mul2ExpInteger# s d i of
 shiftRInteger :: Integer -> Int# -> Integer
 shiftRInteger j@(S# _) i = shiftRInteger (toBig j) i
 shiftRInteger (J# s d) i = case fdivQ2ExpInteger# s d i of
-                           (# s', d' #) -> J# s' d'
+                           (# s', d' #) -> smartJ# s' d'
 
 {-# NOINLINE testBitInteger #-}
 testBitInteger :: Integer -> Int# -> Bool
@@ -606,7 +636,7 @@ testBitInteger (J# s d) i = isTrue# (testBitInteger# s d i /=# 0#)
 powInteger :: Integer -> Word# -> Integer
 powInteger j@(S# _) e = powInteger (toBig j) e
 powInteger (J# s d) e = case powInteger# s d e of
-                            (# s', d' #) -> J# s' d'
+                            (# s', d' #) -> smartJ# s' d'
 
 -- | \"@'powModInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to
 -- exponent @/e/@ modulo @/m/@.
@@ -620,7 +650,7 @@ powInteger (J# s d) e = case powInteger# s d e of
 powModInteger :: Integer -> Integer -> Integer -> Integer
 powModInteger (J# s1 d1) (J# s2 d2) (J# s3 d3) =
     case powModInteger# s1 d1 s2 d2 s3 d3 of
-        (# s', d' #) -> J# s' d'
+        (# s', d' #) -> smartJ# s' d'
 powModInteger b e m = powModInteger (toBig b) (toBig e) (toBig m)
 
 -- | \"@'powModSecInteger' /b/ /e/ /m/@\" computes base @/b/@ raised to
@@ -651,7 +681,7 @@ recipModInteger j@(S# _) m@(S# _)   = recipModInteger (toBig j) (toBig m)
 recipModInteger j@(S# _) m@(J# _ _) = recipModInteger (toBig j) m
 recipModInteger j@(J# _ _) m@(S# _) = recipModInteger j (toBig m)
 recipModInteger (J# s d) (J# ms md) = case recipModInteger# s d ms md of
-                           (# s', d' #) -> J# s' d'
+                           (# s', d' #) -> smartJ# s' d'
 
 -- | Probalistic Miller-Rabin primality test.
 --
@@ -681,7 +711,7 @@ testPrimeInteger (J# s d) reps = testPrimeInteger# s d reps
 {-# NOINLINE nextPrimeInteger #-}
 nextPrimeInteger :: Integer -> Integer
 nextPrimeInteger j@(S# _) = nextPrimeInteger (toBig j)
-nextPrimeInteger (J# s d) = case nextPrimeInteger# s d of (# s', d' #) -> J# s' d'
+nextPrimeInteger (J# s d) = case nextPrimeInteger# s d of (# s', d' #) -> smartJ# s' d'
 
 -- | Compute number of digits (without sign) in given @/base/@.
 --
@@ -769,13 +799,9 @@ exportIntegerToAddr j@(S# _) addr o e = exportIntegerToAddr (toBig j) addr o e -
 --   significant byte first if @/order/@ is @-1#@, and
 --
 -- * returns a new 'Integer'
---
--- It's recommended to avoid calling 'importIntegerFromByteArray' for
--- known to be small integers as this function currently always
--- returns a big integer even if it would fit into a small integer.
 {-# NOINLINE importIntegerFromByteArray #-}
 importIntegerFromByteArray :: ByteArray# -> Word# -> Word# -> Int# -> Integer
-importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e of (# s', d' #) -> J# s' d'
+importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e of (# s', d' #) -> smartJ# s' d'
 
 -- | Read 'Integer' (without sign) from memory location at @/addr/@ in
 -- base-256 representation.
@@ -788,7 +814,7 @@ importIntegerFromByteArray ba o l e = case importIntegerFromByteArray# ba o l e
 {-# NOINLINE importIntegerFromAddr #-}
 importIntegerFromAddr :: Addr# -> Word# -> Int# -> State# s -> (# State# s, Integer #)
 importIntegerFromAddr addr l e st = case importIntegerFromAddr# addr l e st of
-                                      (# st', s', d' #) -> (# st', J# s' d' #)
+                                      (# st', s', d' #) -> let !j = smartJ# s' d' in (# st', j #)
 
 \end{code}