Implement new `compareByteArrays#` primop
[ghc.git] / testsuite / tests / codeGen / should_run / compareByteArrays.hs
1 {-# LANGUAGE MagicHash #-}
2 {-# LANGUAGE RankNTypes #-}
3 {-# LANGUAGE UnboxedTuples #-}
4
5 -- exercise the 'compareByteArray#' primitive
6
7 module Main (main) where
8
9 import Control.Monad
10 import Control.Monad.ST
11 import Data.List
12 import GHC.Exts (Int (..))
13 import GHC.Prim
14 import GHC.ST (ST (ST))
15 import GHC.Word (Word8 (..))
16 import Text.Printf
17
18 data BA = BA# ByteArray#
19
20 instance Show BA where
21 show xs = "[" ++ intercalate "," (map (printf "0x%02x") (unpack xs)) ++ "]"
22
23 instance Eq BA where
24 x == y = eqByteArray x 0 (sizeofByteArray x) y 0 (sizeofByteArray y)
25
26 instance Ord BA where
27 compare x y = ordByteArray x 0 (sizeofByteArray x) y 0 (sizeofByteArray y)
28
29 compareByteArrays :: BA -> Int -> BA -> Int -> Int -> Int
30 compareByteArrays (BA# ba1#) (I# ofs1#) (BA# ba2#) (I# ofs2#) (I# n#)
31 = I# (compareByteArrays# ba1# ofs1# ba2# ofs2# n#)
32
33 {-
34 copyByteArray :: BA -> Int -> MBA s -> Int -> Int -> ST s ()
35 copyByteArray (BA# src#) (I# srcOfs#) (MBA# dest#) (I# destOfs#) (I# n#)
36 = ST $ \s -> case copyByteArray# src# srcOfs# dest# destOfs# n# s of
37 s' -> (# s', () #)
38 -}
39
40 indexWord8Array :: BA -> Int -> Word8
41 indexWord8Array (BA# ba#) (I# i#)
42 = W8# (indexWord8Array# ba# i#)
43
44 sizeofByteArray :: BA -> Int
45 sizeofByteArray (BA# ba#) = I# (sizeofByteArray# ba#)
46
47
48 data MBA s = MBA# (MutableByteArray# s)
49
50 newByteArray :: Int -> ST s (MBA s)
51 newByteArray (I# n#)
52 = ST $ \s -> case newByteArray# n# s of
53 (# s', mba# #) -> (# s', MBA# mba# #)
54
55 writeWord8Array :: MBA s -> Int -> Word8 -> ST s ()
56 writeWord8Array (MBA# mba#) (I# i#) (W8# j#)
57 = ST $ \s -> case writeWord8Array# mba# i# j# s of
58 s' -> (# s', () #)
59
60 unsafeFreezeByteArray :: MBA s -> ST s BA
61 unsafeFreezeByteArray (MBA# mba#)
62 = ST $ \s -> case unsafeFreezeByteArray# mba# s of
63 (# s', ba# #) -> (# s', BA# ba# #)
64
65 ----------------------------------------------------------------------------
66 -- high-level operations
67
68 createByteArray :: Int -> (forall s. MBA s -> ST s ()) -> BA
69 createByteArray n go = runST $ do
70 mba <- newByteArray n
71 go mba
72 unsafeFreezeByteArray mba
73
74 pack :: [Word8] -> BA
75 pack xs = createByteArray (length xs) $ \mba -> do
76 let go _ [] = pure ()
77 go i (y:ys) = do
78 writeWord8Array mba i y
79 go (i+1) ys
80 go 0 xs
81
82 unpack :: BA -> [Word8]
83 unpack ba = go 0
84 where
85 go i | i < sz = indexWord8Array ba i : go (i+1)
86 | otherwise = []
87 sz = sizeofByteArray ba
88
89 eqByteArray :: BA -> Int -> Int -> BA -> Int -> Int -> Bool
90 eqByteArray ba1 ofs1 n1 ba2 ofs2 n2
91 | n1 /= n2 = False
92 | n1 == 0 = True
93 | otherwise = compareByteArrays ba1 ofs1 ba2 ofs2 n1 == 0
94
95 ordByteArray :: BA -> Int -> Int -> BA -> Int -> Int -> Ordering
96 ordByteArray ba1 ofs1 n1 ba2 ofs2 n2
97 | n == 0 = compare n1 n2
98 | otherwise = case compareByteArrays ba1 ofs1 ba2 ofs2 n of
99 r | r < 0 -> LT
100 | r > 0 -> GT
101 | n1 < n2 -> LT
102 | n1 > n2 -> GT
103 | otherwise -> EQ
104 where
105 n = n1 `min` n2
106
107 main :: IO ()
108 main = do
109 putStrLn "BEGIN"
110 -- a couple of low-level tests
111 print (compareByteArrays s1 0 s2 0 4 `compare` 0)
112 print (compareByteArrays s2 0 s1 0 4 `compare` 0)
113 print (compareByteArrays s1 0 s2 0 3 `compare` 0)
114 print (compareByteArrays s1 0 s2 1 3 `compare` 0)
115 print (compareByteArrays s1 3 s2 2 1 `compare` 0)
116
117 forM_ [(s1,s1),(s1,s2),(s2,s1),(s2,s2)] $ \(x,y) -> do
118 print (x == y, compare x y)
119
120 -- realistic test
121 print (sort (map pack strs) == map pack (sort strs))
122
123 -- brute-force test
124 forM_ [1..15] $ \n -> do
125 forM_ [0..rnglen-(n+1)] $ \j -> do
126 forM_ [0..rnglen-(n+1)] $ \k -> do
127 let iut = compareByteArrays srng j srng k n `compare` 0
128 ref = (take n (drop j rng) `compare` take n (drop k rng))
129 unless (iut == ref) $
130 print ("FAIL",n,j,k,iut,ref)
131
132 putStrLn "END"
133 where
134 s1, s2 :: BA
135 s1 = pack [0xca,0xfe,0xba,0xbe]
136 s2 = pack [0xde,0xad,0xbe,0xef]
137
138 strs = let go i xs = case splitAt (i `mod` 5) xs of
139 ([],[]) -> []
140 (y,ys) -> y : go (i+1) ys
141 in go 1 rng
142
143 srng = pack rng
144
145 rnglen = length rng
146
147 rng :: [Word8]
148 rng = [ 0xc1, 0x60, 0x31, 0xb6, 0x46, 0x81, 0xa7, 0xc6, 0xa8, 0xf4, 0x1e, 0x5d, 0xb7, 0x7c, 0x0b, 0xcd
149 , 0x10, 0xfa, 0xe3, 0xdd, 0xf4, 0x26, 0xf9, 0x50, 0x4b, 0x9c, 0xdf, 0xc4, 0xda, 0xca, 0xc1, 0x60
150 , 0x91, 0xf8, 0x70, 0x1a, 0x53, 0x89, 0xf1, 0xd9, 0xee, 0xff, 0x52, 0xb8, 0x1c, 0x5e, 0x25, 0x69
151 , 0xd1, 0xa1, 0x08, 0x47, 0x93, 0x89, 0x71, 0x7a, 0xe4, 0x56, 0x24, 0x1b, 0xa1, 0x43, 0x63, 0xc0
152 , 0x4d, 0xec, 0x93, 0x30, 0xb7, 0x98, 0x19, 0x23, 0x4e, 0x00, 0x76, 0x7e, 0xf4, 0xcc, 0x8b, 0x92
153 , 0x19, 0xc5, 0x3d, 0xf4, 0xa0, 0x4f, 0xe3, 0x64, 0x1b, 0x4e, 0x01, 0xc9, 0xfc, 0x47, 0x3e, 0x16
154 , 0xa4, 0x78, 0xdd, 0x12, 0x20, 0xa6, 0x0b, 0xcd, 0x82, 0x06, 0xd0, 0x2a, 0x19, 0x2d, 0x2f, 0xf2
155 , 0x8a, 0xf0, 0xc2, 0x2d, 0x0e, 0xfb, 0x39, 0x55, 0xb2, 0xfb, 0x6e, 0xd0, 0xfa, 0xf0, 0x87, 0x57
156 , 0x93, 0xa3, 0xae, 0x36, 0x1f, 0xcf, 0x91, 0x45, 0x44, 0x11, 0x62, 0x7f, 0x18, 0x9a, 0xcb, 0x54
157 , 0x78, 0x3c, 0x04, 0xbe, 0x3e, 0xd4, 0x2c, 0xbf, 0x73, 0x38, 0x9e, 0xf5, 0xc9, 0xbe, 0xd9, 0xf8
158 , 0xe5, 0xf5, 0x41, 0xbb, 0x84, 0x03, 0x2c, 0xe2, 0x0d, 0xe5, 0x8b, 0x1c, 0x75, 0xf7, 0x4c, 0x49
159 , 0xfe, 0xac, 0x9f, 0xf4, 0x36, 0xf2, 0xba, 0x5f, 0xc0, 0xda, 0x24, 0xfc, 0x10, 0x61, 0xf0, 0xb6
160 , 0xa7, 0xc7, 0xba, 0xc6, 0xb0, 0x41, 0x04, 0x8c, 0xd0, 0xe8, 0x48, 0x41, 0x38, 0xa4, 0x84, 0x21
161 , 0xb6, 0xb1, 0x21, 0x33, 0x58, 0xf2, 0xa5, 0xe5, 0x73, 0xf2, 0xd7, 0xbc, 0xc7, 0x7e, 0x86, 0xee
162 , 0x81, 0xb1, 0xcd, 0x42, 0xc0, 0x2c, 0xd0, 0xa0, 0x8d, 0xb5, 0x4a, 0x5b, 0xc1, 0xfe, 0xcc, 0x92
163 , 0x59, 0xf4, 0x71, 0x96, 0x58, 0x6a, 0xb6, 0xa2, 0xf7, 0x67, 0x76, 0x01, 0xc5, 0x8b, 0xc9, 0x6f
164 , 0x38, 0x93, 0xf3, 0xaa, 0x89, 0xf7, 0xb2, 0x2a, 0x0f, 0x19, 0x7b, 0x48, 0xbe, 0x86, 0x37, 0xd1
165 , 0x30, 0xfa, 0xce, 0x72, 0xf4, 0x25, 0x64, 0xee, 0xde, 0x3a, 0x5c, 0x02, 0x32, 0xe6, 0x31, 0x3a
166 , 0x4b, 0x18, 0x47, 0x30, 0xa4, 0x2c, 0xf8, 0x4d, 0xc5, 0xee, 0x0b, 0x9c, 0x75, 0x43, 0x2a, 0xf9
167 ]