Implement getSizeofMutableByteArrayOp primop
authorBen Gamari <bgamari.foss@gmail.com>
Fri, 21 Aug 2015 08:37:39 +0000 (10:37 +0200)
committerBen Gamari <ben@smart-cactus.org>
Fri, 21 Aug 2015 10:10:06 +0000 (12:10 +0200)
Now since ByteArrays are mutable we need to be more explicit about when
the size is queried.

Test Plan: Add testcase and validate

Reviewers: goldfire, hvr, austin

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D1139

GHC Trac Issues: #9447

compiler/codeGen/StgCmmPrim.hs
compiler/main/BreakArray.hs
compiler/prelude/primops.txt.pp
libraries/integer-gmp/src/GHC/Integer/Type.hs

index d201eaf..7188c05 100644 (file)
@@ -327,6 +327,11 @@ emitPrimOp dflags [res] SizeofByteArrayOp [arg]
 emitPrimOp dflags [res] SizeofMutableByteArrayOp [arg]
    = emitPrimOp dflags [res] SizeofByteArrayOp [arg]
 
+--  #define getSizzeofMutableByteArrayzh(r,a) \
+--      r = ((StgArrWords *)(a))->bytes
+emitPrimOp dflags [res] GetSizeofMutableByteArrayOp [arg]
+   = emitAssign (CmmLocal res) (cmmLoadIndexW dflags arg (fixedHdrSizeW dflags) (bWord dflags))
+
 
 --  #define touchzh(o)                  /* nothing */
 emitPrimOp _ res@[] TouchOp args@[_arg]
index 6455912..65bf932 100644 (file)
@@ -34,6 +34,7 @@ import Control.Monad
 
 import ExtsCompat46
 import GHC.IO ( IO(..) )
