Add more primops for atomic ops on byte arrays
authorJohan Tibell <johan.tibell@gmail.com>
Mon, 9 Jun 2014 09:43:21 +0000 (11:43 +0200)
committerJohan Tibell <johan.tibell@gmail.com>
Tue, 24 Jun 2014 17:06:53 +0000 (19:06 +0200)
Summary:
Add more primops for atomic ops on byte arrays

Adds the following primops:

 * atomicReadIntArray#
 * atomicWriteIntArray#
 * fetchSubIntArray#
 * fetchOrIntArray#
 * fetchXorIntArray#
 * fetchAndIntArray#

Makes these pre-existing out-of-line primops inline:

 * fetchAddIntArray#
 * casIntArray#

23 files changed:
compiler/cmm/CmmMachOp.hs
compiler/cmm/CmmSink.hs
compiler/cmm/PprC.hs
compiler/codeGen/StgCmmPrim.hs
compiler/llvmGen/Llvm/AbsSyn.hs
compiler/llvmGen/Llvm/PpLlvm.hs
compiler/llvmGen/LlvmCodeGen/CodeGen.hs
compiler/nativeGen/CPrim.hs
compiler/nativeGen/PPC/CodeGen.hs
compiler/nativeGen/SPARC/CodeGen.hs
compiler/nativeGen/X86/CodeGen.hs
compiler/nativeGen/X86/Instr.hs
compiler/nativeGen/X86/Ppr.hs
compiler/prelude/primops.txt.pp
includes/stg/MiscClosures.h
libraries/ghc-prim/cbits/atomic.c [new file with mode: 0644]
libraries/ghc-prim/ghc-prim.cabal
rts/Linker.c
rts/PrimOps.cmm
testsuite/tests/concurrent/should_run/.gitignore
testsuite/tests/concurrent/should_run/AtomicPrimops.hs [new file with mode: 0644]
testsuite/tests/concurrent/should_run/AtomicPrimops.stdout [new file with mode: 0644]
testsuite/tests/concurrent/should_run/all.T

index c4ec393..d8ce492 100644 (file)
@@ -19,6 +19,9 @@ module CmmMachOp
     -- CallishMachOp
     , CallishMachOp(..), callishMachOpHints
     , pprCallishMachOp
+
+    -- Atomic read-modify-write
+    , AtomicMachOp(..)
    )
 where
 
@@ -547,8 +550,24 @@ data CallishMachOp
 
   | MO_PopCnt Width
   | MO_BSwap Width
+
+  -- Atomic read-modify-write.
+  | MO_AtomicRMW Width AtomicMachOp
+  | MO_AtomicRead Width
+  | MO_AtomicWrite Width
+  | MO_Cmpxchg Width
   deriving (Eq, Show)
 
+-- | The operation to perform atomically.
+data AtomicMachOp =
+      AMO_Add
+    | AMO_Sub
+    | AMO_And
+    | AMO_Nand
+    | AMO_Or
+    | AMO_Xor
+      deriving (Eq, Show)
+
 pprCallishMachOp :: CallishMachOp -> SDoc
 pprCallishMachOp mo = text (show mo)
 
index 4c02542..4dced9a 100644 (file)
@@ -650,6 +650,10 @@ data AbsMem
 -- perhaps we ought to have a special annotation for calls that can
 -- modify heap/stack memory.  For now we just use the conservative
 -- definition here.
+--
+-- Some CallishMachOp imply a memory barrier e.g. AtomicRMW and
+-- therefore we should never float any memory operations across one of
+-- these calls.
 
 
 bothMems :: AbsMem -> AbsMem -> AbsMem
index 47b247e..455c79b 100644 (file)
@@ -753,6 +753,10 @@ pprCallishMachOp_for_C mop
         MO_Memmove      -> ptext (sLit "memmove")
         (MO_BSwap w)    -> ptext (sLit $ bSwapLabel w)
         (MO_PopCnt w)   -> ptext (sLit $ popCntLabel w)
+        (MO_AtomicRMW w amop) -> ptext (sLit $ atomicRMWLabel w amop)
+        (MO_Cmpxchg w)  -> ptext (sLit $ cmpxchgLabel w)
+        (MO_AtomicRead w)  -> ptext (sLit $ atomicReadLabel w)
+        (MO_AtomicWrite w) -> ptext (sLit $ atomicWriteLabel w)
         (MO_UF_Conv w)  -> ptext (sLit $ word2FloatLabel w)
 
         MO_S_QuotRem  {} -> unsupported
index 40a5e36..e4c682b 100644 (file)
@@ -769,6 +769,25 @@ emitPrimOp _ res PrefetchByteArrayOp0        args = doPrefetchByteArrayOp 0 res
 emitPrimOp _ res PrefetchMutableByteArrayOp0 args = doPrefetchByteArrayOp 0 res args
 emitPrimOp _ res PrefetchAddrOp0             args = doPrefetchAddrOp 0 res args
 
+-- Atomic read-modify-write
+emitPrimOp dflags [res] FetchAddByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Add mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchSubByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Sub mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchAndByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_And mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchNandByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Nand mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchOrByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Or mba ix (bWord dflags) n
+emitPrimOp dflags [res] FetchXorByteArrayOp_Int [mba, ix, n] =
+    doAtomicRMW res AMO_Xor mba ix (bWord dflags) n
+emitPrimOp dflags [res] AtomicReadByteArrayOp_Int [mba, ix] =
+    doAtomicReadByteArray res mba ix (bWord dflags)
+emitPrimOp dflags [] AtomicWriteByteArrayOp_Int [mba, ix, val] =
+    doAtomicWriteByteArray mba ix (bWord dflags) val
+emitPrimOp dflags [res] CasByteArrayOp_Int [mba, ix, old, new] =
+    doCasByteArray res mba ix (bWord dflags) old new
 
 -- The rest just translate straightforwardly
 emitPrimOp dflags [res] op [arg]
@@ -1933,6 +1952,81 @@ doWriteSmallPtrArrayOp addr idx val = do
     emit (setInfo addr (CmmLit (CmmLabel mkSMAP_DIRTY_infoLabel)))
 
 ------------------------------------------------------------------------------
