Avoid copying if possible in `concat`
authorBen Gamari <ben@smart-cactus.org>
Sun, 15 May 2016 21:37:51 +0000 (23:37 +0200)
committerDuncan Coutts <duncan@community.haskell.org>
Sun, 13 Nov 2016 20:56:08 +0000 (20:56 +0000)
The `binary` package revealed a rather obvious missing optimization here
when it forced extraneous copies with the evaluation of `concat [a,b]`
where `a` is empty and `b` is large.

Here we rework `Data.ByteString.concat` and `Data.ByteString.Lazy.concat`
to more aggressively avoid unnecessary copies in the face of
concatentations of lists with empty chunks. This rework has the nice
advantage of avoiding allocation during computation of the final buffer
length in the case where a copy is necessary (whereas previously
`checkedSum` would fail to fuse and therefore require allocation for its
list argument).

Data/ByteString/Internal.hs
Data/ByteString/Lazy.hs
Data/ByteString/Lazy/Internal.hs

index 4a9983b..031403e 100644 (file)
@@ -34,7 +34,6 @@ module Data.ByteString.Internal (
         unpackBytes, unpackAppendBytesLazy, unpackAppendBytesStrict,
         unpackChars, unpackAppendCharsLazy, unpackAppendCharsStrict,
         unsafePackAddress,
-        checkedSum,
 
         -- * Low level imperative construction
         create,                 -- :: Int -> (Ptr Word8 -> IO ()) -> IO ByteString
@@ -51,6 +50,7 @@ module Data.ByteString.Internal (
 
         -- * Utilities
         nullForeignPtr,         -- :: ForeignPtr Word8
+        checkedAdd,             -- :: String -> Int -> Int -> Int
 
         -- * Standard C Functions
         c_strlen,               -- :: CString -> IO CInt
@@ -76,7 +76,7 @@ module Data.ByteString.Internal (
         inlinePerformIO               -- :: IO a -> a
   ) where
 
-import Prelude hiding (concat)
+import Prelude hiding (concat, null)
 import qualified Data.List as List
 
 import Foreign.ForeignPtr       (ForeignPtr, withForeignPtr)
@@ -461,24 +461,63 @@ append (PS fp1 off1 len1) (PS fp2 off2 len2) =
       withForeignPtr fp2 $ \p2 -> memcpy destptr2 (p2 `plusPtr` off2) len2
 
 concat :: [ByteString] -> ByteString
-concat []     = mempty
-concat [bs]   = bs
-concat bss0   = unsafeCreate totalLen $ \ptr -> go bss0 ptr
+concat = \bss0 -> goLen0 bss0 bss0
+    -- The idea here is we first do a pass over the input list to determine:
+    --
+    --  1. is a copy necessary? e.g. @concat []@, @concat [mempty, "hello"]@,
+    --     and @concat ["hello", mempty, mempty]@ can all be handled without
+    --     copying.
+    --  2. if a copy is necessary, how large is the result going to be?
+    --
+    -- If a copy is necessary then we create a buffer of the appropriate size
+    -- and do another pass over the input list, copying the chunks into the
+    -- buffer. Also, since foreign calls aren't entirely free we skip over
+    -- empty chunks while copying.
+    --
+    -- We pass the original [ByteString] (bss0) through as an argument through
+    -- goLen0, goLen1, and goLen since we will need it again in goCopy. Passing
+    -- it as an explicit argument avoids capturing it in these functions'
+    -- closures which would result in unnecessary closure allocation.
   where
-    totalLen = checkedSum "concat" [ len | (PS _ _ len) <- bss0 ]
-    go []                  !_   = return ()
-    go (PS fp off len:bss) !ptr = do
+    -- It's still possible that the result is empty
+    goLen0 _    []                     = mempty
+    goLen0 bss0 (PS _ _ 0     :bss)    = goLen0 bss0 bss
+    goLen0 bss0 (bs           :bss)    = goLen1 bss0 bs bss
+
+    -- It's still possible that the result is a single chunk
+    goLen1 _    bs []                  = bs
+    goLen1 bss0 bs (PS _ _ 0  :bss)    = goLen1 bss0 bs bss
+    goLen1 bss0 bs (PS _ _ len:bss)    = goLen bss0 (checkedAdd "concat" len' len) bss
+      where PS _ _ len' = bs
+
+    -- General case, just find the total length we'll need
+    goLen bss0 !total (PS _ _ len:bss) = goLen bss0 total' bss
+      where total' = checkedAdd "concat" total len
+    goLen bss0 total [] =
+      unsafeCreate total $ \ptr -> goCopy bss0 ptr
+
+    -- Copy the data
+    goCopy []                  !_   = return ()
+    goCopy (PS _  _   0  :bss) !ptr = goCopy bss ptr
+    goCopy (PS fp off len:bss) !ptr = do
       withForeignPtr fp $ \p -> memcpy ptr (p `plusPtr` off) len
-      go bss (ptr `plusPtr` len)
-
--- | Add a list of non-negative numbers.  Errors out on overflow.
-checkedSum :: String -> [Int] -> Int
-checkedSum fun = go 0
-  where go !a (x:xs)
-            | ax >= 0   = go ax xs
-            | otherwise = overflowError fun
-          where ax = a + x
-        go a  _         = a
+      goCopy bss (ptr `plusPtr` len)
+{-# NOINLINE concat #-}
+
+{-# RULES
+"ByteString concat [] -> mempty"
+   concat [] = mempty
+"ByteString concat [bs] -> bs" forall x.
+   concat [x] = x
+ #-}
+
+-- | Add two non-negative numbers. Errors out on overflow.
+checkedAdd :: String -> Int -> Int -> Int
+checkedAdd fun x y
+  | r >= 0    = r
+  | otherwise = overflowError fun
+  where r = x + y
+{-# INLINE checkedAdd #-}
 
 ------------------------------------------------------------------------
 
index 5b1cf5a..329b4d8 100644 (file)
@@ -282,17 +282,40 @@ fromStrict bs | S.null bs = Empty
 -- avoid converting back and forth between strict and lazy bytestrings.
 --
 toStrict :: ByteString -> S.ByteString
-toStrict Empty           = S.empty
-toStrict (Chunk c Empty) = c
-toStrict cs0 = S.unsafeCreate totalLen $ \ptr -> go cs0 ptr
+toStrict = \cs -> goLen0 cs cs
+    -- We pass the original [ByteString] (bss0) through as an argument through
+    -- goLen0, goLen1, and goLen since we will need it again in goCopy. Passing
+    -- it as an explicit argument avoids capturing it in these functions'
+    -- closures which would result in unnecessary closure allocation.
   where
-    totalLen = S.checkedSum "Lazy.toStrict" . L.map S.length . toChunks $ cs0
-
-    go Empty                        !_       = return ()
-    go (Chunk (S.PS fp off len) cs) !destptr =
+    -- It's still possible that the result is empty
+    goLen0 _   Empty                   = S.empty
+    goLen0 cs0 (Chunk c cs) | S.null c = goLen0 cs0 cs
+    goLen0 cs0 (Chunk c cs)            = goLen1 cs0 c cs
+
+    -- It's still possible that the result is a single chunk
+    goLen1 _   bs Empty                = bs
+    goLen1 cs0 bs (Chunk c cs)
+      | S.null c                   = goLen1 cs0 bs cs
+      | otherwise                  =
+        goLen cs0 (S.checkedAdd "Lazy.concat" (S.length bs) (S.length c)) cs
+
+    -- General case, just find the total length we'll need
+    goLen cs0 !total (Chunk c cs)      = goLen cs0 total' cs
+      where
+        total' = S.checkedAdd "Lazy.concat" total (S.length c)
+    goLen cs0 total Empty =
+      S.unsafeCreate total $ \ptr -> goCopy cs0 ptr
+
+    -- Copy the data
+    goCopy Empty                        !_   = return ()
+    goCopy (Chunk (S.PS _  _   0  ) cs) !ptr = goCopy cs ptr
+    goCopy (Chunk (S.PS fp off len) cs) !ptr = do
       withForeignPtr fp $ \p -> do
-        S.memcpy destptr (p `plusPtr` off) len
-        go cs (destptr `plusPtr` len)
+        S.memcpy ptr (p `plusPtr` off) len
+        goCopy cs (ptr `plusPtr` len)
+-- See the comment on Data.ByteString.Internal.concat for some background on
+-- this implementation.
 
 ------------------------------------------------------------------------
 
index a292cfb..fcf6cc6 100644 (file)
@@ -73,6 +73,7 @@ import Data.Data                (Data(..), mkNoRepType)
 --
 data ByteString = Empty | Chunk {-# UNPACK #-} !S.ByteString ByteString
     deriving (Typeable)
+-- See 'invariant' function later in this module for internal invariants.
 
 instance Eq  ByteString where
     (==)    = eq