af807c8fd7794c6d07ca104d1580a15f487a06fe
[ghc.git] / compiler / vectorise / Vectorise / Convert.hs
1 module Vectorise.Convert
2 ( fromVect
3 )
4 where
5
6 import Vectorise.Monad
7 import Vectorise.Builtins
8 import Vectorise.Type.Type
9
10 import CoreSyn
11 import TyCon
12 import Type
13 import TyCoRep
14 import NameSet
15 import FastString
16 import Outputable
17
18 import Control.Applicative
19 import Prelude -- avoid redundant import warning due to AMP
20
21 -- |Convert a vectorised expression such that it computes the non-vectorised equivalent of its
22 -- value.
23 --
24 -- For functions, we eta expand the function and convert the arguments and result:
25
26 -- For example
27 -- @
28 -- \(x :: Double) ->
29 -- \(y :: Double) ->
30 -- ($v_foo $: x) $: y
31 -- @
32 --
33 -- We use the type of the original binding to work out how many outer lambdas to add.
34 --
35 fromVect :: Type -- ^ The type of the original binding.
36 -> CoreExpr -- ^ Expression giving the closure to use, eg @$v_foo@.
37 -> VM CoreExpr
38
39 -- Convert the type to the core view if it isn't already.
40 --
41 fromVect ty expr
42 | Just ty' <- coreView ty
43 = fromVect ty' expr
44
45 -- For each function constructor in the original type we add an outer
46 -- lambda to bind the parameter variable, and an inner application of it.
47 fromVect (ForAllTy (Anon arg_ty) res_ty) expr
48 = do
49 arg <- newLocalVar (fsLit "x") arg_ty
50 varg <- toVect arg_ty (Var arg)
51 varg_ty <- vectType arg_ty
52 vres_ty <- vectType res_ty
53 apply <- builtin applyVar
54 body <- fromVect res_ty
55 $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
56 return $ Lam arg body
57
58 -- If the type isn't a function, then we can't current convert it unless the type is scalar (i.e.,
59 -- is identical to the non-vectorised version).
60 --
61 fromVect ty expr
62 = identityConv ty >> return expr
63
64 -- Convert an expression such that it evaluates to the vectorised equivalent of the value of the
65 -- original expression.
66 --
67 -- WARNING: Currently only works for the scalar types, where the vectorised value coincides with the
68 -- original one.
69 --
70 toVect :: Type -> CoreExpr -> VM CoreExpr
71 toVect ty expr = identityConv ty >> return expr
72
73 -- |Check that the type is neutral under type vectorisation — i.e., all involved type constructor
74 -- are not altered by vectorisation as they contain no parallel arrays.
75 --
76 identityConv :: Type -> VM ()
77 identityConv ty
78 | Just ty' <- coreView ty
79 = identityConv ty'
80 identityConv (TyConApp tycon tys)
81 = do { mapM_ identityConv tys
82 ; identityConvTyCon tycon
83 }
84 identityConv (LitTy {}) = noV $ text "identityConv: not sure about literal types under vectorisation"
85 identityConv (TyVarTy {}) = noV $ text "identityConv: type variable changes under vectorisation"
86 identityConv (AppTy {}) = noV $ text "identityConv: type appl. changes under vectorisation"
87 identityConv (ForAllTy {}) = noV $ text "identityConv: quantified type changes under vectorisation"
88 identityConv (CastTy {}) = noV $ text "identityConv: not sure about casted types under vectorisation"
89 identityConv (CoercionTy {}) = noV $ text "identityConv: not sure about coercions under vectorisation"
90
91 -- |Check that this type constructor is not changed by vectorisation — i.e., it does not embed any
92 -- parallel arrays.
93 --
94 identityConvTyCon :: TyCon -> VM ()
95 identityConvTyCon tc
96 = do
97 { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
98 ; parray <- builtin parrayTyCon
99 ; if isParallel && not (tc == parray)
100 then noV idErr
101 else return ()
102 }
103 where
104 idErr = text "identityConvTyCon: type constructor contains parallel arrays" <+> ppr tc