+-- Atomic read-modify-write
+
+-- | Emit an atomic modification to a byte array element. The result
+-- reg contains that previous value of the element. Implies a full
+-- memory barrier.
+doAtomicRMW :: LocalReg      -- ^ Result reg
+            -> AtomicMachOp  -- ^ Atomic op (e.g. add)
+            -> CmmExpr       -- ^ MutableByteArray#
+            -> CmmExpr       -- ^ Index
+            -> CmmType       -- ^ Type of element by which we are indexing
+            -> CmmExpr       -- ^ Op argument (e.g. amount to add)
+            -> FCode ()
+doAtomicRMW res amop mba idx idx_ty n = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_AtomicRMW width amop)
+        [ addr, n ]
+
+-- | Emit an atomic read to a byte array that acts as a memory barrier.
+doAtomicReadByteArray
+    :: LocalReg  -- ^ Result reg
+    -> CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> FCode ()
+doAtomicReadByteArray res mba idx idx_ty = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_AtomicRead width)
+        [ addr ]
+
+-- | Emit an atomic write to a byte array that acts as a memory barrier.
+doAtomicWriteByteArray
+    :: CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> CmmExpr   -- ^ Value to write
+    -> FCode ()
+doAtomicWriteByteArray mba idx idx_ty val = do
+    dflags <- getDynFlags
+    let width = typeWidth idx_ty
+        addr  = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+                width mba idx
+    emitPrimCall
+        [ {- no results -} ]
+        (MO_AtomicWrite width)
+        [ addr, val ]
+
+doCasByteArray
+    :: LocalReg  -- ^ Result reg
+    -> CmmExpr   -- ^ MutableByteArray#
+    -> CmmExpr   -- ^ Index
+    -> CmmType   -- ^ Type of element by which we are indexing
+    -> CmmExpr   -- ^ Old value
+    -> CmmExpr   -- ^ New value
+    -> FCode ()
+doCasByteArray res mba idx idx_ty old new = do
+    dflags <- getDynFlags
+    let width = (typeWidth idx_ty)
+        addr = cmmIndexOffExpr dflags (arrWordsHdrSize dflags)
+               width mba idx
+    emitPrimCall
+        [ res ]
+        (MO_Cmpxchg width)
+        [ addr, old, new ]
+
+------------------------------------------------------------------------------
 -- Helpers for emitting function calls
 
 -- | Emit a call to @memcpy@.
