01fbede4bd45597c89b83701802535b8c6f869c2
[ghc.git] / compiler / vectorise / Vectorise / Utils / PADict.hs
1 module Vectorise.Utils.PADict (
2 paDictArgType,
3 paDictOfType,
4 paMethod,
5 prDictOfReprType,
6 prDictOfPReprInstTyCon
7 ) where
8
9 import Vectorise.Monad
10 import Vectorise.Builtins
11 import Vectorise.Utils.Base
12
13 import CoreSyn
14 import CoreUtils
15 import FamInstEnv
16 import Coercion
17 import Type
18 import TypeRep
19 import TyCon
20 import CoAxiom
21 import Var
22 import Outputable
23 import DynFlags
24 import FastString
25 import Control.Monad
26
27
28 -- |Construct the PA argument type for the tyvar. For the tyvar (v :: *) it's
29 -- just PA v. For (v :: (* -> *) -> *) it's
30 --
31 -- > forall (a :: * -> *). (forall (b :: *). PA b -> PA (a b)) -> PA (v a)
32 --
33 paDictArgType :: TyVar -> VM (Maybe Type)
34 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
35 where
36 go ty (FunTy k1 k2)
37 = do
38 tv <- newTyVar (fsLit "a") k1
39 mty1 <- go (TyVarTy tv) k1
40 case mty1 of
41 Just ty1 -> do
42 mty2 <- go (AppTy ty (TyVarTy tv)) k2
43 return $ fmap (ForAllTy tv . FunTy ty1) mty2
44 Nothing -> go ty k2
45
46 go ty k
47 | isLiftedTypeKind k
48 = do
49 pa_cls <- builtin paClass
50 return $ Just $ mkClassPred pa_cls [ty]
51
52 go _ _ = return Nothing
53
54
55 -- |Get the PA dictionary for some type
56 --
57 paDictOfType :: Type -> VM CoreExpr
58 paDictOfType ty
59 = paDictOfTyApp ty_fn ty_args
60 where
61 (ty_fn, ty_args) = splitAppTys ty
62
63 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
64 paDictOfTyApp ty_fn ty_args
65 | Just ty_fn' <- coreView ty_fn
66 = paDictOfTyApp ty_fn' ty_args
67
68 -- for type variables, look up the dfun and apply to the PA dictionaries
69 -- of the type arguments
70 paDictOfTyApp (TyVarTy tv) ty_args
71 = do
72 { dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
73 (ppr tv <+> text "in" <+> ppr ty)
74 $ lookupTyVarPA tv
75 ; dicts <- mapM paDictOfType ty_args
76 ; return $ dfun `mkTyApps` ty_args `mkApps` dicts
77 }
78
79 -- for tycons, we also need to apply the dfun to the PR dictionary of
80 -- the representation type if the tycon is polymorphic
81 paDictOfTyApp (TyConApp tc []) ty_args
82 = do
83 { dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
84 $ lookupTyConPA tc
85 ; super <- super_dict tc ty_args
86 ; dicts <- mapM paDictOfType ty_args
87 ; return $ Var dfun `mkTyApps` ty_args `mkApps` super `mkApps` dicts
88 }
89 where
90 noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"
91
92 super_dict _ [] = return []
93 super_dict tycon ty_args
94 = do
95 { pr <- prDictOfPReprInst (TyConApp tycon ty_args)
96 ; return [pr]
97 }
98
99 paDictOfTyApp _ _ = getDynFlags >>= failure
100
101 failure dflags = cantVectorise dflags "Can't construct PA dictionary for type" (ppr ty)
102
103 -- |Produce code that refers to a method of the 'PA' class.
104 --
105 paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
106 paMethod _ query ty
107 | Just tycon <- splitPrimTyCon ty -- Is 'ty' from 'GHC.Prim' (e.g., 'Int#')?
108 = liftM Var $ builtin (query tycon)
109 paMethod method _ ty
110 = do
111 { fn <- builtin method
112 ; dict <- paDictOfType ty
113 ; return $ mkApps (Var fn) [Type ty, dict]
114 }
115
116 -- |Given a type @ty@, return the PR dictionary for @PRepr ty@.
117 --
118 prDictOfPReprInst :: Type -> VM CoreExpr
119 prDictOfPReprInst ty
120 = do
121 { (FamInstMatch { fim_instance = prepr_fam, fim_tys = prepr_args }) <- preprSynTyCon ty
122 ; prDictOfPReprInstTyCon ty (famInstAxiom prepr_fam) prepr_args
123 }
124
125 -- |Given a type @ty@, its PRepr synonym tycon and its type arguments,
126 -- return the PR @PRepr ty@. Suppose we have:
127 --
128 -- > type instance PRepr (T a1 ... an) = t
129 --
130 -- which is internally translated into
131 --
132 -- > type :R:PRepr a1 ... an = t
133 --
134 -- and the corresponding coercion. Then,
135 --
136 -- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
137 --
138 -- Note that @ty@ is only used for error messages
139 --
140 prDictOfPReprInstTyCon :: Type -> CoAxiom Unbranched -> [Type] -> VM CoreExpr
141 prDictOfPReprInstTyCon _ty prepr_ax prepr_args
142 = do
143 let rhs = mkUnbranchedAxInstRHS prepr_ax prepr_args
144 dict <- prDictOfReprType' rhs
145 pr_co <- mkBuiltinCo prTyCon
146 let co = mkAppCo pr_co
147 $ mkSymCo
148 $ mkUnbranchedAxInstCo Nominal prepr_ax prepr_args
149 return $ mkCast dict co
150
151 -- |Get the PR dictionary for a type. The argument must be a representation
152 -- type.
153 --
154 prDictOfReprType :: Type -> VM CoreExpr
155 prDictOfReprType ty
156 | Just (tycon, tyargs) <- splitTyConApp_maybe ty
157 = do
158 prepr <- builtin preprTyCon
159 if tycon == prepr
160 then do
161 let [ty'] = tyargs
162 pa <- paDictOfType ty'
163 sel <- builtin paPRSel
164 return $ Var sel `App` Type ty' `App` pa
165 else do
166 -- a representation tycon must have a PR instance
167 dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
168 lookupTyConPR tycon
169 prDFunApply dfun tyargs
170
171 | otherwise
172 = do
173 -- it is a tyvar or an application of a tyvar
174 -- determine the PR dictionary from its PA dictionary
175 --
176 -- NOTE: This assumes that PRepr t ~ t is for all representation types
177 -- t
178 --
179 -- FIXME: This doesn't work for kinds other than * at the moment. We'd
180 -- have to simply abstract the term over the missing type arguments.
181 pa <- paDictOfType ty
182 prsel <- builtin paPRSel
183 return $ Var prsel `mkApps` [Type ty, pa]
184
185 prDictOfReprType' :: Type -> VM CoreExpr
186 prDictOfReprType' ty = prDictOfReprType ty `orElseV`
187 do dflags <- getDynFlags
188 cantVectorise dflags "No PR dictionary for representation type"
189 (ppr ty)
190
191 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
192 -- to the argument types.
193 prDFunApply :: Var -> [Type] -> VM CoreExpr
194 prDFunApply dfun tys
195 | Just [] <- ctxs -- PR (a :-> b) doesn't have a context
196 = return $ Var dfun `mkTyApps` tys
197
198 | Just tycons <- ctxs
199 , length tycons == length tys
200 = do
201 pa <- builtin paTyCon
202 pr <- builtin prTyCon
203 dflags <- getDynFlags
204 args <- zipWithM (dictionary dflags pa pr) tys tycons
205 return $ Var dfun `mkTyApps` tys `mkApps` args
206
207 | otherwise = do dflags <- getDynFlags
208 invalid dflags
209 where
210 -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
211 -- ctxs is Just [PA, PR]
212 ctxs = fmap (map fst)
213 $ sequence
214 $ map splitTyConApp_maybe
215 $ fst
216 $ splitFunTys
217 $ snd
218 $ splitForAllTys
219 $ varType dfun
220
221 dictionary dflags pa pr ty tycon
222 | tycon == pa = paDictOfType ty
223 | tycon == pr = prDictOfReprType ty
224 | otherwise = invalid dflags
225
226 invalid dflags = cantVectorise dflags "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)
227