1 -- STM stress test
3 {-# OPTIONS -fffi #-}
4 module Main (main) where
6 import Control.Concurrent
7 import Control.Concurrent.STM
8 import System.Random
9 import Data.Array
10 import GHC.Conc ( unsafeIOToSTM )
11 import Control.Monad ( when )
12 import System.IO
13 import System.IO.Unsafe
14 import System.Environment
15 import Foreign.C
17 -- | The number of bank accounts
18 n_accounts :: Int
19 n_accounts = 7
21 -- | The number of threads transferring money between accounts
22 n_actors :: Int
23 n_actors = 10
25 -- | The max initial number of monetary units in each account
26 init_credit :: Int
27 init_credit = 5
29 -- | The maximum size of a transfer
30 max_transfer :: Int
31 max_transfer = 3
33 -- | The maximum amount transferred by the source/sink thread
34 max_source :: Int
35 max_source = 3
37 max_transactions = 2000 :: Int
39 type Accounts = Array Int (TVar Int)
41 thread :: Int -> TVar Int -> Accounts -> IO ()
42 thread tid done accounts = loop max_transactions
43 where loop 0 = atomically \$ do x <- readTVar done; writeTVar done (x+1)
44 loop n = do
45 src <- randomRIO (1,n_accounts)
46 dst <- randomRIO (1,n_accounts)
47 if (src == dst) then loop n else do
48 amount <- randomRIO (1,max_transfer)
49 start tid src dst amount
50 atomically_ tid \$ do
51 let src_acc = accounts ! src
52 dst_acc = accounts ! dst
54 when (credit_src < amount) retry
55 writeTVar src_acc (credit_src - amount)
57 writeTVar dst_acc (credit_dst + amount)
58 loop (n-1)
60 start tid src dst amount =
61 puts ("start " ++ show tid ++ ' ':show src ++ ' ':show dst ++ ' ':show amount)
63 main = do
64 hSetBuffering stdout LineBuffering
66 {-
67 args <- getArgs
68 case args of
69 [n,m] -> let g = read (n ++ ' ':m) in setStdGen g
70 [] -> do g <- getStdGen
71 print g
72 -}
74 -- for a deterministic run, we set the random seed explicitly:
77 -- HACK: the global commitVar requires atomically, so we want to seq it outside of
78 -- an enclosing atomically (otherwise STM gets very confused).
79 seq commitVar \$ return ()
81 -- print n_actors
82 -- print n_accounts
83 amounts <- sequence (take n_accounts (repeat (randomRIO (0,init_credit))))
84 -- mapM print amounts
85 tvars <- atomically \$ mapM newTVar amounts
86 let accounts = listArray (1,n_accounts) tvars
87 done <- atomically (newTVar 0)
88 sequence [ forkIO (thread id done accounts) | id <- [1..n_actors] ]
90 atomically \$ do
92 when (x < n_actors) retry
94 sourceThreadId = 0 :: Int
95 sourceAccount = 0 :: Int
97 -- A thread that alternates between dropping some cash into an account
98 -- (source), and removing some cash from an account (sink).
99 sourceSinkThread accounts = loop True
100 where loop source = do
101 amount <- randomRIO (1,max_source)
102 acct <- randomRIO (1,n_accounts)
103 if source
104 then do start sourceThreadId sourceAccount acct amount
105 transfer acct amount
106 else do start sourceThreadId acct sourceAccount amount
107 transfer acct (-amount)
108 loop (not source)
110 transfer acct amount = do
111 let t = accounts ! acct
114 writeTVar t (max 0 (x+amount)) -- never drop below zero,
115 -- and don't block.
117 -- -----------------------------------------------------------------------------
118 -- Our tracing wrapper for atomically
120 {-# NOINLINE commitVar #-}
121 commitVar = unsafePerformIO \$ atomically \$ newTVar ([] :: [Int])
123 atomically_ :: Int -> STM a -> IO a
124 atomically_ tid stm = do
125 r <- atomically \$ do
126 stmTrace ("execute " ++ show tid)
127 r <- stm `orElse` do
128 stmTrace ("retry " ++ show tid)
129 retry