Re-add more primops for atomic ops on byte arrays
authorJohan Tibell <johan.tibell@gmail.com>
Fri, 27 Jun 2014 11:48:24 +0000 (13:48 +0200)
committerJohan Tibell <johan.tibell@gmail.com>
Mon, 30 Jun 2014 20:12:45 +0000 (22:12 +0200)
This is the second attempt to add this functionality. The first
attempt was reverted in 950fcae46a82569e7cd1fba1637a23b419e00ecd, due
to register allocator failure on x86. Given how the register
allocator currently works, we don't have enough registers on x86 to
support cmpxchg using complicated addressing modes. Instead we fall
back to a simpler addressing mode on x86.

Adds the following primops:

 * atomicReadIntArray#
 * atomicWriteIntArray#
 * fetchSubIntArray#
 * fetchOrIntArray#
 * fetchXorIntArray#
 * fetchAndIntArray#

Makes these pre-existing out-of-line primops inline:

 * fetchAddIntArray#
 * casIntArray#

22 files changed:
compiler/cmm/CmmMachOp.hs
compiler/cmm/CmmSink.hs
compiler/cmm/PprC.hs
compiler/codeGen/StgCmmPrim.hs
compiler/llvmGen/Llvm/AbsSyn.hs
compiler/llvmGen/Llvm/PpLlvm.hs
compiler/llvmGen/LlvmCodeGen/CodeGen.hs
compiler/nativeGen/CPrim.hs
compiler/nativeGen/PPC/CodeGen.hs
compiler/nativeGen/SPARC/CodeGen.hs
compiler/nativeGen/X86/CodeGen.hs
compiler/nativeGen/X86/Instr.hs
compiler/nativeGen/X86/Ppr.hs
compiler/prelude/primops.txt.pp
includes/stg/MiscClosures.h
libraries/ghc-prim/cbits/atomic.c [new file with mode: 0644]
libraries/ghc-prim/ghc-prim.cabal
rts/Linker.c
rts/PrimOps.cmm
testsuite/tests/concurrent/should_run/AtomicPrimops.hs [new file with mode: 0644]
testsuite/tests/concurrent/should_run/AtomicPrimops.stdout [new file with mode: 0644]
testsuite/tests/concurrent/should_run/all.T

index c4ec393..d8ce492 100644 (file)
@@ -19,6 +19,9 @@ module CmmMachOp
     -- CallishMachOp
     , CallishMachOp(..), callishMachOpHints
     , pprCallishMachOp
+
+    -- Atomic read-modify-write
+    , AtomicMachOp(..)
    )
 where
 
@@ -547,8 +550,24 @@ data CallishMachOp
 
   | MO_PopCnt Width
   | MO_BSwap Width
+
+  -- Atomic read-modify-write.
+  | MO_AtomicRMW Width AtomicMachOp
+  | MO_AtomicRead Width
+  | MO_AtomicWrite Width
+  | MO_Cmpxchg Width
   deriving (Eq, Show)
 
+-- | The operation to perform atomically.
+data AtomicMachOp =
+      AMO_Add
+    | AMO_Sub
+    | AMO_And
+    | AMO_Nand
+    | AMO_Or
+    | AMO_Xor
+      deriving (Eq, Show)
+
 pprCallishMachOp :: CallishMachOp -> SDoc
 pprCallishMachOp mo = text (show mo)
 
index 4c02542..4dced9a 100644 (file)
@@ -650,6 +650,10 @@ data AbsMem
 -- perhaps we ought to have a special annotation for calls that can
 -- modify heap/stack memory.  For now we just use the conservative
 -- definition here.
+--
+-- Some CallishMachOp imply a memory barrier e.g. AtomicRMW and
+-- therefore we should never float any memory operations across one of
+-- these calls.
 
 
 bothMems :: AbsMem -> AbsMem -> AbsMem
index 47b247e..455c79b 100644 (file)
@@ -753,6 +753,10 @@ pprCallishMachOp_for_C mop
         MO_Memmove      -> ptext (sLit "memmove")
         (MO_BSwap w)    -> ptext (sLit $ bSwapLabel w)
         (MO_PopCnt w)   -> ptext (sLit $ popCntLabel w)
+        (MO_AtomicRMW w amop) -> ptext (sLit $ atomicRMWLabel w amop)
+        (MO_Cmpxchg w)  -> ptext (sLit $ cmpxchgLabel w)
+        (MO_AtomicRead w)  -> ptext (sLit $ atomicReadLabel w)
+        (MO_AtomicWrite w) -> ptext (sLit $ atomicWriteLabel w)
         (MO_UF_Conv w)  -> ptext (sLit $ word2FloatLabel w)
 
         MO_S_QuotRem  {} -> unsupported
index 40a5e36..e4c682b 100644 (file)
@@ -769,6 +769,25 @@ emitPrimOp _ res PrefetchByteArrayOp0        args = doPrefetchByteArrayOp 0 res
 emitPrimOp _ res PrefetchMutableByteArrayOp0 args = doPrefetchByteArrayOp 0 res args
 emitPrimOp _ res PrefetchAddrOp0             args = doPrefetchAddrOp 0 res args
 
+-- Atomic read-modify-write
+emitPrimOp dflags [res] FetchAddByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Add mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchSubByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Sub mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchAndByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_And mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchNandByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Nand mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchOrByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Or mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchXorByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Xor mba ix (bWord dflags) n
+emitPrimOp dflags [res] AtomicReadByteArrayOp_Int [mba, ix] =
+    doAtomicReadByteArray res mba ix (bWord dflags)
+emitPrimOp dflags [] AtomicWriteByteArrayOp_Int [mba, ix, val] =
+    doAtomicWriteByteArray mba ix (bWord dflags) val
+emitPrimOp dflags [res] CasByteArrayOp_Int [mba, ix, old, new] =
+    doCasByteArray res mba ix (bWord dflags) old new
 
 -- The rest just translate straightforwardly
 emitPrimOp dflags [res] op [arg]
@@ -1933,6 +1952,81 @@ doWriteSmallPtrArrayOp addr idx val = do
     emit (setInfo addr (CmmLit (CmmLabel mkSMAP_DIRTY_infoLabel)))
 
 ------------------------------------------------------------------------------
