60d5dd16ffb13095019c404fe824778420890029
[packages/stm.git] / tests / conc049.hs
1 -- STM stress test
2
3 {-# OPTIONS -fffi #-}
4 module Main (main) where
5
6 import Control.Concurrent
7 import Control.Concurrent.STM
8 import System.Random
9 import Data.Array
10 import GHC.Conc ( unsafeIOToSTM )
11 import Control.Monad ( when )
12 import System.IO
13 import System.IO.Unsafe
14 import System.Environment
15 import Foreign.C
16
17 -- | The number of bank accounts
18 n_accounts :: Int
19 n_accounts = 7
20
21 -- | The number of threads transferring money between accounts
22 n_actors :: Int
23 n_actors = 10
24
25 -- | The max initial number of monetary units in each account
26 init_credit :: Int
27 init_credit = 5
28
29 -- | The maximum size of a transfer
30 max_transfer :: Int
31 max_transfer = 3
32
33 -- | The maximum amount transferred by the source/sink thread
34 max_source :: Int
35 max_source = 3
36
37 max_transactions = 2000 :: Int
38
39 type Accounts = Array Int (TVar Int)
40
41 thread :: Int -> TVar Int -> Accounts -> IO ()
42 thread tid done accounts = loop max_transactions
43 where loop 0 = atomically $ do x <- readTVar done; writeTVar done (x+1)
44 loop n = do
45 src <- randomRIO (1,n_accounts)
46 dst <- randomRIO (1,n_accounts)
47 if (src == dst) then loop n else do
48 amount <- randomRIO (1,max_transfer)
49 start tid src dst amount
50 atomically_ tid $ do
51 let src_acc = accounts ! src
52 dst_acc = accounts ! dst
53 credit_src <- readTVar src_acc
54 when (credit_src < amount) retry
55 writeTVar src_acc (credit_src - amount)
56 credit_dst <- readTVar dst_acc
57 writeTVar dst_acc (credit_dst + amount)
58 loop (n-1)
59
60 start tid src dst amount =
61 puts ("start " ++ show tid ++ ' ':show src ++ ' ':show dst ++ ' ':show amount)
62
63 main = do
64 hSetBuffering stdout LineBuffering
65
66 {-
67 args <- getArgs
68 case args of
69 [n,m] -> let g = read (n ++ ' ':m) in setStdGen g
70 [] -> do g <- getStdGen
71 print g
72 -}
73
74 -- for a deterministic run, we set the random seed explicitly:
75 setStdGen (read "526454551 6356")
76
77 -- HACK: the global commitVar requires atomically, so we want to seq it outside of
78 -- an enclosing atomically (otherwise STM gets very confused).
79 seq commitVar $ return ()
80
81 -- print n_actors
82 -- print n_accounts
83 amounts <- sequence (take n_accounts (repeat (randomRIO (0,init_credit))))
84 -- mapM print amounts
85 tvars <- atomically $ mapM newTVar amounts
86 let accounts = listArray (1,n_accounts) tvars
87 done <- atomically (newTVar 0)
88 sequence [ forkIO (thread id done accounts) | id <- [1..n_actors] ]
89 forkIO $ sourceSinkThread accounts
90 atomically $ do
91 x <- readTVar done
92 when (x < n_actors) retry
93
94 sourceThreadId = 0 :: Int
95 sourceAccount = 0 :: Int
96
97 -- A thread that alternates between dropping some cash into an account
98 -- (source), and removing some cash from an account (sink).
99 sourceSinkThread accounts = loop True
100 where loop source = do
101 amount <- randomRIO (1,max_source)
102 acct <- randomRIO (1,n_accounts)
103 if source
104 then do start sourceThreadId sourceAccount acct amount
105 transfer acct amount
106 else do start sourceThreadId acct sourceAccount amount
107 transfer acct (-amount)
108 loop (not source)
109
110 transfer acct amount = do
111 let t = accounts ! acct
112 atomically_ sourceThreadId $ do
113 x <- readTVar t
114 writeTVar t $! max 0 (x+amount) -- never drop below zero,
115 -- and don't block.
116
117 -- NB. $! above is necessary to avoid this test getting into a bad
118 -- state. The sourceSinkThread fills up all the accounts with
119 -- thunks which the other threads have to evaluate. They'll keep
120 -- getting blocked on each other, and meanwhile the
121 -- sourceSinkThread can keep on filling up the accounts with more
122 -- thunks.
123
124 -- -----------------------------------------------------------------------------
125 -- Our tracing wrapper for atomically
126
127 {-# NOINLINE commitVar #-}
128 commitVar = unsafePerformIO $ atomically $ newTVar ([] :: [Int])
129
130 atomically_ :: Int -> STM a -> IO a
131 atomically_ tid stm = do
132 r <- atomically $ do
133 stmTrace ("execute " ++ show tid)
134 r <- stm `orElse` do
135 stmTrace ("retry " ++ show tid)
136 retry
137 c <- readTVar commitVar
138 writeTVar commitVar (tid:c)
139 return r
140
141 atomically $ do
142 c <- readTVar commitVar
143 mapM stmTrace ["commit " ++ show tid | tid <- reverse c ]
144 writeTVar commitVar []
145 return r
146
147 stmTrace s = unsafeIOToSTM (puts s)
148
149 puts :: String -> IO ()
150 puts s = throwErrnoIfMinus1_ "puts" $ withCString s c_puts
151
152 foreign import ccall unsafe {-"puts"-} "strlen"
153 c_puts :: CString -> IO CInt