Add Cmm support for representing 128-bit-wide SIMD vectors.
authorGeoffrey Mainland <gmainlan@microsoft.com>
Fri, 19 Oct 2012 08:32:02 +0000 (09:32 +0100)
committerGeoffrey Mainland <gmainlan@microsoft.com>
Fri, 1 Feb 2013 22:00:23 +0000 (22:00 +0000)
compiler/cmm/CmmCallConv.hs
compiler/cmm/CmmCommonBlockElim.hs
compiler/cmm/CmmExpr.hs
compiler/cmm/CmmType.hs
compiler/cmm/PprC.hs
compiler/cmm/PprCmmExpr.hs
compiler/llvmGen/LlvmCodeGen/Base.hs
compiler/llvmGen/LlvmCodeGen/Data.hs

index 7007872..dd4d6a6 100644 (file)
@@ -67,8 +67,11 @@ assignArgumentsPos dflags off conv arg_ty reps = (stk_off, assignments)
       assignments = reg_assts ++ stk_assts
 
       assign_regs assts []     _    = (assts, [])
-      assign_regs assts (r:rs) regs = if isFloatType ty then float else int
-        where float = case (w, regs) of
+      assign_regs assts (r:rs) regs | isVecType ty   = vec
+                                    | isFloatType ty = float
+                                    | otherwise      = int
+        where vec = (assts, (r:rs))
+              float = case (w, regs) of
                         (W32, (vs, fs, ds, ls, s:ss)) -> k (RegisterParam (FloatReg s), (vs, fs, ds, ls, ss))
                         (W32, (vs, f:fs, ds, ls, ss))
                             | not hasSseRegs          -> k (RegisterParam f, (vs, fs, ds, ls, ss))
index 614edf2..522d323 100644 (file)
@@ -119,6 +119,7 @@ hash_block block =
         hash_lit :: CmmLit -> Word32
         hash_lit (CmmInt i _) = fromInteger i
         hash_lit (CmmFloat r _) = truncate r
+        hash_lit (CmmVec ls) = hash_list hash_lit ls
         hash_lit (CmmLabel _) = 119 -- ugh
         hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
         hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
index 87713c6..dce9624 100644 (file)
@@ -33,6 +33,7 @@ import BlockId
 import CLabel
 import DynFlags
 import Unique
+import Outputable (panic)
 
 import Data.Set (Set)
 import qualified Data.Set as Set
@@ -101,6 +102,7 @@ data CmmLit
         -- it will be used as a signed or unsigned value (the CmmType doesn't
         -- distinguish between signed & unsigned).
   | CmmFloat  Rational Width
+  | CmmVec [CmmLit]                     -- Vector literal
   | CmmLabel    CLabel                  -- Address of label
   | CmmLabelOff CLabel Int              -- Address of label + byte offset
 
@@ -133,6 +135,11 @@ cmmExprType dflags (CmmStackSlot _ _)  = bWord dflags -- an address
 cmmLitType :: DynFlags -> CmmLit -> CmmType
 cmmLitType _      (CmmInt _ width)     = cmmBits  width
 cmmLitType _      (CmmFloat _ width)   = cmmFloat width
+cmmLitType _      (CmmVec [])          = panic "cmmLitType: CmmVec []"
+cmmLitType cflags (CmmVec (l:ls))      = let ty = cmmLitType cflags l
+                                         in if all (`cmmEqType` ty) (map (cmmLitType cflags) ls)
+                                            then cmmVec (1+length ls) ty
+                                            else panic "cmmLitType: CmmVec"
 cmmLitType dflags (CmmLabel lbl)       = cmmLabelType dflags lbl
 cmmLitType dflags (CmmLabelOff lbl _)  = cmmLabelType dflags lbl
 cmmLitType dflags (CmmLabelDiffOff {}) = bWord dflags