+-- Atomic read-modify-write
+
+-- | Emit an atomic modification to a byte array element. The result
+-- reg contains that previous value of the element. Implies a full
+-- memory barrier.
+doAtomicRMW :: LocalReg      -- ^ Result reg
+            -> AtomicMachOp  -- ^ Atomic op (e.g. add)
+            -> CmmExpr       -- ^ MutableByteArray#
+            -> CmmExpr       -- ^ Index
+            -> CmmType       -- ^ Type of element by which we are indexing
+            -> CmmExpr       -- ^ Op argument (e.g. amount to add)
+            -> FCode ()
+doAtomicRMW res amop mba idx idx_ty n = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_AtomicRMW width amop)
+        [ addr, n ]
+
+-- | Emit an atomic read to a byte array that acts as a memory barrier.
+doAtomicReadByteArray
+    :: LocalReg  -- ^ Result reg
+    -> CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> FCode ()
+doAtomicReadByteArray res mba idx idx_ty = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_AtomicRead width)
+        [ addr ]
+
+-- | Emit an atomic write to a byte array that acts as a memory barrier.
+doAtomicWriteByteArray
+    :: CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> CmmExpr   -- ^ Value to write
+    -> FCode ()
+doAtomicWriteByteArray mba idx idx_ty val = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ {- no results -} ]
+        (MO_AtomicWrite width)
+        [ addr, val ]
+
+doCasByteArray
+    :: LocalReg  -- ^ Result reg
+    -> CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> CmmExpr   -- ^ Old value
+    -> CmmExpr   -- ^ New value
+    -> FCode ()
+doCasByteArray res mba idx idx_ty old new = do
+    dflags <- getDynFlags
+    let width = (typeWidth idx_ty)
+        addr = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+               width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_Cmpxchg width)
+        [ addr, old, new ]
+
+------------------------------------------------------------------------------
 -- Helpers for emitting function calls
 
 -- | Emit a call to @memcpy@.
