Support MO_U_QuotRem2 in LLVM backend
authorMichal Terepeta <michal.terepeta@gmail.com>
Mon, 3 Aug 2015 06:41:13 +0000 (08:41 +0200)
committerBen Gamari <ben@smart-cactus.org>
Mon, 3 Aug 2015 06:41:32 +0000 (08:41 +0200)
This adds support for MO_U_QuotRem2 in LLVM backend.  Similarly to
MO_U_Mul2 we use the standard LLVM instructions (in this case 'udiv'
and 'urem') but do the computation on double the word width (e.g., for
64-bit we will do them on 128 registers).

Test Plan: validate

Reviewers: rwbarton, austin, bgamari

Reviewed By: bgamari

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D1100

GHC Trac Issues: #9430

compiler/codeGen/StgCmmPrim.hs
compiler/llvmGen/LlvmCodeGen/CodeGen.hs
testsuite/tests/primops/should_run/T9430.hs

index 243e2a3..d201eaf 100644 (file)
@@ -808,7 +808,8 @@ callishPrimOpSupported dflags op
       WordQuotRemOp  | ncg && x86ish  -> Left (MO_U_QuotRem  (wordWidth dflags))
                      | otherwise      -> Right (genericWordQuotRemOp dflags)
 
-      WordQuotRem2Op | ncg && x86ish  -> Left (MO_U_QuotRem2 (wordWidth dflags))
+      WordQuotRem2Op | (ncg && x86ish)
+                          || llvm     -> Left (MO_U_QuotRem2 (wordWidth dflags))
                      | otherwise      -> Right (genericWordQuotRem2Op dflags)
 
       WordAdd2Op     | (ncg && x86ish)
index fb02120..517da53 100644 (file)
@@ -32,6 +32,13 @@ import UniqSupply
 import Unique
 import Util
 
+import Control.Monad.Trans.Class
+import Control.Monad.Trans.Writer
+
+#if MIN_VERSION_base(4,8,0)
+#else
+import Data.Monoid ( Monoid, mappend, mempty )
+#endif
 import Data.List ( nub )
 import Data.Maybe ( catMaybes )
 
@@ -288,6 +295,53 @@ genCall (PrimTarget (MO_U_Mul2 w)) [dstH, dstL] [lhs, rhs] = do
            toOL [ stmt3 , stmt4, stmt5, stmt6, stmt7, stmt8, storeL, storeH ]
     return (stmts, decls1 ++ decls2)
 
+-- MO_U_QuotRem2 is another case we handle by widening the registers to double
+-- the width and use normal LLVM instructions (similarly to the MO_U_Mul2). The
+-- main difference here is that we need to conmbine two words into one register
+-- and then use both 'udiv' and 'urem' instructions to compute the result.
+genCall (PrimTarget (MO_U_QuotRem2 w)) [dstQ, dstR] [lhsH, lhsL, rhs] = run $ do
+    let width = widthToLlvmInt w
+        bitWidth = widthInBits w
+        width2x = LMInt (bitWidth * 2)
+    -- First zero-extend all parameters to double width.
+    let zeroExtend expr = do
+            var <- liftExprData $ exprToVar expr
+            doExprW width2x $ Cast LM_Zext var width2x
+    lhsExtH <- zeroExtend lhsH
+    lhsExtL <- zeroExtend lhsL
+    rhsExt <- zeroExtend rhs
+    -- Now we combine the first two parameters (that represent the high and low
+    -- bits of the value). So first left-shift the high bits to their position
+    -- and then bit-or them with the low bits.
+    let widthLlvmLit = LMLitVar $ LMIntLit (fromIntegral bitWidth) width
+    lhsExtHShifted <- doExprW width2x $ LlvmOp LM_MO_Shl lhsExtH widthLlvmLit
+    lhsExt <- doExprW width2x $ LlvmOp LM_MO_Or lhsExtHShifted lhsExtL
+    -- Finally, we can call 'udiv' and 'urem' to compute the results.
+    retExtDiv <- doExprW width2x $ LlvmOp LM_MO_UDiv lhsExt rhsExt
+    retExtRem <- doExprW width2x $ LlvmOp LM_MO_URem lhsExt rhsExt
+    -- And since everything is in 2x width, we need to truncate the results and
+    -- then return them.
+    let narrow var = doExprW width $ Cast LM_Trunc var width
+    retDiv <- narrow retExtDiv
+    retRem <- narrow retExtRem
+    dstRegQ <- lift $ getCmmReg (CmmLocal dstQ)
+    dstRegR <- lift $ getCmmReg (CmmLocal dstR)
+    statement $ Store retDiv dstRegQ
+    statement $ Store retRem dstRegR
+  where
+    -- TODO(michalt): Consider extracting this and using in more places.
+    -- Hopefully this should cut down on the noise of accumulating the
+    -- statements and declarations.
+    doExprW :: LlvmType -> LlvmExpression -> WriterT LlvmAccum LlvmM LlvmVar
+    doExprW a b = do
+        (var, stmt) <- lift $ doExpr a b
+        statement stmt
+        return var
+    run :: WriterT LlvmAccum LlvmM () -> LlvmM (LlvmStatements, [LlvmCmmDecl])
+    run action = do
+        LlvmAccum stmts decls <- execWriterT action
+        return (stmts, decls)
+
 -- Handle the MO_{Add,Sub}IntC separately. LLVM versions return a record from
 -- which we need to extract the actual values.
 genCall t@(PrimTarget (MO_AddIntC w)) [dstV, dstO] [lhs, rhs] =
