Use unsafeDupablePerformIO where possible
[packages/text.git] / Data / Text / Lazy / Encoding / Fusion.hs
index 997acdf..0d0c724 100644 (file)
@@ -1,11 +1,11 @@
-{-# LANGUAGE BangPatterns, Rank2Types #-}
+{-# LANGUAGE BangPatterns, CPP, Rank2Types #-}
 
 -- |
 -- Module      : Data.Text.Lazy.Encoding.Fusion
--- Copyright   : (c) Bryan O'Sullivan 2009
+-- Copyright   : (c) 2009, 2010 Bryan O'Sullivan
 --
 -- License     : BSD-style
--- Maintainer  : bos@serpentine.com, rtharper@aftereternity.co.uk, 
+-- Maintainer  : bos@serpentine.com, rtomharper@googlemail.com,
 --               duncan@haskell.org
 -- Stability   : experimental
 -- Portability : portable
@@ -17,11 +17,11 @@ module Data.Text.Lazy.Encoding.Fusion
     (
     -- * Streaming
     --  streamASCII
-     streamUtf8
-    --, streamUtf16LE
-    --, streamUtf16BE
-    --, streamUtf32LE
-    --, streamUtf32BE
+      streamUtf8
+    , streamUtf16LE
+    , streamUtf16BE
+    , streamUtf32LE
+    , streamUtf32BE
 
     -- * Unstreaming
     , unstream
@@ -35,84 +35,255 @@ import qualified Data.ByteString.Unsafe as B
 import Data.Text.Encoding.Fusion.Common
 import Data.Text.Encoding.Error
 import Data.Text.Fusion (Step(..), Stream(..))
-import Data.Text.Fusion.Internal (M(..), PairS(..), S(..))
 import Data.Text.Fusion.Size
-import Data.Text.UnsafeChar (unsafeChr8)
-import Data.Word (Word8)
+import Data.Text.UnsafeChar (unsafeChr, unsafeChr8, unsafeChr32)
+import Data.Text.UnsafeShift (shiftL)
+import Data.Word (Word8, Word16, Word32)
 import qualified Data.Text.Encoding.Utf8 as U8
-import System.IO.Unsafe (unsafePerformIO)
+import qualified Data.Text.Encoding.Utf16 as U16
+import qualified Data.Text.Encoding.Utf32 as U32
+import Data.Text.Unsafe (unsafeDupablePerformIO)
 import Foreign.ForeignPtr (withForeignPtr, ForeignPtr)
 import Foreign.Storable (pokeByteOff)
 import Data.ByteString.Internal (mallocByteString, memcpy)
+#if defined(ASSERTS)
 import Control.Exception (assert)
+#endif
 import qualified Data.ByteString.Internal as B
 
+data S = S0
+       | S1 {-# UNPACK #-} !Word8
+       | S2 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8
+       | S3 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8
+       | S4 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8 {-# UNPACK #-} !Word8
+
+data T = T !ByteString !S {-# UNPACK #-} !Int
+
 -- | /O(n)/ Convert a lazy 'ByteString' into a 'Stream Char', using
 -- UTF-8 encoding.
 streamUtf8 :: OnDecodeError -> ByteString -> Stream Char
-streamUtf8 onErr bs0 = Stream next (bs0 :*: empty :*: 0) unknownSize
-    where
-      empty = S N N N N
-      {-# INLINE next #-}
-      next (bs@(Chunk ps _) :*: S N _ _ _ :*: i)
-          | i < len && U8.validate1 a =
-              Yield (unsafeChr8 a) (bs :*: empty :*: i+1)
-          | i + 1 < len && U8.validate2 a b =
-              Yield (U8.chr2 a b) (bs :*: empty :*: i+2)
-          | i + 2 < len && U8.validate3 a b c =
-              Yield (U8.chr3 a b c) (bs :*: empty :*: i+3)
-          | i + 4 < len && U8.validate4 a b c d =
-              Yield (U8.chr4 a b c d) (bs :*: empty :*: i+4)
-          where len = B.length ps
-                a = B.unsafeIndex ps i
-                b = B.unsafeIndex ps (i+1)
-                c = B.unsafeIndex ps (i+2)
-                d = B.unsafeIndex ps (i+3)
-      next st@(bs :*: s :*: i) =
-        case s of
-          S (J a) N _ _             | U8.validate1 a ->
-            Yield (unsafeChr8 a) es
-          S (J a) (J b) N _         | U8.validate2 a b ->
-            Yield (U8.chr2 a b) es
-          S (J a) (J b) (J c) N     | U8.validate3 a b c ->
-            Yield (U8.chr3 a b c) es
-          S (J a) (J b) (J c) (J d) | U8.validate4 a b c d ->
-            Yield (U8.chr4 a b c d) es
-          _ -> consume st
-         where es = bs :*: empty :*: i
-      {-# INLINE consume #-}
-      consume (bs@(Chunk ps rest) :*: s :*: i)
-          | i >= B.length ps = consume (rest :*: s  :*: 0)
-          | otherwise =
-        case s of
-          S N _ _ _ -> next (bs :*: S x N N N :*: i+1)
-          S a N _ _ -> next (bs :*: S a x N N :*: i+1)
-          S a b N _ -> next (bs :*: S a b x N :*: i+1)
-          S a b c N -> next (bs :*: S a b c x :*: i+1)
-          S (J a) b c d -> decodeError "streamUtf8" "UTF-8" onErr (Just a)
-                           (bs :*: S b c d N :*: i+1)
-          where x = J (B.unsafeIndex ps i)
-      consume (Empty :*: S N _ _ _ :*: _) = Done
-      consume st = decodeError "streamUtf8" "UTF-8" onErr Nothing st
+streamUtf8 onErr bs0 = Stream next (T bs0 S0 0) unknownSize
+  where
+    next (T bs@(Chunk ps _) S0 i)
+      | i < len && U8.validate1 a =
+          Yield (unsafeChr8 a)    (T bs S0 (i+1))
+      | i + 1 < len && U8.validate2 a b =
+          Yield (U8.chr2 a b)     (T bs S0 (i+2))
+      | i + 2 < len && U8.validate3 a b c =
+          Yield (U8.chr3 a b c)   (T bs S0 (i+3))
+      | i + 3 < len && U8.validate4 a b c d =
+          Yield (U8.chr4 a b c d) (T bs S0 (i+4))
+      where len = B.length ps
+            a = B.unsafeIndex ps i
+            b = B.unsafeIndex ps (i+1)
+            c = B.unsafeIndex ps (i+2)
+            d = B.unsafeIndex ps (i+3)
+    next st@(T bs s i) =
+      case s of
+        S1 a       | U8.validate1 a       -> Yield (unsafeChr8 a)    es
+        S2 a b     | U8.validate2 a b     -> Yield (U8.chr2 a b)     es
+        S3 a b c   | U8.validate3 a b c   -> Yield (U8.chr3 a b c)   es
+        S4 a b c d | U8.validate4 a b c d -> Yield (U8.chr4 a b c d) es
+        _ -> consume st
+       where es = T bs S0 i
+    consume (T bs@(Chunk ps rest) s i)
+        | i >= B.length ps = consume (T rest s 0)
+        | otherwise =
+      case s of
+        S0         -> next (T bs (S1 x)       (i+1))
+        S1 a       -> next (T bs (S2 a x)     (i+1))
+        S2 a b     -> next (T bs (S3 a b x)   (i+1))
+        S3 a b c   -> next (T bs (S4 a b c x) (i+1))
+        S4 a b c d -> decodeError "streamUtf8" "UTF-8" onErr (Just a)
+                           (T bs (S3 b c d)   (i+1))
+        where x = B.unsafeIndex ps i
+    consume (T Empty S0 _) = Done
+    consume st             = decodeError "streamUtf8" "UTF-8" onErr Nothing st
 {-# INLINE [0] streamUtf8 #-}
 
+-- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using little
+-- endian UTF-16 encoding.
+streamUtf16LE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf16LE onErr bs0 = Stream next (T bs0 S0 0) unknownSize
+  where
+    next (T bs@(Chunk ps _) S0 i)
+      | i + 1 < len && U16.validate1 x1 =
+          Yield (unsafeChr x1)         (T bs S0 (i+2))
+      | i + 3 < len && U16.validate2 x1 x2 =
+          Yield (U16.chr2 x1 x2)       (T bs S0 (i+4))
+      where len = B.length ps
+            x1   = c (idx  i)      (idx (i + 1))
+            x2   = c (idx (i + 2)) (idx (i + 3))
+            c w1 w2 = w1 + (w2 `shiftL` 8)
+            idx = fromIntegral . B.unsafeIndex ps :: Int -> Word16
+    next st@(T bs s i) =
+      case s of
+        S2 w1 w2       | U16.validate1 (c w1 w2)           ->
+          Yield (unsafeChr (c w1 w2))   es
+        S4 w1 w2 w3 w4 | U16.validate2 (c w1 w2) (c w3 w4) ->
+          Yield (U16.chr2 (c w1 w2) (c w3 w4)) es
+        _ -> consume st
+       where es = T bs S0 i
+             c :: Word8 -> Word8 -> Word16
+             c w1 w2 = fromIntegral w1 + (fromIntegral w2 `shiftL` 8)
+    consume (T bs@(Chunk ps rest) s i)
+        | i >= B.length ps = consume (T rest s 0)
+        | otherwise =
+      case s of
+        S0             -> next (T bs (S1 x)          (i+1))
+        S1 w1          -> next (T bs (S2 w1 x)       (i+1))
+        S2 w1 w2       -> next (T bs (S3 w1 w2 x)    (i+1))
+        S3 w1 w2 w3    -> next (T bs (S4 w1 w2 w3 x) (i+1))
+        S4 w1 w2 w3 w4 -> decodeError "streamUtf16LE" "UTF-16LE" onErr (Just w1)
+                           (T bs (S3 w2 w3 w4)       (i+1))
+        where x = B.unsafeIndex ps i
+    consume (T Empty S0 _) = Done
+    consume st             = decodeError "streamUtf16LE" "UTF-16LE" onErr Nothing st
+{-# INLINE [0] streamUtf16LE #-}
+
+-- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using big
+-- endian UTF-16 encoding.
+streamUtf16BE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf16BE onErr bs0 = Stream next (T bs0 S0 0) unknownSize
+  where
+    next (T bs@(Chunk ps _) S0 i)
+      | i + 1 < len && U16.validate1 x1 =
+          Yield (unsafeChr x1)         (T bs S0 (i+2))
+      | i + 3 < len && U16.validate2 x1 x2 =
+          Yield (U16.chr2 x1 x2)       (T bs S0 (i+4))
+      where len = B.length ps
+            x1   = c (idx  i)      (idx (i + 1))
+            x2   = c (idx (i + 2)) (idx (i + 3))
+            c w1 w2 = (w1 `shiftL` 8) + w2
+            idx = fromIntegral . B.unsafeIndex ps :: Int -> Word16
+    next st@(T bs s i) =
+      case s of
+        S2 w1 w2       | U16.validate1 (c w1 w2)           ->
+          Yield (unsafeChr (c w1 w2))   es
+        S4 w1 w2 w3 w4 | U16.validate2 (c w1 w2) (c w3 w4) ->
+          Yield (U16.chr2 (c w1 w2) (c w3 w4)) es
+        _ -> consume st
+       where es = T bs S0 i
+             c :: Word8 -> Word8 -> Word16
+             c w1 w2 = (fromIntegral w1 `shiftL` 8) + fromIntegral w2
+    consume (T bs@(Chunk ps rest) s i)
+        | i >= B.length ps = consume (T rest s 0)
+        | otherwise =
+      case s of
+        S0             -> next (T bs (S1 x)          (i+1))
+        S1 w1          -> next (T bs (S2 w1 x)       (i+1))
+        S2 w1 w2       -> next (T bs (S3 w1 w2 x)    (i+1))
+        S3 w1 w2 w3    -> next (T bs (S4 w1 w2 w3 x) (i+1))
+        S4 w1 w2 w3 w4 -> decodeError "streamUtf16BE" "UTF-16BE" onErr (Just w1)
+                           (T bs (S3 w2 w3 w4)       (i+1))
+        where x = B.unsafeIndex ps i
+    consume (T Empty S0 _) = Done
+    consume st             = decodeError "streamUtf16BE" "UTF-16BE" onErr Nothing st
+{-# INLINE [0] streamUtf16BE #-}
+
+-- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using big
+-- endian UTF-32 encoding.
+streamUtf32BE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf32BE onErr bs0 = Stream next (T bs0 S0 0) unknownSize
+  where
+    next (T bs@(Chunk ps _) S0 i)
+      | i + 3 < len && U32.validate x =
+          Yield (unsafeChr32 x)       (T bs S0 (i+4))
+      where len = B.length ps
+            x = shiftL x1 24 + shiftL x2 16 + shiftL x3 8 + x4
+            x1    = idx i
+            x2    = idx (i+1)
+            x3    = idx (i+2)
+            x4    = idx (i+3)
+            idx = fromIntegral . B.unsafeIndex ps :: Int -> Word32
+    next st@(T bs s i) =
+      case s of
+        S4 w1 w2 w3 w4 | U32.validate (c w1 w2 w3 w4) ->
+          Yield (unsafeChr32 (c w1 w2 w3 w4)) es
+        _ -> consume st
+       where es = T bs S0 i
+             c :: Word8 -> Word8 -> Word8 -> Word8 -> Word32
+             c w1 w2 w3 w4 = shifted
+              where
+               shifted = shiftL x1 24 + shiftL x2 16 + shiftL x3 8 + x4
+               x1 = fromIntegral w1
+               x2 = fromIntegral w2
+               x3 = fromIntegral w3
+               x4 = fromIntegral w4
+    consume (T bs@(Chunk ps rest) s i)
+        | i >= B.length ps = consume (T rest s 0)
+        | otherwise =
+      case s of
+        S0             -> next (T bs (S1 x)          (i+1))
+        S1 w1          -> next (T bs (S2 w1 x)       (i+1))
+        S2 w1 w2       -> next (T bs (S3 w1 w2 x)    (i+1))
+        S3 w1 w2 w3    -> next (T bs (S4 w1 w2 w3 x) (i+1))
+        S4 w1 w2 w3 w4 -> decodeError "streamUtf32BE" "UTF-32BE" onErr (Just w1)
+                           (T bs (S3 w2 w3 w4)       (i+1))
+        where x = B.unsafeIndex ps i
+    consume (T Empty S0 _) = Done
+    consume st             = decodeError "streamUtf32BE" "UTF-32BE" onErr Nothing st
+{-# INLINE [0] streamUtf32BE #-}
+
+-- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using little
+-- endian UTF-32 encoding.
+streamUtf32LE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf32LE onErr bs0 = Stream next (T bs0 S0 0) unknownSize
+  where
+    next (T bs@(Chunk ps _) S0 i)
+      | i + 3 < len && U32.validate x =
+          Yield (unsafeChr32 x)       (T bs S0 (i+4))
+      where len = B.length ps
+            x = shiftL x4 24 + shiftL x3 16 + shiftL x2 8 + x1
+            x1    = idx i
+            x2    = idx (i+1)
+            x3    = idx (i+2)
+            x4    = idx (i+3)
+            idx = fromIntegral . B.unsafeIndex ps :: Int -> Word32
+    next st@(T bs s i) =
+      case s of
+        S4 w1 w2 w3 w4 | U32.validate (c w1 w2 w3 w4) ->
+          Yield (unsafeChr32 (c w1 w2 w3 w4)) es
+        _ -> consume st
+       where es = T bs S0 i
+             c :: Word8 -> Word8 -> Word8 -> Word8 -> Word32
+             c w1 w2 w3 w4 = shifted
+              where
+               shifted = shiftL x4 24 + shiftL x3 16 + shiftL x2 8 + x1
+               x1 = fromIntegral w1
+               x2 = fromIntegral w2
+               x3 = fromIntegral w3
+               x4 = fromIntegral w4
+    consume (T bs@(Chunk ps rest) s i)
+        | i >= B.length ps = consume (T rest s 0)
+        | otherwise =
+      case s of
+        S0             -> next (T bs (S1 x)          (i+1))
+        S1 w1          -> next (T bs (S2 w1 x)       (i+1))
+        S2 w1 w2       -> next (T bs (S3 w1 w2 x)    (i+1))
+        S3 w1 w2 w3    -> next (T bs (S4 w1 w2 w3 x) (i+1))
+        S4 w1 w2 w3 w4 -> decodeError "streamUtf32LE" "UTF-32LE" onErr (Just w1)
+                           (T bs (S3 w2 w3 w4)       (i+1))
+        where x = B.unsafeIndex ps i
+    consume (T Empty S0 _) = Done
+    consume st             = decodeError "streamUtf32LE" "UTF-32LE" onErr Nothing st
+{-# INLINE [0] streamUtf32LE #-}
+
 -- | /O(n)/ Convert a 'Stream' 'Word8' to a lazy 'ByteString'.
 unstreamChunks :: Int -> Stream Word8 -> ByteString
 unstreamChunks chunkSize (Stream next s0 len0) = chunk s0 (upperBound 4 len0)
-  where chunk s1 len1 = unsafePerformIO $ do
-          let len = min len1 chunkSize
+  where chunk s1 len1 = unsafeDupablePerformIO $ do
+          let len = max 4 (min len1 chunkSize)
           mallocByteString len >>= loop len 0 s1
           where
             loop !n !off !s fp = case next s of
                 Done | off == 0 -> return Empty
-                     | otherwise -> do
-                      bs <- trimUp fp off
-                      return $! Chunk bs Empty
+                     | otherwise -> return $! Chunk (trimUp fp off) Empty
                 Skip s' -> loop n off s' fp
                 Yield x s'
                     | off == chunkSize -> do
-                      bs <- trimUp fp off
-                      return (Chunk bs (chunk s (n - B.length bs)))
+                      let !newLen = n - off
+                      return $! Chunk (trimUp fp off) (chunk s newLen)
                     | off == n -> realloc fp n off s' x
                     | otherwise -> do
                       withForeignPtr fp $ \p -> pokeByteOff p off x
@@ -123,10 +294,13 @@ unstreamChunks chunkSize (Stream next s0 len0) = chunk s0 (upperBound 4 len0)
               fp' <- copy0 fp n n'
               withForeignPtr fp' $ \p -> pokeByteOff p off x
               loop n' (off+1) s fp'
-            {-# NOINLINE trimUp #-}
-            trimUp fp off = return $! B.PS fp 0 off
+            trimUp fp off = B.PS fp 0 off
             copy0 :: ForeignPtr Word8 -> Int -> Int -> IO (ForeignPtr Word8)
-            copy0 !src !srcLen !destLen = assert (srcLen <= destLen) $ do
+            copy0 !src !srcLen !destLen =
+#if defined(ASSERTS)
+              assert (srcLen <= destLen) $
+#endif
+              do
                 dest <- mallocByteString destLen
                 withForeignPtr src  $ \src'  ->
                     withForeignPtr dest $ \dest' ->