+import System.IO.Unsafe ( unsafeDupablePerformIO )
 
 data BreakArray = BA (MutableByteArray# RealWorld)
 
@@ -73,7 +74,16 @@ safeIndex :: DynFlags -> BreakArray -> Int -> Bool
 safeIndex dflags array index = index < size dflags array && index >= 0
 
 size :: DynFlags -> BreakArray -> Int
-size dflags (BA array) = (I# (sizeofMutableByteArray# array)) `div` wORD_SIZE dflags
+size dflags (BA array) = size `div` wORD_SIZE dflags
+  where
+    -- We want to keep this operation pure. The mutable byte array
+    -- is never resized so this is safe.
+    size = unsafeDupablePerformIO $ sizeofMutableByteArray array
+
+    sizeofMutableByteArray :: MutableByteArray# RealWorld -> IO Int
+    sizeofMutableByteArray arr =
+        IO $ \s -> case getSizeofMutableByteArray# arr s of
+                       (# s', n# #) -> (# s', I# n# #)
 
 allocBA :: Int -> IO BreakArray
 allocBA (I# sz) = IO $ \s1 ->
index 6d45ed9..5fe02b2 100644 (file)
@@ -1115,7 +1115,13 @@ primop  SizeofByteArrayOp "sizeofByteArray#" GenPrimOp
 
 primop  SizeofMutableByteArrayOp "sizeofMutableByteArray#" GenPrimOp
    MutableByteArray# s -> Int#
-   {Return the size of the array in bytes.}
+   {Return the size of the array in bytes. Note that this is deprecated as it is
+   unsafe in the presence of concurrent resize operations on the same byte
+   array. See {\tt getSizeofMutableByteArray}.}
+
+primop  GetSizeofMutableByteArrayOp "getSizeofMutableByteArray#" GenPrimOp
+   MutableByteArray# s -> State# s -> (# State# s, Int# #)
+   {Return the number of elements in the array.}
 
 primop IndexByteArrayOp_Char "indexCharArray#" GenPrimOp
    ByteArray# -> Int# -> Char#
index d941c4c..a04d9ad 100644 (file)
@@ -1611,9 +1611,11 @@ sizeofBigNat# (BN# x#)
 
 data MutBigNat s = MBN# !(MutableByteArray# s)
 
-sizeofMutBigNat# :: MutBigNat s -> GmpSize#
-sizeofMutBigNat# (MBN# x#)
-    = sizeofMutableByteArray# x# `uncheckedIShiftRL#` GMP_LIMB_SHIFT#
+getSizeofMutBigNat# :: MutBigNat s -> State# s -> (# State# s, GmpSize# #)
+--getSizeofMutBigNat# :: MutBigNat s -> S s GmpSize#
+getSizeofMutBigNat# (MBN# x#) s =
+    case getSizeofMutableByteArray# x# s of
+        (# s', n# #) -> (# s', n# `uncheckedIShiftRL#` GMP_LIMB_SHIFT# #)
 
 newBigNat# :: GmpSize# -> S s (MutBigNat s)
 newBigNat# limbs# s =
@@ -1634,40 +1636,42 @@ unsafeFreezeBigNat# (MBN# mba#) s = case unsafeFreezeByteArray# mba# s of
 
 resizeMutBigNat# :: MutBigNat s -> GmpSize# -> S s (MutBigNat s)
 resizeMutBigNat# (MBN# mba0#) nsz# s
-  | isTrue# (bsz# ==# sizeofMutableByteArray# mba0#) = (# s, MBN# mba0# #)
-  | True = case resizeMutableByteArray# mba0# bsz# s of
-        (# s', mba# #) -> (# s' , MBN# mba# #)
+  | isTrue# (bsz# ==# n#) = (# s', MBN# mba0# #)
+  | True =
+    case resizeMutableByteArray# mba0# bsz# s' of
+        (# s'', mba# #) -> (# s'', MBN# mba# #)
   where
     bsz# = nsz# `uncheckedIShiftL#` GMP_LIMB_SHIFT#
+    (# s', n# #) = getSizeofMutBigNat# (MBN# mba0#) s
 
 shrinkMutBigNat# :: MutBigNat s -> GmpSize# -> State# s -> State# s
-shrinkMutBigNat# (MBN# mba0#) nsz#
-  | isTrue# (bsz# ==# sizeofMutableByteArray# mba0#) = \s -> s -- no-op
-  | True = shrinkMutableByteArray# mba0# bsz#
+shrinkMutBigNat# (MBN# mba0#) nsz# s
+  | isTrue# (bsz# ==# n#) = s' -- no-op
+  | True                  = shrinkMutableByteArray# mba0# bsz# s'
   where
     bsz# = nsz# `uncheckedIShiftL#` GMP_LIMB_SHIFT#
+    (# s', n# #) = getSizeofMutBigNat# (MBN# mba0#) s
 
 unsafeSnocFreezeBigNat# :: MutBigNat s -> GmpLimb# -> S s BigNat
-unsafeSnocFreezeBigNat# mbn0@(MBN# mba0#) limb# = do
-    -- (MBN# mba#) <- newBigNat# (n# +# 1#)
-    -- _ <- svoid (copyMutableByteArray# mba0# 0# mba# 0# nb0#)
-    (MBN# mba#) <- resizeMutBigNat# mbn0 (n# +# 1#)
-    _ <- svoid (writeWordArray# mba# n# limb#)
-    unsafeFreezeBigNat# (MBN# mba#)
+unsafeSnocFreezeBigNat# mbn0@(MBN# mba0#) limb# s = go s'
   where
     n#   = nb0# `uncheckedIShiftRL#` GMP_LIMB_SHIFT#
-    nb0# = sizeofMutableByteArray# mba0#
+    (# s', nb0# #) = getSizeofMutableByteArray# mba0# s
+    go = do
+        (MBN# mba#) <- resizeMutBigNat# mbn0 (n# +# 1#)
+        _ <- svoid (writeWordArray# mba# n# limb#)
+        unsafeFreezeBigNat# (MBN# mba#)
 
 -- | May shrink underlyng 'ByteArray#' if needed to satisfy BigNat invariant
 unsafeRenormFreezeBigNat# :: MutBigNat s -> S s BigNat
 unsafeRenormFreezeBigNat# mbn s
-  | isTrue# (n0# ==# 0#)  = (# s', nullBigNat #)
-  | isTrue# (n#  ==# 0#)  = (# s', zeroBigNat #)
-  | isTrue# (n#  ==# n0#) = (unsafeFreezeBigNat# mbn) s'
-  | True                  = (unsafeShrinkFreezeBigNat# mbn n#) s'
+  | isTrue# (n0# ==# 0#)  = (# s'', nullBigNat #)
+  | isTrue# (n#  ==# 0#)  = (# s'', zeroBigNat #)
+  | isTrue# (n#  ==# n0#) = (unsafeFreezeBigNat# mbn) s''
+  | True                  = (unsafeShrinkFreezeBigNat# mbn n#) s''
   where
-    (# s', n# #) = normSizeofMutBigNat'# mbn n0# s
-    n0# = sizeofMutBigNat# mbn
+    (# s', n0# #) = getSizeofMutBigNat# mbn s
+    (# s'', n# #) = normSizeofMutBigNat'# mbn n0# s'
 
 -- | Shrink MBN
 unsafeShrinkFreezeBigNat# :: MutBigNat s -> GmpSize# -> S s BigNat
@@ -1695,9 +1699,10 @@ copyWordArray# src src_ofs dst dst_ofs len
 
 -- | Version of 'normSizeofMutBigNat'#' which scans all allocated 'MutBigNat#'
 normSizeofMutBigNat# :: MutBigNat s -> State# s -> (# State# s, Int# #)
-normSizeofMutBigNat# mbn@(MBN# mba) = normSizeofMutBigNat'# mbn sz#
+normSizeofMutBigNat# mbn@(MBN# mba) s = normSizeofMutBigNat'# mbn sz# s'
   where
-    sz# = sizeofMutableByteArray# mba `uncheckedIShiftRA#` GMP_LIMB_SHIFT#
+    (# s', n# #) = getSizeofMutableByteArray# mba s
+    sz# = n# `uncheckedIShiftRA#` GMP_LIMB_SHIFT#
 
 -- | Find most-significant non-zero limb and return its index-position
 -- plus one. Start scanning downward from the initial limb-size
@@ -1726,10 +1731,12 @@ byteArrayToBigNat# ba# n0#
   | isTrue# (n#  ==# 0#)    = zeroBigNat
   | isTrue# (baszr# ==# 0#) -- i.e. ba# is multiple of limb-size
   , isTrue# (baszq# ==# n#) = (BN# ba#)
-  | True = runS $ do
-      mbn@(MBN# mba#) <- newBigNat# n#
-      _ <- svoid (copyByteArray# ba# 0# mba# 0# (sizeofMutableByteArray# mba#))
-      unsafeFreezeBigNat# mbn
+  | True = runS $ \s ->
+      let (# s', mbn@(MBN# mba#) #) = newBigNat# n# s
+          (# s'', ba_sz# #) = getSizeofMutableByteArray# mba# s'
+          go = do _ <- svoid (copyByteArray# ba# 0# mba# 0# ba_sz# )
+                  unsafeFreezeBigNat# mbn
+      in go s''
   where
     (# baszq#, baszr# #) = quotRemInt# (sizeofByteArray# ba#) GMP_LIMB_BYTES#