index f92bd89..24d0856 100644 (file)
@@ -65,6 +65,8 @@ data LlvmFunction = LlvmFunction {
 
 type LlvmFunctions = [LlvmFunction]
 
+type SingleThreaded = Bool
+
 -- | LLVM ordering types for synchronization purposes. (Introduced in LLVM
 -- 3.0). Please see the LLVM documentation for a better description.
 data LlvmSyncOrdering
@@ -224,6 +226,11 @@ data LlvmExpression
   | Load LlvmVar
 
   {- |
+    Atomic load of the value at location ptr
+  -}
+  | ALoad LlvmSyncOrdering SingleThreaded LlvmVar
+
+  {- |
     Navigate in an structure, selecting elements
       * inbound: Is the pointer inbounds? (computed pointer doesn't overflow)
       * ptr:     Location of the structure
index 0250782..7307725 100644 (file)
@@ -239,6 +239,7 @@ ppLlvmExpression expr
         Insert     vec elt idx      -> ppInsert vec elt idx
         GetElemPtr inb ptr indexes  -> ppGetElementPtr inb ptr indexes
         Load       ptr              -> ppLoad ptr
+        ALoad      ord st ptr       -> ppALoad ord st ptr
         Malloc     tp amount        -> ppMalloc tp amount
         Phi        tp precessors    -> ppPhi tp precessors
         Asm        asm c ty v se sk -> ppAsm asm c ty v se sk
@@ -327,13 +328,18 @@ ppSyncOrdering SyncSeqCst    = text "seq_cst"
 -- of specifying alignment.
 
 ppLoad :: LlvmVar -> SDoc
-ppLoad var
-    | isVecPtrVar var = text "load" <+> ppr var <>
-                        comma <+> text "align 1"
-    | otherwise       = text "load" <+> ppr var
+ppLoad var = text "load" <+> ppr var <> align
   where
-    isVecPtrVar :: LlvmVar -> Bool
-    isVecPtrVar = isVector . pLower . getVarType
+    align | isVector . pLower . getVarType $ var = text ", align 1"
+          | otherwise = empty
+
+ppALoad :: LlvmSyncOrdering -> SingleThreaded -> LlvmVar -> SDoc
+ppALoad ord st var = sdocWithDynFlags $ \dflags ->
+  let alignment = (llvmWidthInBits dflags $ getVarType var) `quot` 8
+      align     = text ", align" <+> ppr alignment
+      sThreaded | st        = text " singlethread"
+                | otherwise = empty
+  in text "load atomic" <+> ppr var <> sThreaded <+> ppSyncOrdering ord <> align
 
 ppStore :: LlvmVar -> LlvmVar -> SDoc
 ppStore val dst
index 5175535..4a56600 100644 (file)
@@ -15,6 +15,7 @@ import BlockId
 import CodeGen.Platform ( activeStgRegs, callerSaves )
 import CLabel
 import Cmm
+import CPrim
 import PprCmm
 import CmmUtils
 import Hoopl
@@ -32,6 +33,7 @@ import Unique
 import Data.List ( nub )
 import Data.Maybe ( catMaybes )
 
+type Atomic = Bool
 type LlvmStatements = OrdList LlvmStatement
 
 -- -----------------------------------------------------------------------------
@@ -228,6 +230,17 @@ genCall t@(PrimTarget (MO_PopCnt w)) dsts args =
 genCall t@(PrimTarget (MO_BSwap w)) dsts args =
     genCallSimpleCast w t dsts args
 
+genCall (PrimTarget (MO_AtomicRead _)) [dst] [addr] = do
+  dstV <- getCmmReg (CmmLocal dst)
+  (v1, stmts, top) <- genLoad True addr (localRegType dst)
+  let stmt1 = Store v1 dstV
+  return (stmts `snocOL` stmt1, top)
+
+-- TODO: implement these properly rather than calling to RTS functions.
+-- genCall t@(PrimTarget (MO_AtomicWrite width)) [] [addr, val] = undefined
+-- genCall t@(PrimTarget (MO_AtomicRMW width amop)) [dst] [addr, n] = undefined
+-- genCall t@(PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] = undefined
+
 -- Handle memcpy function specifically since llvm's intrinsic version takes
 -- some extra parameters.
 genCall t@(PrimTarget op) [] args'
@@ -548,7 +561,6 @@ cmmPrimOpFunctions mop = do
 
     (MO_Prefetch_Data _ )-> fsLit "llvm.prefetch"
 
-
     MO_S_QuotRem {}  -> unsupported
     MO_U_QuotRem {}  -> unsupported
     MO_U_QuotRem2 {} -> unsupported
@@ -558,6 +570,12 @@ cmmPrimOpFunctions mop = do
     MO_Touch         -> unsupported
     MO_UF_Conv _     -> unsupported
 
+    MO_AtomicRead _  -> unsupported
+
+    MO_AtomicRMW w amop -> fsLit $ atomicRMWLabel w amop
+    MO_Cmpxchg w        -> fsLit $ cmpxchgLabel w
+    MO_AtomicWrite w    -> fsLit $ atomicWriteLabel w
+
 -- | Tail function calls
 genJump :: CmmExpr -> [GlobalReg] -> LlvmM StmtData
 
@@ -849,7 +867,7 @@ exprToVarOpt opt e = case e of
         -> genLit opt lit
 
     CmmLoad e' ty
-        -> genLoad e' ty
+        -> genLoad False e' ty
 
     -- Cmmreg in expression is the value, so must load. If you want actual
     -- reg pointer, call getCmmReg directly.
@@ -1268,41 +1286,41 @@ genMachOp_slow _ _ _ = panic "genMachOp: More then 2 expressions in MachOp!"
 
 
 -- | Handle CmmLoad expression.
-genLoad :: CmmExpr -> CmmType -> LlvmM ExprData
+genLoad :: Atomic -> CmmExpr -> CmmType -> LlvmM ExprData
 
 -- First we try to detect a few common cases and produce better code for
 -- these then the default case. We are mostly trying to detect Cmm code
 -- like I32[Sp + n] and use 'getelementptr' operations instead of the
 -- generic case that uses casts and pointer arithmetic
-genLoad e@(CmmReg (CmmGlobal r)) ty
-    = genLoad_fast e r 0 ty
+genLoad atomic e@(CmmReg (CmmGlobal r)) ty
+    = genLoad_fast atomic e r 0 ty
 
-genLoad e@(CmmRegOff (CmmGlobal r) n) ty
-    = genLoad_fast e r n ty
+genLoad atomic e@(CmmRegOff (CmmGlobal r) n) ty
+    = genLoad_fast atomic e r n ty
 
-genLoad e@(CmmMachOp (MO_Add _) [
+genLoad atomic e@(CmmMachOp (MO_Add _) [
                             (CmmReg (CmmGlobal r)),
                             (CmmLit (CmmInt n _))])
                 ty
-    = genLoad_fast e r (fromInteger n) ty
+    = genLoad_fast atomic e r (fromInteger n) ty
 
-genLoad e@(CmmMachOp (MO_Sub _) [
+genLoad atomic e@(CmmMachOp (MO_Sub _) [
                             (CmmReg (CmmGlobal r)),
                             (CmmLit (CmmInt n _))])
                 ty
-    = genLoad_fast e r (negate $ fromInteger n) ty
+    = genLoad_fast atomic e r (negate $ fromInteger n) ty
 
 -- generic case
-genLoad e ty
+genLoad atomic e ty
     = do other <- getTBAAMeta otherN
-         genLoad_slow e ty other
+         genLoad_slow atomic e ty other
 
 -- | Handle CmmLoad expression.
 -- This is a special case for loading from a global register pointer
 -- offset such as I32[Sp+8].
-genLoad_fast :: CmmExpr -> GlobalReg -> Int -> CmmType
-                -> LlvmM ExprData
-genLoad_fast e r n ty = do
+genLoad_fast :: Atomic -> CmmExpr -> GlobalReg -> Int -> CmmType
+             -> LlvmM ExprData
+genLoad_fast atomic e r n ty = do
     dflags <- getDynFlags
     (gv, grt, s1) <- getCmmRegVal (CmmGlobal r)
     meta          <- getTBAARegMeta r
@@ -1315,7 +1333,7 @@ genLoad_fast e r n ty = do
                 case grt == ty' of
                      -- were fine
                      True -> do
-                         (var, s3) <- doExpr ty' (MExpr meta $ Load ptr)
+                         (var, s3) <- doExpr ty' (MExpr meta $ loadInstr ptr)
                          return (var, s1 `snocOL` s2 `snocOL` s3,
                                      [])
 
@@ -1323,32 +1341,34 @@ genLoad_fast e r n ty = do
                      False -> do
                          let pty = pLift ty'
                          (ptr', s3) <- doExpr pty $ Cast LM_Bitcast ptr pty
-                         (var, s4) <- doExpr ty' (MExpr meta $ Load ptr')
+                         (var, s4) <- doExpr ty' (MExpr meta $ loadInstr ptr')
                          return (var, s1 `snocOL` s2 `snocOL` s3
                                     `snocOL` s4, [])
 
             -- If its a bit type then we use the slow method since
             -- we can't avoid casting anyway.
-            False -> genLoad_slow e ty meta
-
+            False -> genLoad_slow atomic  e ty meta
+  where
+    loadInstr ptr | atomic    = ALoad SyncSeqCst False ptr
+                  | otherwise = Load ptr
 
 -- | Handle Cmm load expression.
 -- Generic case. Uses casts and pointer arithmetic if needed.
-genLoad_slow :: CmmExpr -> CmmType -> [MetaAnnot] -> LlvmM ExprData
-genLoad_slow e ty meta = do
+genLoad_slow :: Atomic -> CmmExpr -> CmmType -> [MetaAnnot] -> LlvmM ExprData
+genLoad_slow atomic e ty meta = do
     (iptr, stmts, tops) <- exprToVar e
     dflags <- getDynFlags
     case getVarType iptr of
          LMPointer _ -> do
                     (dvar, load) <- doExpr (cmmToLlvmType ty)
-                                           (MExpr meta $ Load iptr)
+                                           (MExpr meta $ loadInstr iptr)
                     return (dvar, stmts `snocOL` load, tops)
 
          i@(LMInt _) | i == llvmWord dflags -> do
                     let pty = LMPointer $ cmmToLlvmType ty
                     (ptr, cast)  <- doExpr pty $ Cast LM_Inttoptr iptr pty
                     (dvar, load) <- doExpr (cmmToLlvmType ty)
-                                           (MExpr meta $ Load ptr)
+                                           (MExpr meta $ loadInstr ptr)
                     return (dvar, stmts `snocOL` cast `snocOL` load, tops)
 
          other -> do dflags <- getDynFlags
@@ -1357,6 +1377,9 @@ genLoad_slow e ty meta = do
                             "Size of Ptr: " ++ show (llvmPtrBits dflags) ++
                             ", Size of var: " ++ show (llvmWidthInBits dflags other) ++
                             ", Var: " ++ showSDoc dflags (ppr iptr)))
+  where
+    loadInstr ptr | atomic    = ALoad SyncSeqCst False ptr
+                  | otherwise = Load ptr
 
 
 -- | Handle CmmReg expression. This will return a pointer to the stack
index a6f4cab..34782df 100644 (file)
@@ -1,11 +1,16 @@
 -- | Generating C symbol names emitted by the compiler.
 module CPrim
-    ( popCntLabel
+    ( atomicReadLabel
+    , atomicWriteLabel
+    , atomicRMWLabel
+    , cmpxchgLabel
+    , popCntLabel
     , bSwapLabel
     , word2FloatLabel
     ) where
 
 import CmmType
+import CmmMachOp
 import Outputable
 
 popCntLabel :: Width -> String
@@ -31,3 +36,46 @@ word2FloatLabel w = "hs_word2float" ++ pprWidth w
     pprWidth W32 = "32"
     pprWidth W64 = "64"
     pprWidth w   = pprPanic "word2FloatLabel: Unsupported word width " (ppr w)
+
+atomicRMWLabel :: Width -> AtomicMachOp -> String
+atomicRMWLabel w amop = "hs_atomic_" ++ pprFunName amop ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicRMWLabel: Unsupported word width " (ppr w)
+
+    pprFunName AMO_Add  = "add"
+    pprFunName AMO_Sub  = "sub"
+    pprFunName AMO_And  = "and"
+    pprFunName AMO_Nand = "nand"
+    pprFunName AMO_Or   = "or"
+    pprFunName AMO_Xor  = "xor"
+
+cmpxchgLabel :: Width -> String
+cmpxchgLabel w = "hs_cmpxchg" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "cmpxchgLabel: Unsupported word width " (ppr w)
+
+atomicReadLabel :: Width -> String
+atomicReadLabel w = "hs_atomicread" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicReadLabel: Unsupported word width " (ppr w)
+
+atomicWriteLabel :: Width -> String
+atomicWriteLabel w = "hs_atomicwrite" ++ pprWidth w
+  where
+    pprWidth W8  = "8"
+    pprWidth W16 = "16"
+    pprWidth W32 = "32"
+    pprWidth W64 = "64"
+    pprWidth w   = pprPanic "atomicWriteLabel: Unsupported word width " (ppr w)
index 91651e6..22a2c7c 100644 (file)
@@ -1160,6 +1160,10 @@ genCCall' dflags gcp target dest_regs args0
 
                     MO_BSwap w   -> (fsLit $ bSwapLabel w, False)
                     MO_PopCnt w  -> (fsLit $ popCntLabel w, False)
+                    MO_AtomicRMW w amop -> (fsLit $ atomicRMWLabel w amop, False)
+                    MO_Cmpxchg w -> (fsLit $ cmpxchgLabel w, False)
+                    MO_AtomicRead w  -> (fsLit $ atomicReadLabel w, False)
+                    MO_AtomicWrite w -> (fsLit $ atomicWriteLabel w, False)
 
                     MO_S_QuotRem {}  -> unsupported
                     MO_U_QuotRem {}  -> unsupported
index f5e61d0..51f89d6 100644 (file)
@@ -654,6 +654,10 @@ outOfLineMachOp_table mop
 
         MO_BSwap w   -> fsLit $ bSwapLabel w
         MO_PopCnt w  -> fsLit $ popCntLabel w
+        MO_AtomicRMW w amop -> fsLit $ atomicRMWLabel w amop
+        MO_Cmpxchg w -> fsLit $ cmpxchgLabel w
+        MO_AtomicRead w -> fsLit $ atomicReadLabel w
+        MO_AtomicWrite w -> fsLit $ atomicWriteLabel w
 
         MO_S_QuotRem {}  -> unsupported
         MO_U_QuotRem {}  -> unsupported
index fa93767..140e8b2 100644 (file)
@@ -1761,6 +1761,93 @@ genCCall dflags is32Bit (PrimTarget (MO_UF_Conv width)) dest_regs args = do
   where
     lbl = mkCmmCodeLabel primPackageId (fsLit (word2FloatLabel width))
 
+genCCall dflags _ (PrimTarget (MO_AtomicRMW width amop)) [dst] [addr, n] = do
+    Amode amode addr_code <- getAmode addr
+    arg <- getNewRegNat size
+    arg_code <- getAnyReg n
+    use_sse2 <- sse2Enabled
+    let platform = targetPlatform dflags
+        dst_r    = getRegisterReg platform use_sse2 (CmmLocal dst)
+    code <- op_code dst_r arg amode
+    return $ addr_code `appOL` arg_code arg `appOL` code
+  where
+    -- Code for the operation
+    op_code :: Reg       -- ^ The destination reg
+            -> Reg       -- ^ Register containing argument
+            -> AddrMode  -- ^ Address of location to mutate
+            -> NatM (OrdList Instr)
+    op_code dst_r arg amode = case amop of
+        -- In the common case where dst_r is a virtual register the
+        -- final move should go away, because it's the last use of arg
+        -- and the first use of dst_r.
+        AMO_Add  -> return $ toOL [ LOCK
+                                  , XADD size (OpReg arg) (OpAddr amode)
+                                  , MOV size (OpReg arg) (OpReg dst_r)
+                                  ]
+        AMO_Sub  -> return $ toOL [ NEGI size (OpReg arg)
+                                  , LOCK
+                                  , XADD size (OpReg arg) (OpAddr amode)
+                                  , MOV size (OpReg arg) (OpReg dst_r)
+                                  ]
+        AMO_And  -> cmpxchg_code (\ src dst -> unitOL $ AND size src dst)
+        AMO_Nand -> cmpxchg_code (\ src dst -> toOL [ AND size src dst
+                                                    , NOT size dst
+                                                    ])
+        AMO_Or   -> cmpxchg_code (\ src dst -> unitOL $ OR size src dst)
+        AMO_Xor  -> cmpxchg_code (\ src dst -> unitOL $ XOR size src dst)
+      where
+        -- Simulate operation that lacks a dedicated instruction using
+        -- cmpxchg.
+        cmpxchg_code :: (Operand -> Operand -> OrdList Instr)
+                     -> NatM (OrdList Instr)
+        cmpxchg_code instrs = do
+            lbl <- getBlockIdNat
+            tmp <- getNewRegNat size
+            return $ toOL
+                [ MOV size (OpAddr amode) (OpReg eax)
+                , JXX ALWAYS lbl
+                , NEWBLOCK lbl
+                  -- Keep old value so we can return it:
+                , MOV size (OpReg eax) (OpReg dst_r)
+                , MOV size (OpReg eax) (OpReg tmp)
+                ]
+                `appOL` instrs (OpReg arg) (OpReg tmp) `appOL` toOL
+                [ LOCK
+                , CMPXCHG size (OpReg tmp) (OpAddr amode)
+                , JXX NE lbl
+                ]
+
+    size = intSize width
+
+genCCall dflags _ (PrimTarget (MO_AtomicRead width)) [dst] [addr] = do
+  load_code <- intLoadCode (MOV (intSize width)) addr
+  let platform = targetPlatform dflags
+  use_sse2 <- sse2Enabled
+  return (load_code (getRegisterReg platform use_sse2 (CmmLocal dst)))
+
+genCCall _ _ (PrimTarget (MO_AtomicWrite width)) [] [addr, val] = do
+    assignMem_IntCode (intSize width) addr val
+
+genCCall dflags _ (PrimTarget (MO_Cmpxchg width)) [dst] [addr, old, new] = do
+    Amode amode addr_code <- getAmode addr
+    newval <- getNewRegNat size
+    newval_code <- getAnyReg new
+    oldval <- getNewRegNat size
+    oldval_code <- getAnyReg old
+    use_sse2 <- sse2Enabled
+    let platform = targetPlatform dflags
+        dst_r    = getRegisterReg platform use_sse2 (CmmLocal dst)
+        code     = toOL
+                   [ MOV size (OpReg oldval) (OpReg eax)
+                   , LOCK
+                   , CMPXCHG size (OpReg newval) (OpAddr amode)
+                   , MOV size (OpReg eax) (OpReg dst_r)
+                   ]
+    return $ addr_code `appOL` newval_code newval `appOL` oldval_code oldval
+        `appOL` code
+  where
+    size = intSize width
+
 genCCall _ is32Bit target dest_regs args
  | is32Bit   = genCCall32 target dest_regs args
  | otherwise = genCCall64 target dest_regs args
@@ -2385,6 +2472,11 @@ outOfLineCmmOp mop res args
               MO_PopCnt _  -> fsLit "popcnt"
               MO_BSwap _   -> fsLit "bswap"
 
+              MO_AtomicRMW _ _ -> fsLit "atomicrmw"
+              MO_AtomicRead _  -> fsLit "atomicread"
+              MO_AtomicWrite _ -> fsLit "atomicwrite"
+              MO_Cmpxchg _     -> fsLit "cmpxchg"
+
               MO_UF_Conv _ -> unsupported
 
               MO_S_QuotRem {}  -> unsupported
index 05fff9b..ac91747 100644 (file)
@@ -327,6 +327,10 @@ data Instr
         | PREFETCH  PrefetchVariant Size Operand -- prefetch Variant, addr size, address to prefetch
                                         -- variant can be NTA, Lvl0, Lvl1, or Lvl2
 
+        | LOCK  -- lock prefix
+        | XADD        Size Operand Operand  -- src (r), dst (r/m)
+        | CMPXCHG     Size Operand Operand  -- src (r), dst (r/m), eax implicit
+
 data PrefetchVariant = NTA | Lvl0 | Lvl1 | Lvl2
 
 
@@ -337,6 +341,8 @@ data Operand
 
 
 
+-- | Returns which registers are read and written as a (read, written)
+-- pair.
 x86_regUsageOfInstr :: Platform -> Instr -> RegUsage
 x86_regUsageOfInstr platform instr
  = case instr of
@@ -428,10 +434,21 @@ x86_regUsageOfInstr platform instr
 
     -- note: might be a better way to do this
     PREFETCH _  _ src -> mkRU (use_R src []) []
+    LOCK                -> noUsage
+    XADD _ src dst      -> usageMM src dst
+    CMPXCHG _ src dst   -> usageRMM src dst (OpReg eax)
 
     _other              -> panic "regUsage: unrecognised instr"
-
  where
+    -- # Definitions
+    --
+    -- Written: If the operand is a register, it's written. If it's an
+    -- address, registers mentioned in the address are read.
+    --
+    -- Modified: If the operand is a register, it's both read and
+    -- written. If it's an address, registers mentioned in the address
+    -- are read.
+
     -- 2 operand form; first operand Read; second Written
     usageRW :: Operand -> Operand -> RegUsage
     usageRW op (OpReg reg)      = mkRU (use_R op []) [reg]
@@ -444,6 +461,18 @@ x86_regUsageOfInstr platform instr
     usageRM op (OpAddr ea)      = mkRUR (use_R op $! use_EA ea [])
     usageRM _ _                 = panic "X86.RegInfo.usageRM: no match"
 
+    -- 2 operand form; first operand Modified; second Modified
+    usageMM :: Operand -> Operand -> RegUsage
+    usageMM (OpReg src) (OpReg dst) = mkRU [src, dst] [src, dst]
+    usageMM (OpReg src) (OpAddr ea) = mkRU (use_EA ea [src]) [src]
+    usageMM _ _                     = panic "X86.RegInfo.usageMM: no match"
+
+    -- 3 operand form; first operand Read; second Modified; third Modified
+    usageRMM :: Operand -> Operand -> Operand -> RegUsage
+    usageRMM (OpReg src) (OpReg dst) (OpReg reg) = mkRU [src, dst, reg] [dst, reg]
+    usageRMM (OpReg src) (OpAddr ea) (OpReg reg) = mkRU (use_EA ea [src, reg]) [reg]
+    usageRMM _ _ _                               = panic "X86.RegInfo.usageRMM: no match"
+
     -- 1 operand form; operand Modified
     usageM :: Operand -> RegUsage
     usageM (OpReg reg)          = mkRU [reg] [reg]
@@ -476,6 +505,7 @@ x86_regUsageOfInstr platform instr
         where src' = filter (interesting platform) src
               dst' = filter (interesting platform) dst
 
+-- | Is this register interesting for the register allocator?
 interesting :: Platform -> Reg -> Bool
 interesting _        (RegVirtual _)              = True
 interesting platform (RegReal (RealRegSingle i)) = isFastTrue (freeReg platform i)
@@ -483,6 +513,8 @@ interesting _        (RegReal (RealRegPair{}))   = panic "X86.interesting: no re
 
 
 
+-- | Applies the supplied function to all registers in instructions.
+-- Typically used to change virtual registers to real registers.
 x86_patchRegsOfInstr :: Instr -> (Reg -> Reg) -> Instr
 x86_patchRegsOfInstr instr env
  = case instr of
@@ -571,6 +603,10 @@ x86_patchRegsOfInstr instr env
 
     PREFETCH lvl size src -> PREFETCH lvl size (patchOp src)
 
+    LOCK                -> instr
+    XADD sz src dst     -> patch2 (XADD sz) src dst
+    CMPXCHG sz src dst  -> patch2 (CMPXCHG sz) src dst
+
     _other              -> panic "patchRegs: unrecognised instr"
 
   where
index 459c041..7771c02 100644 (file)
@@ -886,6 +886,14 @@ pprInstr GFREE
             ptext (sLit "\tffree %st(4) ;ffree %st(5)")
           ]
 
+-- Atomics
+
+pprInstr LOCK = ptext (sLit "\tlock")
+
+pprInstr (XADD size src dst) = pprSizeOpOp (sLit "xadd") size src dst
+
+pprInstr (CMPXCHG size src dst) = pprSizeOpOp (sLit "cmpxchg") size src dst
+
 pprInstr _
         = panic "X86.Ppr.pprInstr: no match"
 
index 4851315..4faa585 100644 (file)
@@ -1363,19 +1363,79 @@ primop  SetByteArrayOp "setByteArray#" GenPrimOp
   code_size = { primOpCodeSizeForeignCall + 4 }
   can_fail = True
 
+-- Atomic operations
+
+primop  AtomicReadByteArrayOp_Int "atomicReadIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array and an offset in Int units, read an element. The
+    index is assumed to be in bounds. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop  AtomicWriteByteArrayOp_Int "atomicWriteIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> State# s
+   {Given an array and an offset in Int units, write an element. The
+    index is assumed to be in bounds. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
 primop CasByteArrayOp_Int "casIntArray#" GenPrimOp
    MutableByteArray# s -> Int# -> Int# -> Int# -> State# s -> (# State# s, Int# #)
-   {Machine-level atomic compare and swap on a word within a ByteArray.}
-   with
-   out_of_line = True
-   has_side_effects = True
+   {Given an array, an offset in Int units, the expected old value, and
+    the new value, perform an atomic compare and swap i.e. write the new
+    value if the current value matches the provided old value. Returns
+    the value of the element before the operation. Implies a full memory
+    barrier.}
+   with has_side_effects = True
+        can_fail = True
 
 primop FetchAddByteArrayOp_Int "fetchAddIntArray#" GenPrimOp
    MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
-   {Machine-level word-sized fetch-and-add within a ByteArray.}
-   with
-   out_of_line = True
-   has_side_effects = True
+   {Given an array, and offset in Int units, and a value to add,
+    atomically add the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchSubByteArrayOp_Int "fetchSubIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to subtract,
+    atomically substract the value to the element. Returns the value of
+    the element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchAndByteArrayOp_Int "fetchAndIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to AND,
+    atomically AND the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchNandByteArrayOp_Int "fetchNandIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to NAND,
+    atomically NAND the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchOrByteArrayOp_Int "fetchOrIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to OR,
+    atomically OR the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
+
+primop FetchXorByteArrayOp_Int "fetchXorIntArray#" GenPrimOp
+   MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #)
+   {Given an array, and offset in Int units, and a value to XOR,
+    atomically XOR the value to the element. Returns the value of the
+    element before the operation. Implies a full memory barrier.}
+   with has_side_effects = True
+        can_fail = True
 
 
 ------------------------------------------------------------------------
index 0c4d2f9..ee5a119 100644 (file)
@@ -348,7 +348,6 @@ RTS_FUN_DECL(stg_newByteArrayzh);
 RTS_FUN_DECL(stg_newPinnedByteArrayzh);
 RTS_FUN_DECL(stg_newAlignedPinnedByteArrayzh);
 RTS_FUN_DECL(stg_casIntArrayzh);
-RTS_FUN_DECL(stg_fetchAddIntArrayzh);
 RTS_FUN_DECL(stg_newArrayzh);
 RTS_FUN_DECL(stg_newArrayArrayzh);
 RTS_FUN_DECL(stg_copyArrayzh);
diff --git a/libraries/ghc-prim/cbits/atomic.c b/libraries/ghc-prim/cbits/atomic.c
new file mode 100644 (file)
index 0000000..a2e64af
--- /dev/null
@@ -0,0 +1,280 @@
+#include "Rts.h"
+
+// Fallbacks for atomic primops on byte arrays. The builtins used
+// below are supported on both GCC and LLVM.
+//
+// Ideally these function would take StgWord8, StgWord16, etc but
+// older GCC versions incorrectly assume that the register that the
+// argument is passed in has been zero extended, which is incorrect
+// according to the ABI and is not what GHC does when it generates
+// calls to these functions.
+
+// FetchAddByteArrayOp_Int
+
+extern StgWord hs_atomic_add8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_add8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_add16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_add16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_add32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_add32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_add(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_add64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_add64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_add(x, val);
+}
+
+// FetchSubByteArrayOp_Int
+
+extern StgWord hs_atomic_sub8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_sub8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_sub16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_sub16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_sub32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_sub32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_sub(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_sub64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_sub64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_sub(x, val);
+}
+
+// FetchAndByteArrayOp_Int
+
+extern StgWord hs_atomic_and8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_and8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_and16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_and16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_and32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_and32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_and(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_and64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_and64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_and(x, val);
+}
+
+// FetchNandByteArrayOp_Int
+
+extern StgWord hs_atomic_nand8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_nand8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_nand(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_nand16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_nand16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_nand(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_nand32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_nand32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_nand(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_nand64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_nand64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_nand(x, val);
+}
+
+// FetchOrByteArrayOp_Int
+
+extern StgWord hs_atomic_or8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_or8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_or16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_or16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_or32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_or32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_or(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_or64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_or64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_or(x, val);
+}
+
+// FetchXorByteArrayOp_Int
+
+extern StgWord hs_atomic_xor8(volatile StgWord8 *x, StgWord val);
+StgWord
+hs_atomic_xor8(volatile StgWord8 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord8) val);
+}
+
+extern StgWord hs_atomic_xor16(volatile StgWord16 *x, StgWord val);
+StgWord
+hs_atomic_xor16(volatile StgWord16 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord16) val);
+}
+
+extern StgWord hs_atomic_xor32(volatile StgWord32 *x, StgWord val);
+StgWord
+hs_atomic_xor32(volatile StgWord32 *x, StgWord val)
+{
+  return __sync_fetch_and_xor(x, (StgWord32) val);
+}
+
+extern StgWord64 hs_atomic_xor64(volatile StgWord64 *x, StgWord64 val);
+StgWord64
+hs_atomic_xor64(volatile StgWord64 *x, StgWord64 val)
+{
+  return __sync_fetch_and_xor(x, val);
+}
+
+// CasByteArrayOp_Int
+
+extern StgWord hs_cmpxchg8(volatile StgWord8 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg8(volatile StgWord8 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord8) old, (StgWord8) new);
+}
+
+extern StgWord hs_cmpxchg16(volatile StgWord16 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg16(volatile StgWord16 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord16) old, (StgWord16) new);
+}
+
+extern StgWord hs_cmpxchg32(volatile StgWord32 *x, StgWord old, StgWord new);
+StgWord
+hs_cmpxchg32(volatile StgWord32 *x, StgWord old, StgWord new)
+{
+  return __sync_val_compare_and_swap(x, (StgWord32) old, (StgWord32) new);
+}
+
+extern StgWord hs_cmpxchg64(volatile StgWord64 *x, StgWord64 old, StgWord64 new);
+StgWord
+hs_cmpxchg64(volatile StgWord64 *x, StgWord64 old, StgWord64 new)
+{
+  return __sync_val_compare_and_swap(x, old, new);
+}
+
+// AtomicReadByteArrayOp_Int
+
+extern StgWord hs_atomicread8(volatile StgWord8 *x);
+StgWord
+hs_atomicread8(volatile StgWord8 *x)
+{
+  return *x;
+}
+
+extern StgWord hs_atomicread16(volatile StgWord16 *x);
+StgWord
+hs_atomicread16(volatile StgWord16 *x)
+{
+  return *x;
+}
+
+extern StgWord hs_atomicread32(volatile StgWord32 *x);
+StgWord
+hs_atomicread32(volatile StgWord32 *x)
+{
+  return *x;
+}
+
+extern StgWord64 hs_atomicread64(volatile StgWord64 *x);
+StgWord64
+hs_atomicread64(volatile StgWord64 *x)
+{
+  return *x;
+}
+
+// AtomicWriteByteArrayOp_Int
+
+extern void hs_atomicwrite8(volatile StgWord8 *x, StgWord val);
+void
+hs_atomicwrite8(volatile StgWord8 *x, StgWord val)
+{
+  *x = (StgWord8) val;
+}
+
+extern void hs_atomicwrite16(volatile StgWord16 *x, StgWord val);
+void
+hs_atomicwrite16(volatile StgWord16 *x, StgWord val)
+{
+  *x = (StgWord16) val;
+}
+
+extern void hs_atomicwrite32(volatile StgWord32 *x, StgWord val);
+void
+hs_atomicwrite32(volatile StgWord32 *x, StgWord val)
+{
+  *x = (StgWord32) val;
+}
+
+extern void hs_atomicwrite64(volatile StgWord64 *x, StgWord64 val);
+void
+hs_atomicwrite64(volatile StgWord64 *x, StgWord64 val)
+{
+  *x = (StgWord64) val;
+}
index c861342..bc9f571 100644 (file)
@@ -52,6 +52,7 @@ Library
         exposed-modules: GHC.Prim
 
     c-sources:
+        cbits/atomic.c
         cbits/bswap.c
         cbits/debug.c
         cbits/longlong.c
index f7f554c..fb07d58 100644 (file)
@@ -1186,7 +1186,6 @@ typedef struct _RtsSymbolVal {
       SymI_HasProto(stg_newBCOzh)                                       \
       SymI_HasProto(stg_newByteArrayzh)                                 \
       SymI_HasProto(stg_casIntArrayzh)                                  \
-      SymI_HasProto(stg_fetchAddIntArrayzh)                             \
       SymI_HasProto(stg_newMVarzh)                                      \
       SymI_HasProto(stg_newMutVarzh)                                    \
       SymI_HasProto(stg_newTVarzh)                                      \
index 4d7baca..5f04a6d 100644 (file)
@@ -151,18 +151,6 @@ stg_casIntArrayzh( gcptr arr, W_ ind, W_ old, W_ new )
 }
 
 
-stg_fetchAddIntArrayzh( gcptr arr, W_ ind, W_ incr )
-/* MutableByteArray# s -> Int# -> Int# -> State# s -> (# State# s, Int# #) */
-{
-    W_ p, h;
-
-    p = arr + SIZEOF_StgArrWords + WDS(ind);
-    (h) = ccall atomic_inc(p, incr);
-
-    return(h);
-}
-
-
 stg_newArrayzh ( W_ n /* words */, gcptr init )
 {
     W_ words, size, p;
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.hs b/testsuite/tests/concurrent/should_run/AtomicPrimops.hs
new file mode 100644 (file)
index 0000000..0c55aba
--- /dev/null
@@ -0,0 +1,245 @@
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedTuples #-}
+
+module Main ( main ) where
+
+import Control.Concurrent
+import Control.Concurrent.MVar
+import Control.Monad (when)
+import Foreign.Storable
+import GHC.Exts
+import GHC.IO
+
+-- | Iterations per worker.
+iters :: Int
+iters = 1000000
+
+main :: IO ()
+main = do
+    fetchAddSubTest
+    fetchAndTest
+    fetchNandTest
+    fetchOrTest
+    fetchXorTest
+    casTest
+    readWriteTest
+
+-- | Test fetchAddIntArray# by having two threads concurrenctly
+-- increment a counter and then checking the sum at the end.
+fetchAddSubTest :: IO ()
+fetchAddSubTest = do
+    tot <- race 0
+        (\ mba -> work fetchAddIntArray mba iters 2)
+        (\ mba -> work fetchSubIntArray mba iters 1)
+    assertEq 1000000 tot "fetchAddSubTest"
+  where
+    work :: (MByteArray -> Int -> Int -> IO ()) -> MByteArray -> Int -> Int
+         -> IO ()
+    work op mba 0 val = return ()
+    work op mba n val = op mba 0 val >> work op mba (n-1) val
+
+-- | Test fetchXorIntArray# by having two threads concurrenctly XORing
+-- and then checking the result at the end. Works since XOR is
+-- commutative.
+--
+-- Covers the code paths for AND, NAND, and OR as well.
+fetchXorTest :: IO ()
+fetchXorTest = do
+    res <- race n0
+        (\ mba -> work mba iters t1pat)
+        (\ mba -> work mba iters t2pat)
+    assertEq expected res "fetchXorTest"
+  where
+    work :: MByteArray -> Int -> Int -> IO ()
+    work mba 0 val = return ()
+    work mba n val = fetchXorIntArray mba 0 val >> work mba (n-1) val
+
+    -- Initial value is a large prime and the two patterns are 1010...
+    -- and 0101...
+    (n0, t1pat, t2pat)
+        | sizeOf (undefined :: Int) == 8 =
+            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
+        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+    expected
+        | sizeOf (undefined :: Int) == 8 = 4294967295
+        | otherwise = 65535
+
+-- The tests for AND, NAND, and OR are trivial for two reasons:
+--
+--  * The code path is already well exercised by 'fetchXorTest'.
+--
+--  * It's harder to test these operations, as a long sequence of them
+--    convert to a single value but we'd like to write a test in the
+--    style of 'fetchXorTest' that applies the operation repeatedly,
+--    to make it likely that any race conditions are detected.
+--
+-- Right now we only test that they return the correct value for a
+-- single op on each thread.
+
+fetchOpTest :: (MByteArray -> Int -> Int -> IO ())
+            -> Int -> String -> IO ()
+fetchOpTest op expected name = do
+    res <- race n0
+        (\ mba -> work mba t1pat)
+        (\ mba -> work mba t2pat)
+    assertEq expected res name
+  where
+    work :: MByteArray -> Int -> IO ()
+    work mba val = op mba 0 val
+
+    -- Initial value is a large prime and the two patterns are 1010...
+    -- and 0101...
+    (n0, t1pat, t2pat)
+        | sizeOf (undefined :: Int) == 8 =
+            (0x00000000ffffffff, 0x5555555555555555, 0x9999999999999999)
+        | otherwise = (0x0000ffff, 0x55555555, 0x99999999)
+
+fetchAndTest :: IO ()
+fetchAndTest = fetchOpTest fetchAndIntArray expected "fetchAndTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 286331153
+            | otherwise = 4369
+
+fetchNandTest :: IO ()
+fetchNandTest = fetchOpTest fetchNandIntArray expected "fetchNandTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 7378697629770151799
+            | otherwise = -2576976009
+
+fetchOrTest :: IO ()
+fetchOrTest = fetchOpTest fetchOrIntArray expected "fetchOrTest"
+  where expected
+            | sizeOf (undefined :: Int) == 8 = 15987178197787607039
+            | otherwise = 3722313727
+
+-- | Test casIntArray# by using it to emulate fetchAddIntArray# and
+-- then having two threads concurrenctly increment a counter,
+-- checking the sum at the end.
+casTest :: IO ()
+casTest = do
+    tot <- race 0
+        (\ mba -> work mba iters 1)
+        (\ mba -> work mba iters 2)
+    assertEq 3000000 tot "casTest"
+  where
+    work :: MByteArray -> Int -> Int -> IO ()
+    work mba 0 val = return ()
+    work mba n val = add mba 0 val >> work mba (n-1) val
+
+    -- Fetch-and-add implemented using CAS.
+    add :: MByteArray -> Int -> Int -> IO ()
+    add mba ix n = do
+        old <- readIntArray mba ix
+        old' <- casIntArray mba ix old (old + n)
+        when (old /= old') $ add mba ix n
+
+-- | Tests atomic reads and writes by making sure that one thread sees
+-- updates that are done on another. This test isn't very good at the
+-- moment, as this might work even without atomic ops, but at least it
+-- exercises the code.
+readWriteTest :: IO ()
+readWriteTest = do
+    mba <- newByteArray (sizeOf (undefined :: Int))
+    writeIntArray mba 0 0
+    latch <- newEmptyMVar
+    done <- newEmptyMVar
+    forkIO $ do
+        takeMVar latch
+        n <- atomicReadIntArray mba 0
+        assertEq 1 n "readWriteTest"
+        putMVar done ()
+    atomicWriteIntArray mba 0 1
+    putMVar latch ()
+    takeMVar done
+
+-- | Create two threads that mutate the byte array passed to them
+-- concurrently. The array is one word large.
+race :: Int                    -- ^ Initial value of array element
+     -> (MByteArray -> IO ())  -- ^ Thread 1 action
+     -> (MByteArray -> IO ())  -- ^ Thread 2 action
+     -> IO Int                 -- ^ Final value of array element
+race n0 thread1 thread2 = do
+    done1 <- newEmptyMVar
+    done2 <- newEmptyMVar
+    mba <- newByteArray (sizeOf (undefined :: Int))
+    writeIntArray mba 0 n0
+    forkIO $ thread1 mba >> putMVar done1 ()
+    forkIO $ thread2 mba >> putMVar done2 ()
+    mapM_ takeMVar [done1, done2]
+    readIntArray mba 0
+
+------------------------------------------------------------------------
+-- Test helper
+
+assertEq :: (Eq a, Show a) => a -> a -> String -> IO ()
+assertEq expected actual name
+    | expected == actual = putStrLn $ name ++ ": OK"
+    | otherwise = do
+        putStrLn $ name ++ ": FAIL"
+        putStrLn $ "Expected: " ++ show expected
+        putStrLn $ "  Actual: " ++ show actual
+
+------------------------------------------------------------------------
+-- Wrappers around MutableByteArray#
+
+data MByteArray = MBA (MutableByteArray# RealWorld)
+
+fetchAddIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchAddIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchAddIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchSubIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchSubIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchSubIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchAndIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchAndIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchAndIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchNandIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchNandIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchNandIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchOrIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchOrIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchOrIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+fetchXorIntArray :: MByteArray -> Int -> Int -> IO ()
+fetchXorIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case fetchXorIntArray# mba# ix# n# s# of
+        (# s2#, _ #) -> (# s2#, () #)
+
+newByteArray :: Int -> IO MByteArray
+newByteArray (I# n#) = IO $ \ s# ->
+    case newByteArray# n# s# of
+        (# s2#, mba# #) -> (# s2#, MBA mba# #)
+
+writeIntArray :: MByteArray -> Int -> Int -> IO ()
+writeIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case writeIntArray# mba# ix# n# s# of
+        s2# -> (# s2#, () #)
+
+readIntArray :: MByteArray -> Int -> IO Int
+readIntArray (MBA mba#) (I# ix#) = IO $ \ s# ->
+    case readIntArray# mba# ix# s# of
+        (# s2#, n# #) -> (# s2#, I# n# #)
+
+atomicWriteIntArray :: MByteArray -> Int -> Int -> IO ()
+atomicWriteIntArray (MBA mba#) (I# ix#) (I# n#) = IO $ \ s# ->
+    case atomicWriteIntArray# mba# ix# n# s# of
+        s2# -> (# s2#, () #)
+
+atomicReadIntArray :: MByteArray -> Int -> IO Int
+atomicReadIntArray (MBA mba#) (I# ix#) = IO $ \ s# ->
+    case atomicReadIntArray# mba# ix# s# of
+        (# s2#, n# #) -> (# s2#, I# n# #)
+
+casIntArray :: MByteArray -> Int -> Int -> Int -> IO Int
+casIntArray (MBA mba#) (I# ix#) (I# old#) (I# new#) = IO $ \ s# ->
+    case casIntArray# mba# ix# old# new# s# of
+        (# s2#, old2# #) -> (# s2#, I# old2# #)
diff --git a/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout b/testsuite/tests/concurrent/should_run/AtomicPrimops.stdout
new file mode 100644 (file)
index 0000000..c37041a
--- /dev/null
@@ -0,0 +1,7 @@
+fetchAddSubTest: OK
+fetchAndTest: OK
+fetchNandTest: OK
+fetchOrTest: OK
+fetchXorTest: OK
+casTest: OK
+readWriteTest: OK
index 0b502c3..0a66892 100644 (file)
@@ -81,6 +81,7 @@ test('tryReadMVar1', normal, compile_and_run, [''])
 test('tryReadMVar2', normal, compile_and_run, [''])
 
 test('T7970', normal, compile_and_run, [''])
+test('AtomicPrimops', normal, compile_and_run, [''])
 
 # -----------------------------------------------------------------------------
 # These tests we only do for a full run