index 9a443c1..49a2dc1 100644 (file)
@@ -1,7 +1,7 @@
 
 module CmmType
     ( CmmType   -- Abstract
-    , b8, b16, b32, b64, f32, f64, bWord, bHalfWord, gcWord
+    , b8, b16, b32, b64, b128, f32, f64, bWord, bHalfWord, gcWord
     , cInt, cLong
     , cmmBits, cmmFloat
     , typeWidth, cmmEqType, cmmEqType_ignoring_ptrhood
@@ -17,6 +17,13 @@ module CmmType
     , rEP_StgEntCounter_allocs
 
     , ForeignHint(..)
+
+    , Length
+    , vec, vec2, vec4, vec8, vec16
+    , vec2f64, vec2b64, vec4f32, vec4b32, vec8b16, vec16b8
+    , cmmVec
+    , vecLength, vecElemType
+    , isVecType
    )
 where
 
@@ -42,10 +49,11 @@ import Data.Int
 data CmmType    -- The important one!
   = CmmType CmmCat Width
 
-data CmmCat     -- "Category" (not exported)
-   = GcPtrCat   -- GC pointer
-   | BitsCat    -- Non-pointer
-   | FloatCat   -- Float
+data CmmCat                -- "Category" (not exported)
+   = GcPtrCat              -- GC pointer
+   | BitsCat               -- Non-pointer
+   | FloatCat              -- Float
+   | VecCat Length CmmCat  -- Vector
    deriving( Eq )
         -- See Note [Signed vs unsigned] at the end
 
@@ -53,9 +61,10 @@ instance Outputable CmmType where
   ppr (CmmType cat wid) = ppr cat <> ppr (widthInBits wid)
 
 instance Outputable CmmCat where
-  ppr FloatCat  = ptext $ sLit("F")
-  ppr GcPtrCat  = ptext $ sLit("P")
-  ppr BitsCat   = ptext $ sLit("I")
+  ppr FloatCat       = ptext $ sLit("F")
+  ppr GcPtrCat       = ptext $ sLit("P")
+  ppr BitsCat        = ptext $ sLit("I")
+  ppr (VecCat n cat) = ppr cat <> text "x" <> ppr n <> text "V"
 
 -- Why is CmmType stratified?  For native code generation,
 -- most of the time you just want to know what sort of register
@@ -77,10 +86,15 @@ cmmEqType_ignoring_ptrhood :: CmmType -> CmmType -> Bool
 cmmEqType_ignoring_ptrhood (CmmType c1 w1) (CmmType c2 w2)
    = c1 `weak_eq` c2 && w1==w2
    where
-      FloatCat `weak_eq` FloatCat = True
-      FloatCat `weak_eq` _other   = False
-      _other   `weak_eq` FloatCat = False
-      _word1   `weak_eq` _word2   = True        -- Ignores GcPtr
+     weak_eq :: CmmCat -> CmmCat -> Bool
+     FloatCat         `weak_eq` FloatCat         = True
+     FloatCat         `weak_eq` _other           = False
+     _other           `weak_eq` FloatCat         = False
+     (VecCat l1 cat1) `weak_eq` (VecCat l2 cat2) = l1 == l2
+                                                   && cat1 `weak_eq` cat2
+     (VecCat {})      `weak_eq` _other           = False
+     _other           `weak_eq` (VecCat {})      = False
+     _word1           `weak_eq` _word2           = True        -- Ignores GcPtr
 
 --- Simple operations on CmmType -----
 typeWidth :: CmmType -> Width
@@ -92,11 +106,12 @@ cmmFloat = CmmType FloatCat
 
 -------- Common CmmTypes ------------
 -- Floats and words of specific widths
-b8, b16, b32, b64, f32, f64 :: CmmType
+b8, b16, b32, b64, b128, f32, f64 :: CmmType
 b8     = cmmBits W8
 b16    = cmmBits W16
 b32    = cmmBits W32
 b64    = cmmBits W64
+b128   = cmmBits W128
 f32    = cmmFloat W32
 f64    = cmmFloat W64
 
@@ -244,6 +259,51 @@ narrowS W32 x = fromIntegral (fromIntegral x :: Int32)
 narrowS W64 x = fromIntegral (fromIntegral x :: Int64)
 narrowS _ _ = panic "narrowTo"
 
