Vectoriser gets all DPH library identifiers from Data.Array.Parallel.Prim
[ghc.git] / compiler / vectorise / Vectorise / Utils / PADict.hs
1 module Vectorise.Utils.PADict (
2 paDictArgType,
3 paDictOfType,
4 paMethod,
5 prDictOfReprType,
6 prDictOfPReprInstTyCon
7 ) where
8
9 import Vectorise.Monad
10 import Vectorise.Builtins
11 import Vectorise.Utils.Base
12
13 import CoreSyn
14 import CoreUtils
15 import Coercion
16 import Type
17 import TypeRep
18 import TyCon
19 import Var
20 import Outputable
21 import FastString
22 import Control.Monad
23
24
25 -- |Construct the PA argument type for the tyvar. For the tyvar (v :: *) it's
26 -- just PA v. For (v :: (* -> *) -> *) it's
27 --
28 -- > forall (a :: * -> *). (forall (b :: *). PA b -> PA (a b)) -> PA (v a)
29 --
30 paDictArgType :: TyVar -> VM (Maybe Type)
31 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
32 where
33 go ty (FunTy k1 k2)
34 = do
35 tv <- newTyVar (fsLit "a") k1
36 mty1 <- go (TyVarTy tv) k1
37 case mty1 of
38 Just ty1 -> do
39 mty2 <- go (AppTy ty (TyVarTy tv)) k2
40 return $ fmap (ForAllTy tv . FunTy ty1) mty2
41 Nothing -> go ty k2
42
43 go ty k
44 | isLiftedTypeKind k
45 = do
46 pa_cls <- builtin paClass
47 return $ Just $ mkClassPred pa_cls [ty]
48
49 go _ _ = return Nothing
50
51
52 -- |Get the PA dictionary for some type
53 --
54 paDictOfType :: Type -> VM CoreExpr
55 paDictOfType ty
56 = paDictOfTyApp ty_fn ty_args
57 where
58 (ty_fn, ty_args) = splitAppTys ty
59
60 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
61 paDictOfTyApp ty_fn ty_args
62 | Just ty_fn' <- coreView ty_fn
63 = paDictOfTyApp ty_fn' ty_args
64
65 -- for type variables, look up the dfun and apply to the PA dictionaries
66 -- of the type arguments
67 paDictOfTyApp (TyVarTy tv) ty_args
68 = do dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
69 (ppr tv <+> text "in" <+> ppr ty)
70 $ lookupTyVarPA tv
71 dicts <- mapM paDictOfType ty_args
72 return $ dfun `mkTyApps` ty_args `mkApps` dicts
73
74 -- for tycons, we also need to apply the dfun to the PR dictionary of
75 -- the representation type if the tycon is polymorphic
76 paDictOfTyApp (TyConApp tc []) ty_args
77 = do
78 dfun <- maybeCantVectoriseM "No PA dictionary for type constructor"
79 (ppr tc <+> text "in" <+> ppr ty)
80 $ lookupTyConPA tc
81 dicts <- mapM paDictOfType ty_args
82 return $ Var dfun `mkTyApps` ty_args `mkApps` dicts
83
84 paDictOfTyApp _ _ = failure
85
86 failure = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
87
88 -- |Produce code that refers to a method of the 'PA' class.
89 --
90 paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
91 paMethod _ query ty
92 | Just tycon <- splitPrimTyCon ty -- Is 'ty' from 'GHC.Prim' (e.g., 'Int#')?
93 = liftM Var $ builtin (query tycon)
94 paMethod method _ ty
95 = do
96 fn <- builtin method
97 dict <- paDictOfType ty
98 return $ mkApps (Var fn) [Type ty, dict]
99
100 -- | Given a type @ty@, its PRepr synonym tycon and its type arguments,
101 -- return the PR @PRepr ty@. Suppose we have:
102 --
103 -- > type instance PRepr (T a1 ... an) = t
104 --
105 -- which is internally translated into
106 --
107 -- > type :R:PRepr a1 ... an = t
108 --
109 -- and the corresponding coercion. Then,
110 --
111 -- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
112 --
113 -- Note that @ty@ is only used for error messages
114 --
115 prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr
116 prDictOfPReprInstTyCon ty prepr_tc prepr_args
117 | Just rhs <- coreView (mkTyConApp prepr_tc prepr_args)
118 = do
119 dict <- prDictOfReprType' rhs
120 pr_co <- mkBuiltinCo prTyCon
121 let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
122 let co = mkAppCo pr_co
123 $ mkSymCo
124 $ mkAxInstCo arg_co prepr_args
125 return $ mkCoerce co dict
126
127 | otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty)
128
129 -- |Get the PR dictionary for a type. The argument must be a representation
130 -- type.
131 --
132 prDictOfReprType :: Type -> VM CoreExpr
133 prDictOfReprType ty
134 | Just (tycon, tyargs) <- splitTyConApp_maybe ty
135 = do
136 prepr <- builtin preprTyCon
137 if tycon == prepr
138 then do
139 let [ty'] = tyargs
140 pa <- paDictOfType ty'
141 sel <- builtin paPRSel
142 return $ Var sel `App` Type ty' `App` pa
143 else do
144 -- a representation tycon must have a PR instance
145 dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
146 lookupTyConPR tycon
147 prDFunApply dfun tyargs
148
149 | otherwise
150 = do
151 -- it is a tyvar or an application of a tyvar
152 -- determine the PR dictionary from its PA dictionary
153 --
154 -- NOTE: This assumes that PRepr t ~ t is for all representation types
155 -- t
156 --
157 -- FIXME: This doesn't work for kinds other than * at the moment. We'd
158 -- have to simply abstract the term over the missing type arguments.
159 pa <- paDictOfType ty
160 prsel <- builtin paPRSel
161 return $ Var prsel `mkApps` [Type ty, pa]
162
163 prDictOfReprType' :: Type -> VM CoreExpr
164 prDictOfReprType' ty = prDictOfReprType ty `orElseV`
165 cantVectorise "No PR dictionary for representation type"
166 (ppr ty)
167
168 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
169 -- to the argument types.
170 prDFunApply :: Var -> [Type] -> VM CoreExpr
171 prDFunApply dfun tys
172 | Just [] <- ctxs -- PR (a :-> b) doesn't have a context
173 = return $ Var dfun `mkTyApps` tys
174
175 | Just tycons <- ctxs
176 , length tycons == length tys
177 = do
178 pa <- builtin paTyCon
179 pr <- builtin prTyCon
180 args <- zipWithM (dictionary pa pr) tys tycons
181 return $ Var dfun `mkTyApps` tys `mkApps` args
182
183 | otherwise = invalid
184 where
185 -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
186 -- ctxs is Just [PA, PR]
187 ctxs = fmap (map fst)
188 $ sequence
189 $ map splitTyConApp_maybe
190 $ fst
191 $ splitFunTys
192 $ snd
193 $ splitForAllTys
194 $ varType dfun
195
196 dictionary pa pr ty tycon
197 | tycon == pa = paDictOfType ty
198 | tycon == pr = prDictOfReprType ty
199 | otherwise = invalid
200
201 invalid = cantVectorise "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)
202