index f92bd89..24d0856 100644 (file)
@@ -65,6 +65,8 @@ data LlvmFunction = LlvmFunction {
 
 type LlvmFunctions = [LlvmFunction]
 
+type SingleThreaded = Bool
+
 -- | LLVM ordering types for synchronization purposes. (Introduced in LLVM
 -- 3.0). Please see the LLVM documentation for a better description.
 data LlvmSyncOrdering
@@ -224,6 +226,11 @@ data LlvmExpression
   | Load LlvmVar
 
   {- |
+    Atomic load of the value at location ptr
+  -}
+  | ALoad LlvmSyncOrdering SingleThreaded LlvmVar
+
+  {- |
     Navigate in an structure, selecting elements
       * inbound: Is the pointer inbounds? (computed pointer doesn't overflow)
       * ptr:     Location of the structure
index 0250782..7307725 100644 (file)
@@ -239,6 +239,7 @@ ppLlvmExpression expr
         Insert     vec elt idx      -> ppInsert vec elt idx
         GetElemPtr inb ptr indexes  -> ppGetElementPtr inb ptr indexes
         Load       ptr              -> ppLoad ptr
+        ALoad      ord st ptr       -> ppALoad ord st ptr
         Malloc     tp amount        -> ppMalloc tp amount
         Phi        tp precessors    -> ppPhi tp precessors
         Asm        asm c ty v se sk -> ppAsm asm c ty v se sk
@@ -327,13 +328,18 @@ ppSyncOrdering SyncSeqCst    = text "seq_cst"
 -- of specifying alignment.
 
 ppLoad :: LlvmVar -> SDoc
-ppLoad var
-    | isVecPtrVar var = text "load" <+> ppr var <>
-                        comma <+> text "align 1"
-    | otherwise       = text "load" <+> ppr var
+ppLoad var = text "load" <+> ppr var <> align
   where
-    isVecPtrVar :: LlvmVar -> Bool
-    isVecPtrVar = isVector . pLower . getVarType
+    align | isVector . pLower . getVarType $ var = text ", align 1"
+          | otherwise = empty
+
+ppALoad :: LlvmSyncOrdering -> SingleThreaded -> LlvmVar -> SDoc
+ppALoad ord st var = sdocWithDynFlags $ \dflags ->
+  let alignment = (llvmWidthInBits dflags $ getVarType var) `quot` 8
+      align     = text ", align" <+> ppr alignment
+      sThreaded | st        = text " singlethread"
+                | otherwise = empty
+  in text "load atomic" <+> ppr var <> sThreaded <+> ppSyncOrdering ord <> align
 
 ppStore :: LlvmVar -> LlvmVar -> SDoc
 ppStore val dst
index 5175535..4a56600 100644 (file)
@@ -15,6 +15,7 @@ import BlockId
 import CodeGen.Platform ( activeStgRegs, callerSaves )
 import CLabel
 import Cmm
+import CPrim
 import PprCmm
 import CmmUtils
 import Hoopl
@@ -32,6 +33,7 @@ import Unique
 import Data.List ( nub )
 import Data.Maybe ( catMaybes )
 
+type Atomic = Bool
 type LlvmStatements = OrdList LlvmStatement
 
 -- -----------------------------------------------------------------------------
@@ -228,6 +230,17 @@ genCall t@(PrimTarget (MO_PopCnt w)) dsts args =
 genCall t@(PrimTarget (MO_BSwap w)) dsts args =
     genCallSimpleCast w t dsts args
 
+genCall (PrimTarget (MO_AtomicRead _)) [dst] [addr] = do
+  dstV <- getCmmReg (CmmLocal dst)
+  (v1, stmts, top) <- genLoad True addr (localRegType dst)
+  let stmt1 = Store v1 dstV
+  return (stmts `snocOL` stmt1, top)
+
+-- TODO: implement these properly rather than calling to RTS functions.
+-- genCall t@(PrimTarget (MO_AtomicWrite width)) [] [addr, val] = undefined
+-- genCall t@(PrimTarget (MO_AtomicRMW width amop)) [dst] [addr, n] = undefined
+-- genCall t@(PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] = undefined
+
 -- Handle memcpy function specifically since llvm's intrinsic version takes
 -- some extra parameters.
 genCall t@(PrimTarget op) [] args'
@@ -548,7 +561,6 @@ cmmPrimOpFunctions mop = do
 
     (MO_Prefetch_Data _ )-> fsLit "llvm.prefetch"
 
-
     MO_S_QuotRem {}  -> unsupported
     MO_U_QuotRem {}  -> unsupported
     MO_U_QuotRem2 {} -> unsupported
@@ -558,6 +570,12 @@ cmmPrimOpFunctions mop = do
     MO_Touch         -> unsupported
     MO_UF_Conv _     -> unsupported
 
+    MO_AtomicRead _  -> unsupported
+
+    MO_AtomicRMW w amop -> fsLit $ atomicRMWLabel w amop
+    MO_Cmpxchg w        -> fsLit $ cmpxchgLabel w
+    MO_AtomicWrite w    -> fsLit $ atomicWriteLabel w
+
 -- | Tail function calls
 genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
 
@@ -849,7 +867,7 @@ exprToVarOpt opt e = case e of
         -> genLit opt lit
 
     CmmLoad e' ty
-        -> genLoad e' ty
+        -> genLoad False e' ty
 
     -- Cmmreg in expression is the value, so must load. If you want actual
     -- reg pointer, call getCmmReg directly.
@@ -1268,41 +1286,41 @@ genMachOp_slow _ _ _ = panic "genMachOp: More then 2 expressions in MachOp!"
 
 
 -- | Handle CmmLoad expression.
-genLoad :: CmmExpr -> CmmType -> LlvmM ExprData
+genLoad :: Atomic -> CmmExpr -> CmmType -> LlvmM ExprData
 
 -- First we try to detect a few common cases and produce better code for
 -- these then the default case. We are mostly trying to detect Cmm code
 -- like I32[Sp + n] and use 'getelementptr' operations instead of the
 -- generic case that uses casts and pointer arithmetic
-genLoad e@(CmmReg (CmmGlobal r)) ty
-    = genLoad_fast e r 0 ty
+genLoad atomic e@(CmmReg (CmmGlobal r)) ty
+    = genLoad_fast atomic e r 0 ty
 
-genLoad e@(CmmRegOff (CmmGlobal r) n) ty
-    = genLoad_fast e r n ty
+genLoad atomic e@(CmmRegOff (CmmGlobal r) n) ty
+    = genLoad_fast atomic e r n ty
 
-genLoad e@(CmmMachOp (MO_Add _) [
+genLoad atomic e@(CmmMachOp (MO_Add _) [
                             (CmmReg (CmmGlobal r)),
                             (CmmLit (CmmInt n _))])
                 ty
-    = genLoad_fast e r (fromInteger n) ty
+    = genLoad_fast atomic e r (fromInteger n) ty
 
-genLoad e@(CmmMachOp (MO_Sub _) [
+genLoad atomic e@(CmmMachOp (MO_Sub _) [
                             (CmmReg (CmmGlobal r)),
                             (CmmLit (CmmInt n _))])
                 ty
-    = genLoad_fast e r (negate $ fromInteger n) ty
+    = genLoad_fast atomic e r (negate $ fromInteger n) ty
 
 -- generic case
-genLoad e ty
+genLoad atomic e ty
     = do other <- getTBAAMeta otherN
-         genLoad_slow e ty other
+         genLoad_slow atomic e ty other
 
 -- | Handle CmmLoad expression.
 -- This is a special case for loading from a global register pointer
 -- offset such as I32[Sp+8].
-genLoad_fast :: CmmExpr -> GlobalReg -> Int -> CmmType
-                -> LlvmM ExprData
-genLoad_fast e r n ty = do
+genLoad_fast :: Atomic -> CmmExpr -> GlobalReg -> Int -> CmmType
+             -> LlvmM ExprData
+genLoad_fast atomic e r n ty = do
     dflags <- getDynFlags
     (gv, grt, s1) <- getCmmRegVal (CmmGlobal r)
     meta          <- getTBAARegMeta r
@@ -1315,7 +1333,7 @@ genLoad_fast e r n ty = do
                 case grt == ty' of
                      -- were fine
                      True -> do
-                         (var, s3) <- doExpr ty' (MExpr meta $ Load ptr)
+                         (var, s3) <- doExpr ty' (MExpr meta $ loadInstr ptr)
                          return (var, s1 `snocOL` s2 `snocOL` s3,
                                      [])
 
@@ -1323,32 +1341,34 @@ genLoad_fast e r n ty = do
                      False -> do
                          let pty = pLift ty'
                          (ptr', s3) <- doExpr pty $ Cast LM_Bitcast ptr pty
-                         (var, s4) <- doExpr ty' (MExpr meta $ Load ptr')
+                         (var, s4) <- doExpr ty' (MExpr meta $ loadInstr ptr')
                          return (var, s1 `snocOL` s2 `snocOL` s3
                                     `snocOL` s4, [])
 
             -- If its a bit type then we use the slow method since
             -- we can't avoid casting anyway.
-            False -> genLoad_slow e ty meta
-
+            False -> genLoad_slow atomic  e ty meta
+  where
+    loadInstr ptr | atomic    = ALoad SyncSeqCst False ptr
+                  | otherwise = Load ptr
 
 -- | Handle Cmm load expression.
 -- Generic case. Uses casts and pointer arithmetic if needed.
-genLoad_slow :: CmmExpr -> CmmType -> [MetaAnnot] -> LlvmM ExprData
-genLoad_slow e ty meta = do
+genLoad_slow :: Atomic -> CmmExpr -> CmmType -> [MetaAnnot] -> LlvmM ExprData
+genLoad_slow atomic e ty meta = do
     (iptr, stmts, tops) <- exprToVar e
     dflags <- getDynFlags
     case getVarType iptr of
          LMPointer _ -> do
                     (dvar, load) <- doExpr (cmmToLlvmType ty)
-                                           (MExpr meta $ Load iptr)
+                                           (MExpr meta $ loadInstr iptr)
                     return (dvar, stmts `snocOL` load, tops)
 
          i@(LMInt _) | i == llvmWord dflags -> do
                     let pty = LMPointer $ cmmToLlvmType ty
                     (ptr, cast)  <- doExpr pty $ Cast LM_Inttoptr iptr pty
                     (dvar, load) <- doExpr (cmmToLlvmType ty)
-                                           (MExpr meta $ Load ptr)
+                                           (MExpr meta $ loadInstr ptr)
                     return (dvar, stmts `snocOL` cast `snocOL` load, tops)
 
          other -> do dflags <- getDynFlags
@@ -1357,6 +1377,9 @@ genLoad_slow e ty meta = do
                             "Size of Ptr: " ++ show (llvmPtrBits dflags) ++
                             ", Size of var: " ++ show (llvmWidthInBits dflags other) ++
                             ", Var: " ++ showSDoc dflags (ppr iptr)))
+  where
+    loadInstr ptr | atomic    = ALoad SyncSeqCst False ptr
+                  | otherwise = Load ptr
 
 
 -- | Handle CmmReg expression. This will return a pointer to the stack
index a6f4cab..34782df 100644 (file)
@@ -1,11 +1,16 @@
 -- | Generating C symbol names emitted by the compiler.
 module CPrim
-    ( popCntLabel
+    ( atomicReadLabel
+    , atomicWriteLabel
+    , atomicRMWLabel
+    , cmpxchgLabel
+    , popCntLabel
     , bSwapLabel
     , word2FloatLabel
     ) where
 
 import CmmType
+import CmmMachOp
 import Outputable
 
 popCntLabel :: Width -> String
@@ -31,3 +36,46 @@ word2FloatLabel w = "hs_word2float" ++ pprWidth w
     pprWidth W32 = "32"
     pprWidth W64 = "64"
     pprWidth w   = pprPanic "word2FloatLabel: Unsupported word width " (ppr w)
+
+atomicRMWLabel :: Width -> AtomicMachOp -> String
+atomicRMWLabel w amop = "hs_atomic_" ++ pprFunName amop ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicRMWLabel: Unsupported word width " (ppr w)
+
+    pprFunName AMO_Add  = "add"
+    pprFunName AMO_Sub  = "sub"
+    pprFunName AMO_And  = "and"
+    pprFunName AMO_Nand = "nand"
+    pprFunName AMO_Or   = "or"
+    pprFunName AMO_Xor  = "xor"
+
+cmpxchgLabel :: Width -> String
+cmpxchgLabel w = "hs_cmpxchg" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "cmpxchgLabel: Unsupported word width " (ppr w)
+
+atomicReadLabel :: Width -> String
+atomicReadLabel w = "hs_atomicread" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicReadLabel: Unsupported word width " (ppr w)
+
+atomicWriteLabel :: Width -> String
+atomicWriteLabel w = "hs_atomicwrite" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicWriteLabel: Unsupported word width " (ppr w)
index 91651e6..22a2c7c 100644 (file)
@@ -1160,6 +1160,10 @@ genCCall' dflags gcp target dest_regs args0
 
                     MO_BSwap w   -> (fsLit $ bSwapLabel w, False)
                     MO_PopCnt w  -> (fsLit $ popCntLabel w, False)
+                    MO_AtomicRMW w amop -> (fsLit $ atomicRMWLabel w amop, False)
+                    MO_Cmpxchg w -> (fsLit $ cmpxchgLabel w, False)
+                    MO_AtomicRead w  -> (fsLit $ atomicReadLabel w, False)
+                    MO_AtomicWrite w -> (fsLit $ atomicWriteLabel w, False)
 
                     MO_S_QuotRem {}  -> unsupported
                     MO_U_QuotRem {}  -> unsupported
index f5e61d0..51f89d6 100644 (file)
@@ -654,6 +654,10 @@ outOfLineMachOp_table mop
 
         MO_BSwap w   -> fsLit $ bSwapLabel w
         MO_PopCnt w  -> fsLit $ popCntLabel w
+        MO_AtomicRMW w amop -> fsLit $ atomicRMWLabel w amop
+        MO_Cmpxchg w -> fsLit $ cmpxchgLabel w
+        MO_AtomicRead w -> fsLit $ atomicReadLabel w
+        MO_AtomicWrite w -> fsLit $ atomicWriteLabel w
 
         MO_S_QuotRem {}  -> unsupported
         MO_U_QuotRem {}  -> unsupported
index fa93767..8e9b49d 100644 (file)
@@ -1057,6 +1057,18 @@ getAmode' _ expr = do
   (reg,code) <- getSomeReg expr
   return (Amode (AddrBaseIndex (EABaseReg reg) EAIndexNone (ImmInt 0)) code)
 
+-- | Like 'getAmode', but on 32-bit use simple register addressing
+-- (i.e. no index register). This stops us from running out of
+-- registers on x86 when using instructions such as cmpxchg, which can
+-- use up to three virtual registers and one fixed register.
+getSimpleAmode :: DynFlags -> Bool -> CmmExpr -> NatM Amode
+getSimpleAmode dflags is32Bit addr
+    | is32Bit = do
+        addr_code <- getAnyReg addr
+        addr_r <- getNewRegNat (intSize (wordWidth dflags))
+        let amode = AddrBaseIndex (EABaseReg addr_r) EAIndexNone (ImmInt 0)
+        return $! Amode amode (addr_code addr_r)
+    | otherwise = getAmode addr
 
 x86_complex_amode :: CmmExpr -> CmmExpr -> Integer -> Integer -> NatM Amode
 x86_complex_amode base index shift offset
@@ -1761,6 +1773,99 @@ genCCall dflags is32Bit (PrimTarget (MO_UF_Conv width)) dest_regs args = do
   where
     lbl = mkCmmCodeLabel primPackageId (fsLit (word2FloatLabel width))
 
+genCCall dflags is32Bit (PrimTarget (MO_AtomicRMW width amop)) [dst] [addr, n] = do
+    Amode amode addr_code <-
+        if amop `elem` [AMO_Add, AMO_Sub]
+        then getAmode addr
+        else getSimpleAmode dflags is32Bit addr  -- See genCCall for MO_Cmpxchg
+    arg <- getNewRegNat size
+    arg_code <- getAnyReg n
+    use_sse2 <- sse2Enabled
+    let platform = targetPlatform dflags
+        dst_r    = getRegisterReg platform use_sse2 (CmmLocal dst)
+    code <- op_code dst_r arg amode
+    return $ addr_code `appOL` arg_code arg `appOL` code
+  where
+    -- Code for the operation
+    op_code :: Reg       -- Destination reg
+            -> Reg       -- Register containing argument
+            -> AddrMode  -- Address of location to mutate
+            -> NatM (OrdList Instr)
+    op_code dst_r arg amode = case amop of
+        -- In the common case where dst_r is a virtual register the
+        -- final move should go away, because it's the last use of arg
+        -- and the first use of dst_r.
+        AMO_Add  -> return $ toOL [ LOCK
+                                  , XADD size (OpReg arg) (OpAddr amode)
+                                  , MOV size (OpReg arg) (OpReg dst_r)
+                                  ]
+        AMO_Sub  -> return $ toOL [ NEGI size (OpReg arg)
+                                  , LOCK
+                                  , XADD size (OpReg arg) (OpAddr amode)
+                                  , MOV size (OpReg arg) (OpReg dst_r)
+                                  ]
+        AMO_And  -> cmpxchg_code (\ src dst -> unitOL $ AND size src dst)
+        AMO_Nand -> cmpxchg_code (\ src dst -> toOL [ AND size src dst
+                                                    , NOT size dst
+                                                    ])
+        AMO_Or   -> cmpxchg_code (\ src dst -> unitOL $ OR size src dst)
+        AMO_Xor  -> cmpxchg_code (\ src dst -> unitOL $ XOR size src dst)
+      where
+        -- Simulate operation that lacks a dedicated instruction using
+        -- cmpxchg.
+        cmpxchg_code :: (Operand -> Operand -> OrdList Instr)
+                     -> NatM (OrdList Instr)
+        cmpxchg_code instrs = do
+            lbl <- getBlockIdNat
+            tmp <- getNewRegNat size
+            return $ toOL
+                [ MOV size (OpAddr amode) (OpReg eax)
+                , JXX ALWAYS lbl
+                , NEWBLOCK lbl
+                  -- Keep old value so we can return it:
+                , MOV size (OpReg eax) (OpReg dst_r)
+                , MOV size (OpReg eax) (OpReg tmp)
+                ]
+                `appOL` instrs (OpReg arg) (OpReg tmp) `appOL` toOL
+                [ LOCK
+                , CMPXCHG size (OpReg tmp) (OpAddr amode)
+                , JXX NE lbl
+                ]
+
+    size = intSize width
+
+genCCall dflags _ (PrimTarget (MO_AtomicRead width)) [dst] [addr] = do
+  load_code <- intLoadCode (MOV (intSize width)) addr
+  let platform = targetPlatform dflags
+  use_sse2 <- sse2Enabled
+  return (load_code (getRegisterReg platform use_sse2 (CmmLocal dst)))
+
+genCCall _ _ (PrimTarget (MO_AtomicWrite width)) [] [addr, val] = do
+    assignMem_IntCode (intSize width) addr val
+
+genCCall dflags is32Bit (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] = do
+    -- On x86 we don't have enough registers to use cmpxchg with a
+    -- complicated addressing mode, so on that architecture we
+    -- pre-compute the address first.
+    Amode amode addr_code <- getSimpleAmode dflags is32Bit addr
+    newval <- getNewRegNat size
+    newval_code <- getAnyReg new
+    oldval <- getNewRegNat size
+    oldval_code <- getAnyReg old
+    use_sse2 <- sse2Enabled
+    let platform = targetPlatform dflags
+        dst_r    = getRegisterReg platform use_sse2 (CmmLocal dst)
+        code     = toOL
+                   [ MOV size (OpReg oldval) (OpReg eax)
+                   , LOCK
+                   , CMPXCHG size (OpReg newval) (OpAddr amode)
+                   , MOV size (OpReg eax) (OpReg dst_r)
+                   ]
+    return $ addr_code `appOL` newval_code newval `appOL` oldval_code oldval
+        `appOL` code
+  where
+    size = intSize width
+
 genCCall _ is32Bit target dest_regs args
  | is32Bit   = genCCall32 target dest_regs args
  | otherwise = genCCall64 target dest_regs args
@@ -2385,6 +2490,11 @@ outOfLineCmmOp mop res args
               MO_PopCnt _  -> fsLit "popcnt"
               MO_BSwap _   -> fsLit "bswap"
 
+              MO_AtomicRMW _ _ -> fsLit "atomicrmw"
+              MO_AtomicRead _  -> fsLit "atomicread"
+              MO_AtomicWrite _ -> fsLit "atomicwrite"
+              MO_Cmpxchg _     -> fsLit "cmpxchg"
+
               MO_UF_Conv _ -> unsupported
 
               MO_S_QuotRem {}  -> unsupported
index 05fff9b..ac91747 100644 (file)
@@ -327,6 +327,10 @@ data Instr
         | PREFETCH  PrefetchVariant Size Operand -- prefetch Variant, addr size, address to prefetch
                                         -- variant can be NTA, Lvl0, Lvl1, or Lvl2
 
+        | LOCK  -- lock prefix
+        | XADD        Size Operand Operand  -- src (r), dst (r/m)
+        | CMPXCHG     Size Operand Operand  -- src (r), dst (r/m), eax implicit
+
 data PrefetchVariant = NTA | Lvl0 | Lvl1 | Lvl2
 
 
@@ -337,6 +341,8 @@ data Operand
 
 
 
+-- | Returns which registers are read and written as a (read, written)
+-- pair.
 x86_regUsageOfInstr :: Platform -> Instr -> RegUsage
 x86_regUsageOfInstr platform instr
  = case instr of
@@ -428,10 +434,21 @@ x86_regUsageOfInstr platform instr
 
     -- note: might be a better way to do this
     PREFETCH _  _ src -> mkRU (use_R src []) []
+    LOCK                -> noUsage
+    XADD _ src dst      -> usageMM src dst
+    CMPXCHG _ src dst   -> usageRMM src dst (OpReg eax)
 
     _other              -> panic "regUsage: unrecognised instr"
-
  where
+    -- # Definitions
+    --
+    -- Written: If the operand is a register, it's written. If it's an
+    -- address, registers mentioned in the address are read.
+    --
+    -- Modified: If the operand is a register, it's both read and
+    -- written. If it's an address, registers mentioned in the address
+    -- are read.
+
     -- 2 operand form; first operand Read; second Written
     usageRW :: Operand -> Operand -> RegUsage
     usageRW op (OpReg reg)      = mkRU (use_R op []) [reg]
@@ -444,6 +461,18 @@ x86_regUsageOfInstr platform instr
     usageRM op (OpAddr ea)      = mkRUR (use_R op $! use_EA ea [])
     usageRM _ _                 = panic "X86.RegInfo.usageRM: no match"
 
+    -- 2 operand form; first operand Modified; second Modified
+    usageMM :: Operand -> Operand -> RegUsage
+    usageMM (OpReg src) (OpReg dst) = mkRU [src, dst] [src, dst]
+    usageMM (OpReg src) (OpAddr ea) = mkRU (use_EA ea [src]) [src]
+    usageMM _ _                     = panic "X86.RegInfo.usageMM: no match"
+
+    -- 3 operand form; first operand Read; second Modified; third Modified
+    usageRMM :: Operand -> Operand -> Operand -> RegUsage
+    usageRMM (OpReg src) (OpReg dst) (OpReg reg) = mkRU [src, dst, reg] [dst, reg]
+    usageRMM (OpReg src) (OpAddr ea) (OpReg reg) = mkRU (use_EA ea [src, reg]) [reg]
+    usageRMM _ _ _                               = panic "X86.RegInfo.usageRMM: no match"
+
     -- 1 operand form; operand Modified
     usageM :: Operand -> RegUsage
     usageM (OpReg reg)          = mkRU [reg] [reg]
@@ -476,6 +505,7 @@ x86_regUsageOfInstr platform instr
         where src' = filter (interesting platform) src
               dst' = filter (interesting platform) dst
 
+-- | Is this register interesting for the register allocator?
 interesting :: Platform -> Reg -> Bool
 interesting _        (RegVirtual _)              = True
 interesting platform (RegReal (RealRegSingle i)) = isFastTrue (freeReg platform i)
@@ -483,6 +513,8 @@ interesting _        (RegReal (RealRegPair{}))   = panic "X86.interesting: no re
 
 
 
+-- | Applies the supplied function to all registers in instructions.
+-- Typically used to change virtual registers to real registers.
 x86_patchRegsOfInstr :: Instr -> (Reg -> Reg) -> Instr
 x86_patchRegsOfInstr instr env
  = case instr of
@@ -571,6 +603,10 @@ x86_patchRegsOfInstr instr env
 
     PREFETCH lvl size src -> PREFETCH lvl size (patchOp src)
 
+    LOCK                -> instr
+    XADD sz src dst     -> patch2 (XADD sz) src dst
+    CMPXCHG sz src dst  -> patch2 (CMPXCHG sz) src dst
+
     _other              -> panic "patchRegs: unrecognised instr"
 
   where
index 459c041..7771c02 100644 (file)
@@ -886,6 +886,14 @@ pprInstr GFREE
             ptext (sLit "\tffree %st(4) ;ffree %st(5)")
           ]
 
+-- Atomics
+
+pprInstr LOCK = ptext (sLit "\tlock")
+
+pprInstr (XADD size src dst) = pprSizeOpOp (sLit "xadd") size src dst
+
+pprInstr (CMPXCHG size src dst) = pprSizeOpOp (sLit "cmpxchg") size src dst
+
 pprInstr _
         = panic "X86.Ppr.pprInstr: no match"
 
index 4851315..4faa585 100644 (file)
@@ -1363,19 +1363,79 @@ primop  SetByteArrayOp "setByteArray#" GenPrimOp
   code_size = { primOpCodeSizeForeignCall + 4 }
   can_fail = True
 
+-- Atomic operations
+
+primop  AtomicReadByteArrayOp_Int "atomicReadIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array and an offset in Int units, read an element. The
+    index is assumed to be in bounds. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop  AtomicWriteByteArrayOp_Int "atomicWriteIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
+   {Given an array and an offset in Int units, write an element. The
+    index is assumed to be in bounds. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
 primop CasByteArrayOp_Int "casIntArray#" GenPrimOp
    MutableByteArray# s -> Int# -> Int# -> Int# -> State# s -> (# State# s, Int# #)
-   {Machine-level atomic compare and swap on a word within a ByteArray.}
-   with
-   out_of_line = True
-   has_side_effects = True
+   {Given an array, an offset in Int units, the expected old value, and
+    the new value, perform an atomic compare and swap i.e. write the new
+    value if the current value matches the provided old value. Returns
+    the value of the element before the operation. Implies a full memory
+    barrier.}
+   with has_side_effects = True
+        can_fail = True
 
 primop FetchAddByteArrayOp_Int "fetchAddIntArray#" GenPrimOp
    MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
-   {Machine-level word-sized fetch-and-add within a ByteArray.}
-   with
-   out_of_line = True
-   has_side_effects = True
+   {Given an array, and offset in Int units, and a value to add,
+    atomically add the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchSubByteArrayOp_Int "fetchSubIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to subtract,
+    atomically substract the value to the element. Returns the value of
+    the element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchAndByteArrayOp_Int "fetchAndIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to AND,
+    atomically AND the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchNandByteArrayOp_Int "fetchNandIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to NAND,
+    atomically NAND the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchOrByteArrayOp_Int "fetchOrIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to OR,
+    atomically OR the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchXorByteArrayOp_Int "fetchXorIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to XOR,
+    atomically XOR the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
 
 
 ------------------------------------------------------------------------
index 0c4d2f9..ee5a119 100644 (file)
@@ -348,7 +348,6 @@ RTS_FUN_DECL(stg_newByteArrayzh);
 RTS_FUN_DECL(stg_newPinnedByteArrayzh);
 RTS_FUN_DECL(stg_newAlignedPinnedByteArrayzh);
 RTS_FUN_DECL(stg_casIntArrayzh);
-RTS_FUN_DECL(stg_fetchAddIntArrayzh);
 RTS_FUN_DECL(stg_newArrayzh);
 RTS_FUN_DECL(stg_newArrayArrayzh);
 RTS_FUN_DECL(stg_copyArrayzh);
diff --git a/libraries/ghc-prim/cbits/atomic.c b/libraries/ghc-prim/cbits/atomic.c
new file mode 100644 (file)
index 0000000..e3d6cc1
--- /dev/null
@@ -0,0 +1,306 @@
+#include "Rts.h"
+
+// Fallbacks for atomic primops on byte arrays. The builtins used
+// below are supported on both GCC and LLVM.
+//
+// Ideally these function would take StgWord8, StgWord16, etc but
+// older GCC versions incorrectly assume that the register that the
+// argument is passed in has been zero extended, which is incorrect
+// according to the ABI and is not what GHC does when it generates
+// calls to these functions.
+
+// FetchAddByteArrayOp_Int
+
+extern StgWord hs_atomic_add8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_add8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_add16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_add16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_add32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_add32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_add64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_add64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_add(x, val);
+}
+
+// FetchSubByteArrayOp_Int
+
+extern StgWord hs_atomic_sub8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_sub8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_sub16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_sub16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_sub32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_sub32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_sub64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_sub64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_sub(x, val);
+}
+
+// FetchAndByteArrayOp_Int
+
+extern StgWord hs_atomic_and8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_and8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_and16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_and16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_and32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_and32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_and64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_and64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_and(x, val);
+}
+
+// FetchNandByteArrayOp_Int
+
+// Workaround for http://llvm.org/bugs/show_bug.cgi?id=8842
+#define CAS_NAND(x, val)                                            \
+  {                                                                 \
+    __typeof__ (*(x)) tmp = *(x);                                   \
+    while (!__sync_bool_compare_and_swap(x, tmp, ~(tmp & (val)))) { \
+      tmp = *(x);                                                   \
+    }                                                               \
+    return tmp;                                                     \
+  }
+
+extern StgWord hs_atomic_nand8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_nand8(volatile StgWord8 *x, StgWord val)
+{
+#ifdef __clang__
+  CAS_NAND(x, (StgWord8) val)
+#else
+  return __sync_fetch_and_nand(x, (StgWord8) val);
+#endif
+}
+
+extern StgWord hs_atomic_nand16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_nand16(volatile StgWord16 *x, StgWord val)
+{
+#ifdef __clang__
+  CAS_NAND(x, (StgWord16) val);
+#else
+  return __sync_fetch_and_nand(x, (StgWord16) val);
+#endif
+}
+
+extern StgWord hs_atomic_nand32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_nand32(volatile StgWord32 *x, StgWord val)
+{
+#ifdef __clang__
+  CAS_NAND(x, (StgWord32) val);
+#else
+  return __sync_fetch_and_nand(x, (StgWord32) val);
+#endif
+}
+
+extern StgWord64 hs_atomic_nand64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_nand64(volatile StgWord64 *x, StgWord64 val)
+{
+#ifdef __clang__
+  CAS_NAND(x, val);
+#else
+  return __sync_fetch_and_nand(x, val);
+#endif
+}
+
+// FetchOrByteArrayOp_Int
+
+extern StgWord hs_atomic_or8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_or8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_or16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_or16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_or32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_or32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_or64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_or64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_or(x, val);
+}
+
+// FetchXorByteArrayOp_Int
+
+extern StgWord hs_atomic_xor8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_xor8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_xor16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_xor16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_xor32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_xor32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_xor64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_xor64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_xor(x, val);
+}
+
+// CasByteArrayOp_Int
+
+extern StgWord hs_cmpxchg8(volatile StgWord8 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg8(volatile StgWord8 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord8) old, (StgWord8) new);
+}
+
+extern StgWord hs_cmpxchg16(volatile StgWord16 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg16(volatile StgWord16 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord16) old, (StgWord16) new);
+}
+
+extern StgWord hs_cmpxchg32(volatile StgWord32 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg32(volatile StgWord32 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord32) old, (StgWord32) new);
+}
+
+extern StgWord hs_cmpxchg64(volatile StgWord64 *x, StgWord64 old, StgWord64 new);
+StgWord
+hs_cmpxchg64(volatile StgWord64 *x, StgWord64 old, StgWord64 new)
+{
+  return __sync_val_compare_and_swap(x, old, new);
+}
+
+// AtomicReadByteArrayOp_Int
+
+extern StgWord hs_atomicread8(volatile StgWord8 *x);
+StgWord
+hs_atomicread8(volatile StgWord8 *x)
+{
+  return *x;
+}
+
+extern StgWord hs_atomicread16(volatile StgWord16 *x);
+StgWord
+hs_atomicread16(volatile StgWord16 *x)
+{
+  return *x;
+}
+
+extern StgWord hs_atomicread32(volatile StgWord32 *x);
+StgWord
+hs_atomicread32(volatile StgWord32 *x)
+{
+  return *x;
+}
+
+extern StgWord64 hs_atomicread64(volatile StgWord64 *x);
+StgWord64
+hs_atomicread64(volatile StgWord64 *x)
+{
+  return *x;
+}
+
+// AtomicWriteByteArrayOp_Int
+
+extern void hs_atomicwrite8(volatile StgWord8 *x, StgWord val);
+void
+hs_atomicwrite8(volatile StgWord8 *x, StgWord val)
+{
+  *x = (StgWord8) val;
+}
+
+extern void hs_atomicwrite16(volatile StgWord16 *x, StgWord val);
+void
+hs_atomicwrite16(volatile StgWord16 *x, StgWord val)
+{
+  *x = (StgWord16) val;
+}
+
+extern void hs_atomicwrite32(volatile StgWord32 *x, StgWord val);
+void
+hs_atomicwrite32(volatile StgWord32 *x, StgWord val)
+{
+  *x = (StgWord32) val;
+}
+
+extern void hs_atomicwrite64(volatile StgWord64 *x, StgWord64 val);
+void
+hs_atomicwrite64(volatile StgWord64 *x, StgWord64 val)
+{
+  *x = (StgWord64) val;
+}
index c861342..bc9f571 100644 (file)
@@ -52,6 +52,7 @@ Library
         exposed-modules: GHC.Prim
 
     c-sources:
+        cbits/atomic.c
         cbits/bswap.c
         cbits/debug.c
         cbits/longlong.c
index e5e61bb..ad96d74 100644 (file)
@@ -1186,7 +1186,6 @@ typedef struct _RtsSymbolVal {
       SymI_HasProto(stg_newBCOzh)                                       \
       SymI_HasProto(stg_newByteArrayzh)                                 \
       SymI_HasProto(stg_casIntArrayzh)                                  \
-      SymI_HasProto(stg_fetchAddIntArrayzh)                             \
       SymI_HasProto(stg_newMVarzh)                                      \
       SymI_HasProto(stg_newMutVarzh)                                    \
       SymI_HasProto(stg_newTVarzh)                                      \
index 4d7baca..5f04a6d 100644 (file)
@@ -151,18 +151,6 @@ stg_casIntArrayzh( gcptr arr, W_ ind, W_ old, W_ new )
 }
 
 
-stg_fetchAddIntArrayzh( gcptr arr, W_ ind, W_ incr )
-/* MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #) */
-{
-    W_ p, h;
-
-    p = arr + SIZEOF_StgArrWords + WDS(ind);
-    (h) = ccall atomic_inc(p, incr);
-
-    return(h);
-}
-
-
 stg_newArrayzh ( W_ n /* words */, gcptr init )
 {
     W_ words, size, p;
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs
new file mode 100644 (file)
index 0000000..0c55aba
--- /dev/null
@@ -0,0 +1,245 @@
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+
+module Main ( main ) where
+
+import Control.Concurrent
+import Control.Concurrent.MVar
+import Control.Monad (when)
+import Foreign.Storable
+import GHC.Exts
+import GHC.IO
+
+-- | Iterations per worker.
+iters :: Int
+iters = 1000000
+
+main :: IO ()
+main = do
+    fetchAddSubTest
+    fetchAndTest
+    fetchNandTest
+    fetchOrTest
+    fetchXorTest
+    casTest
+    readWriteTest
+
+-- | Test fetchAddIntArray# by having two threads concurrenctly
+-- increment a counter and then checking the sum at the end.
+fetchAddSubTest :: IO ()
+fetchAddSubTest = do
+    tot <- race 0
+        (\ mba -> work fetchAddIntArray mba iters 2)
+        (\ mba -> work fetchSubIntArray mba iters 1)
+    assertEq 1000000 tot "fetchAddSubTest"
+  where
+    work :: (MByteArray -> Int -> Int -> IO ()) -> MByteArray -> Int -> Int
+         -> IO ()
+    work op mba 0 val = return ()
+    work op mba n val = op mba 0 val >> work op mba (n-1) val
+
+-- | Test fetchXorIntArray# by having two threads concurrenctly XORing
+-- and then checking the result at the end. Works since XOR is
+-- commutative.
+--
+-- Covers the code paths for AND, NAND, and OR as well.
+fetchXorTest :: IO ()
+fetchXorTest = do
+    res <- race n0
+        (\ mba -> work mba iters t1pat)
+        (\ mba -> work mba iters t2pat)
+    assertEq expected res "fetchXorTest"
+  where
+    work :: MByteArray -> Int -> Int -> IO ()
+    work mba 0 val = return ()
+    work mba n val = fetchXorIntArray mba 0 val >> work mba (n-1) val
+
+    -- Initial value is a large prime and the two patterns are 1010...
+    -- and 0101...
+    (n0, t1pat, t2pat)
+        | sizeOf (undefined :: Int) == 8 =
+            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
+        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+    expected
+        | sizeOf (undefined :: Int) == 8 = 4294967295
+        | otherwise = 65535
+
+-- The tests for AND, NAND, and OR are trivial for two reasons:
+--
+--  * The code path is already well exercised by 'fetchXorTest'.
+--
+--  * It's harder to test these operations, as a long sequence of them
+--    convert to a single value but we'd like to write a test in the
+--    style of 'fetchXorTest' that applies the operation repeatedly,
+--    to make it likely that any race conditions are detected.
+--
+-- Right now we only test that they return the correct value for a
+-- single op on each thread.
+
+fetchOpTest :: (MByteArray -> Int -> Int -> IO ())
+            -> Int -> String -> IO ()
+fetchOpTest op expected name = do
+    res <- race n0
+        (\ mba -> work mba t1pat)
+        (\ mba -> work mba t2pat)
+    assertEq expected res name
+  where
+    work :: MByteArray -> Int -> IO ()
+    work mba val = op mba 0 val
+
+    -- Initial value is a large prime and the two patterns are 1010...
+    -- and 0101...
+    (n0, t1pat, t2pat)
+        | sizeOf (undefined :: Int) == 8 =
+            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
+        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+
+fetchAndTest :: IO ()
+fetchAndTest = fetchOpTest fetchAndIntArray expected "fetchAndTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 286331153
+            | otherwise = 4369
+
+fetchNandTest :: IO ()
+fetchNandTest = fetchOpTest fetchNandIntArray expected "fetchNandTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 7378697629770151799
+            | otherwise = -2576976009
+
+fetchOrTest :: IO ()
+fetchOrTest = fetchOpTest fetchOrIntArray expected "fetchOrTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 15987178197787607039
+            | otherwise = 3722313727
+
+-- | Test casIntArray# by using it to emulate fetchAddIntArray# and
+-- then having two threads concurrenctly increment a counter,
+-- checking the sum at the end.
+casTest :: IO ()
+casTest = do
+    tot <- race 0
+        (\ mba -> work mba iters 1)
+        (\ mba -> work mba iters 2)
+    assertEq 3000000 tot "casTest"
+  where
+    work :: MByteArray -> Int -> Int -> IO ()
+    work mba 0 val = return ()
+    work mba n val = add mba 0 val >> work mba (n-1) val
+
+    -- Fetch-and-add implemented using CAS.
+    add :: MByteArray -> Int -> Int -> IO ()
+    add mba ix n = do
+        old <- readIntArray mba ix
+        old' <- casIntArray mba ix old (old + n)
+        when (old /= old') $ add mba ix n
+
+-- | Tests atomic reads and writes by making sure that one thread sees
+-- updates that are done on another. This test isn't very good at the
+-- moment, as this might work even without atomic ops, but at least it
+-- exercises the code.
+readWriteTest :: IO ()
+readWriteTest = do
+    mba <- newByteArray (sizeOf (undefined :: Int))
+    writeIntArray mba 0 0
+    latch <- newEmptyMVar
+    done <- newEmptyMVar
+    forkIO $ do
+        takeMVar latch
+        n <- atomicReadIntArray mba 0
+        assertEq 1 n "readWriteTest"
+        putMVar done ()
+    atomicWriteIntArray mba 0 1
+    putMVar latch ()
+    takeMVar done
+
+-- | Create two threads that mutate the byte array passed to them
+-- concurrently. The array is one word large.
+race :: Int                    -- ^ Initial value of array element
+     -> (MByteArray -> IO ())  -- ^ Thread 1 action
+     -> (MByteArray -> IO ())  -- ^ Thread 2 action
+     -> IO Int                 -- ^ Final value of array element
+race n0 thread1 thread2 = do
+    done1 <- newEmptyMVar
+    done2 <- newEmptyMVar
+    mba <- newByteArray (sizeOf (undefined :: Int))
+    writeIntArray mba 0 n0
+    forkIO $ thread1 mba >> putMVar done1 ()
+    forkIO $ thread2 mba >> putMVar done2 ()
+    mapM_ takeMVar [done1, done2]
+    readIntArray mba 0
+
+------------------------------------------------------------------------
+-- Test helper
+
+assertEq :: (Eq a, Show a) => a -> a -> String -> IO ()
+assertEq expected actual name
+    | expected == actual = putStrLn $ name ++ ": OK"
+    | otherwise = do
+        putStrLn $ name ++ ": FAIL"
+        putStrLn $ "Expected: " ++ show expected
+        putStrLn $ "  Actual: " ++ show actual
+
+------------------------------------------------------------------------
+-- Wrappers around MutableByteArray#
+
+data MByteArray = MBA (MutableByteArray# RealWorld)
+
+fetchAddIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchAddIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchAddIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchSubIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchSubIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchSubIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchAndIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchAndIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchAndIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchNandIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchNandIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchNandIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchOrIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchOrIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchOrIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchXorIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchXorIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchXorIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+newByteArray :: Int -> IO MByteArray
+newByteArray (I# n#) = IO $ \ s# ->
+    case newByteArray# n# s# of
+        (# s2#, mba# #) -> (# s2#, MBA mba# #)
+
+writeIntArray :: MByteArray -> Int -> Int -> IO ()
+writeIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case writeIntArray# mba# ix# n# s# of
+        s2# -> (# s2#, () #)
+
+readIntArray :: MByteArray -> Int -> IO Int
+readIntArray (MBA mba#) (I# ix#) = IO $ \ s# ->
+    case readIntArray# mba# ix# s# of
+        (# s2#, n# #) -> (# s2#, I# n# #)
+
+atomicWriteIntArray :: MByteArray -> Int -> Int -> IO ()
+atomicWriteIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case atomicWriteIntArray# mba# ix# n# s# of
+        s2# -> (# s2#, () #)
+
+atomicReadIntArray :: MByteArray -> Int -> IO Int
+atomicReadIntArray (MBA mba#) (I# ix#) = IO $ \ s# ->
+    case atomicReadIntArray# mba# ix# s# of
+        (# s2#, n# #) -> (# s2#, I# n# #)
+
+casIntArray :: MByteArray -> Int -> Int -> Int -> IO Int
+casIntArray (MBA mba#) (I# ix#) (I# old#) (I# new#) = IO $ \ s# ->
+    case casIntArray# mba# ix# old# new# s# of
+        (# s2#, old2# #) -> (# s2#, I# old2# #)
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout
new file mode 100644 (file)
index 0000000..c37041a
--- /dev/null
@@ -0,0 +1,7 @@
+fetchAddSubTest: OK
+fetchAndTest: OK
+fetchNandTest: OK
+fetchOrTest: OK
+fetchXorTest: OK
+casTest: OK
+readWriteTest: OK
index 0b502c3..0a66892 100644 (file)
@@ -81,6 +81,7 @@ test('tryReadMVar1', normal, compile_and_run, [''])
 test('tryReadMVar2', normal, compile_and_run, [''])
 
 test('T7970', normal, compile_and_run, [''])
+test('AtomicPrimops', normal, compile_and_run, [''])
 
 # -----------------------------------------------------------------------------
 # These tests we only do for a full run