@@ -1767,3 +1821,21 @@ getTBAAMeta u = do
 -- | Returns TBAA meta data for given register
 getTBAARegMeta :: GlobalReg -> LlvmM [MetaAnnot]
 getTBAARegMeta = getTBAAMeta . getTBAA
+
+
+-- | A more convenient way of accumulating LLVM statements and declarations.
+data LlvmAccum = LlvmAccum LlvmStatements [LlvmCmmDecl]
+
+instance Monoid LlvmAccum where
+    mempty = LlvmAccum nilOL []
+    LlvmAccum stmtsA declsA `mappend` LlvmAccum stmtsB declsB =
+        LlvmAccum (stmtsA `mappend` stmtsB) (declsA `mappend` declsB)
+
+liftExprData :: LlvmM ExprData -> WriterT LlvmAccum LlvmM LlvmVar
+liftExprData action = do
+    (var, stmts, decls) <- lift action
+    tell $ LlvmAccum stmts decls
+    return var
+
+statement :: LlvmStatement -> WriterT LlvmAccum LlvmM ()
+statement stmt = tell $ LlvmAccum (unitOL stmt) []
index aec2d26..eedc0a7 100644 (file)
@@ -34,6 +34,23 @@ checkW (expX, expY) op (W# a) (W# b) =
                     "Expected " ++ show expX ++ " and " ++ show expY
                         ++ " but got " ++ show (W# x) ++ " and " ++ show (W# y)
 
+checkW2
+    :: (Word, Word)  -- ^ expected results
+    -> (Word# -> Word# -> Word# -> (# Word#, Word# #))
+                     -- ^ primop
+    -> Word          -- ^ first argument
+    -> Word          -- ^ second argument
+    -> Word          -- ^ third argument
+    -> Maybe String  -- ^ maybe error
+checkW2 (expX, expY) op (W# a) (W# b) (W# c) =
+    case op a b c of
+        (# x, y #)
+            | W# x == expX && W# y == expY -> Nothing
+            | otherwise ->
+                Just $
+                    "Expected " ++ show expX ++ " and " ++ show expY
+                        ++ " but got " ++ show (W# x) ++ " and " ++ show (W# y)
+
 check :: String -> Maybe String -> IO ()
 check s (Just err) = error $ "Error for " ++ s ++ ": " ++ err
 check _ Nothing    = return ()
@@ -91,3 +108,21 @@ main = do
       checkW (2, maxBound - 2) timesWord2# maxBound 3
     check "timesWord2# 3 maxBound" $
       checkW (2, maxBound - 2) timesWord2# 3 maxBound
+
+    check "quotRemWord2# 0 0 1" $ checkW2 (0, 0) quotRemWord2# 0 0 1
+    check "quotRemWord2# 0 4 2" $ checkW2 (2, 0) quotRemWord2# 0 4 2
+    check "quotRemWord2# 0 7 3" $ checkW2 (2, 1) quotRemWord2# 0 7 3
+    check "quotRemWord2# 1 0 (2 ^ 63)" $
+      checkW2 (2, 0) quotRemWord2# 1 0 (2 ^ 63)
+    check "quotRemWord2# 1 1 (2 ^ 63)" $
+      checkW2 (2, 1) quotRemWord2# 1 1 (2 ^ 63)
+    check "quotRemWord2# 1 0 maxBound" $
+      checkW2 (1, 1) quotRemWord2# 1 0 maxBound
+    check "quotRemWord2# 2 0 maxBound" $
+      checkW2 (2, 2) quotRemWord2# 2 0 maxBound
+    check "quotRemWord2# 1 maxBound maxBound" $
+      checkW2 (2, 1) quotRemWord2# 1 maxBound maxBound
+    check "quotRemWord2# (2 ^ 63) 0 maxBound" $
+      checkW2 (2 ^ 63, 2 ^ 63) quotRemWord2# (2 ^ 63) 0 maxBound
+    check "quotRemWord2# (2 ^ 63) maxBound maxBound" $
+      checkW2 (2 ^ 63 + 1, 2 ^ 63) quotRemWord2# (2 ^ 63) maxBound maxBound