+-----------------------------------------------------------------------------
+--              SIMD
+-----------------------------------------------------------------------------
+
+type Length = Int
+
+vec :: Length -> CmmType -> CmmType
+vec l (CmmType cat w) = CmmType (VecCat l cat) vecw
+  where
+    vecw :: Width
+    vecw = widthFromBytes (l*widthInBytes w)
+
+vec2, vec4, vec8, vec16 :: CmmType -> CmmType
+vec2  = vec 2
+vec4  = vec 4
+vec8  = vec 8
+vec16 = vec 16
+
+vec2f64, vec2b64, vec4f32, vec4b32, vec8b16, vec16b8 :: CmmType
+vec2f64 = vec 2 f64
+vec2b64 = vec 2 b64
+vec4f32 = vec 4 f32
+vec4b32 = vec 4 b32
+vec8b16 = vec 8 b16
+vec16b8 = vec 16 b8
+
+cmmVec :: Int -> CmmType -> CmmType
+cmmVec n (CmmType cat w) =
+    CmmType (VecCat n cat) (widthFromBytes (n*widthInBytes w))
+
+vecLength :: CmmType -> Length
+vecLength (CmmType (VecCat l _) _) = l
+vecLength _                        = panic "vecLength: not a vector"
+
+vecElemType :: CmmType -> CmmType
+vecElemType (CmmType (VecCat l cat) w) = CmmType cat scalw
+  where
+    scalw :: Width
+    scalw = widthFromBytes (widthInBytes w `div` l)
+vecElemType _ = panic "vecElemType: not a vector"
+
+isVecType :: CmmType -> Bool
+isVecType (CmmType (VecCat {}) _) = True
+isVecType _                       = False
+
 -------------------------------------------------------------------------
 -- Hints
 
index 45f46b8..2ca8b67 100644 (file)
@@ -467,6 +467,8 @@ pprLit lit = case lit of
                 -- these constants come from <math.h>
                 -- see #1861
 
+    CmmVec {} -> panic "PprC printing vector literal"
+
     CmmBlock bid       -> mkW_ <> pprCLabelAddr (infoTblLbl bid)
     CmmHighStackMark   -> panic "PprC printing high stack mark"
     CmmLabel clbl      -> mkW_ <> pprCLabelAddr clbl
index 71c8446..3c9fa06 100644 (file)
@@ -194,6 +194,7 @@ pprLit lit = sdocWithDynFlags $ \dflags ->
                space <> dcolon <+> ppr rep ]
 
     CmmFloat f rep     -> hsep [ double (fromRat f), dcolon, ppr rep ]
+    CmmVec lits        -> char '<' <> commafy (map pprLit lits) <> char '>'
     CmmLabel clbl      -> ppr clbl
     CmmLabelOff clbl i -> ppr clbl <> ppr_offset i
     CmmLabelDiffOff clbl1 clbl2 i -> ppr clbl1 <> char '-'
index 45f20d7..1457efe 100644 (file)
@@ -70,7 +70,8 @@ type UnresStatic = Either UnresLabel LlvmStatic
 
 -- | Translate a basic CmmType to an LlvmType.
 cmmToLlvmType :: CmmType -> LlvmType
-cmmToLlvmType ty | isFloatType ty = widthToLlvmFloat $ typeWidth ty
+cmmToLlvmType ty | isVecType ty   = LMVector (vecLength ty) (cmmToLlvmType (vecElemType ty))
+                 | isFloatType ty = widthToLlvmFloat $ typeWidth ty
                  | otherwise      = widthToLlvmInt   $ typeWidth ty
 
 -- | Translate a Cmm Float Width to a LlvmType.
index fd0d7cc..83b5453 100644 (file)
@@ -171,6 +171,14 @@ genStaticLit (CmmInt i w)
 genStaticLit (CmmFloat r w)
     = Right $ LMStaticLit (LMFloatLit (fromRational r) (widthToLlvmFloat w))
 
+genStaticLit (CmmVec ls)
+    = Right $ LMStaticLit (LMVectorLit (map toLlvmLit ls))
+  where
+    toLlvmLit :: CmmLit -> LlvmLit
+    toLlvmLit lit = case genStaticLit lit of
+                   Right (LMStaticLit llvmLit) -> llvmLit
+                   _ -> panic "genStaticLit"
+
 -- Leave unresolved, will fix later
 genStaticLit c@(CmmLabel        _    ) = Left $ c
 genStaticLit c@(CmmLabelOff     _   _) = Left $ c