Re-add FunTy (big patch)
[ghc.git] / compiler / vectorise / Vectorise / Utils / PADict.hs
1 module Vectorise.Utils.PADict (
2 paDictArgType,
3 paDictOfType,
4 paMethod,
5 prDictOfReprType,
6 prDictOfPReprInstTyCon
7 ) where
8
9 import Vectorise.Monad
10 import Vectorise.Builtins
11 import Vectorise.Utils.Base
12
13 import CoreSyn
14 import CoreUtils
15 import FamInstEnv
16 import Coercion
17 import Type
18 import TyCoRep
19 import TyCon
20 import CoAxiom
21 import Var
22 import Outputable
23 import DynFlags
24 import FastString
25 import Control.Monad
26
27
28 -- |Construct the PA argument type for the tyvar. For the tyvar (v :: *) it's
29 -- just PA v. For (v :: (* -> *) -> *) it's
30 --
31 -- > forall (a :: * -> *). (forall (b :: *). PA b -> PA (a b)) -> PA (v a)
32 --
33 paDictArgType :: TyVar -> VM (Maybe Type)
34 paDictArgType tv = go (mkTyVarTy tv) (tyVarKind tv)
35 where
36 go ty (FunTy k1 k2)
37 = do
38 tv <- if isCoercionType k1
39 then newCoVar (fsLit "c") k1
40 else newTyVar (fsLit "a") k1
41 mty1 <- go (mkTyVarTy tv) k1
42 case mty1 of
43 Just ty1 -> do
44 mty2 <- go (mkAppTy ty (mkTyVarTy tv)) k2
45 return $ fmap (mkInvForAllTy tv . mkFunTy ty1) mty2
46 Nothing -> go ty k2
47
48 go ty k
49 | isLiftedTypeKind k
50 = do
51 pa_cls <- builtin paClass
52 return $ Just $ mkClassPred pa_cls [ty]
53
54 go _ _ = return Nothing
55
56
57 -- |Get the PA dictionary for some type
58 --
59 paDictOfType :: Type -> VM CoreExpr
60 paDictOfType ty
61 = paDictOfTyApp ty_fn ty_args
62 where
63 (ty_fn, ty_args) = splitAppTys ty
64
65 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
66 paDictOfTyApp ty_fn ty_args
67 | Just ty_fn' <- coreView ty_fn
68 = paDictOfTyApp ty_fn' ty_args
69
70 -- for type variables, look up the dfun and apply to the PA dictionaries
71 -- of the type arguments
72 paDictOfTyApp (TyVarTy tv) ty_args
73 = do
74 { dfun <- maybeCantVectoriseM "No PA dictionary for type variable"
75 (ppr tv <+> text "in" <+> ppr ty)
76 $ lookupTyVarPA tv
77 ; dicts <- mapM paDictOfType ty_args
78 ; return $ dfun `mkTyApps` ty_args `mkApps` dicts
79 }
80
81 -- for tycons, we also need to apply the dfun to the PR dictionary of
82 -- the representation type if the tycon is polymorphic
83 paDictOfTyApp (TyConApp tc []) ty_args
84 = do
85 { dfun <- maybeCantVectoriseM noPADictErr (ppr tc <+> text "in" <+> ppr ty)
86 $ lookupTyConPA tc
87 ; super <- super_dict tc ty_args
88 ; dicts <- mapM paDictOfType ty_args
89 ; return $ Var dfun `mkTyApps` ty_args `mkApps` super `mkApps` dicts
90 }
91 where
92 noPADictErr = "No PA dictionary for type constructor (did you import 'Data.Array.Parallel'?)"
93
94 super_dict _ [] = return []
95 super_dict tycon ty_args
96 = do
97 { pr <- prDictOfPReprInst (TyConApp tycon ty_args)
98 ; return [pr]
99 }
100
101 paDictOfTyApp _ _ = getDynFlags >>= failure
102
103 failure dflags = cantVectorise dflags "Can't construct PA dictionary for type" (ppr ty)
104
105 -- |Produce code that refers to a method of the 'PA' class.
106 --
107 paMethod :: (Builtins -> Var) -> (TyCon -> Builtins -> Var) -> Type -> VM CoreExpr
108 paMethod _ query ty
109 | Just tycon <- splitPrimTyCon ty -- Is 'ty' from 'GHC.Prim' (e.g., 'Int#')?
110 = liftM Var $ builtin (query tycon)
111 paMethod method _ ty
112 = do
113 { fn <- builtin method
114 ; dict <- paDictOfType ty
115 ; return $ mkApps (Var fn) [Type ty, dict]
116 }
117
118 -- |Given a type @ty@, return the PR dictionary for @PRepr ty@.
119 --
120 prDictOfPReprInst :: Type -> VM CoreExpr
121 prDictOfPReprInst ty
122 = do
123 { (FamInstMatch { fim_instance = prepr_fam, fim_tys = prepr_args })
124 <- preprFamInst ty
125 ; prDictOfPReprInstTyCon ty (famInstAxiom prepr_fam) prepr_args
126 }
127
128 -- |Given a type @ty@, its PRepr synonym tycon and its type arguments,
129 -- return the PR @PRepr ty@. Suppose we have:
130 --
131 -- > type instance PRepr (T a1 ... an) = t
132 --
133 -- which is internally translated into
134 --
135 -- > type :R:PRepr a1 ... an = t
136 --
137 -- and the corresponding coercion. Then,
138 --
139 -- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
140 --
141 -- Note that @ty@ is only used for error messages
142 --
143 prDictOfPReprInstTyCon :: Type -> CoAxiom Unbranched -> [Type] -> VM CoreExpr
144 prDictOfPReprInstTyCon _ty prepr_ax prepr_args
145 = do
146 let rhs = mkUnbranchedAxInstRHS prepr_ax prepr_args []
147 dict <- prDictOfReprType' rhs
148 pr_co <- mkBuiltinCo prTyCon
149 let co = mkAppCo pr_co
150 $ mkSymCo
151 $ mkUnbranchedAxInstCo Nominal prepr_ax prepr_args []
152 return $ mkCast dict co
153
154 -- |Get the PR dictionary for a type. The argument must be a representation
155 -- type.
156 --
157 prDictOfReprType :: Type -> VM CoreExpr
158 prDictOfReprType ty
159 | Just (tycon, tyargs) <- splitTyConApp_maybe ty
160 = do
161 prepr <- builtin preprTyCon
162 if tycon == prepr
163 then do
164 let [ty'] = tyargs
165 pa <- paDictOfType ty'
166 sel <- builtin paPRSel
167 return $ Var sel `App` Type ty' `App` pa
168 else do
169 -- a representation tycon must have a PR instance
170 dfun <- maybeV (text "look up PR dictionary for" <+> ppr tycon) $
171 lookupTyConPR tycon
172 prDFunApply dfun tyargs
173
174 | otherwise
175 = do
176 -- it is a tyvar or an application of a tyvar
177 -- determine the PR dictionary from its PA dictionary
178 --
179 -- NOTE: This assumes that PRepr t ~ t is for all representation types
180 -- t
181 --
182 -- FIXME: This doesn't work for kinds other than * at the moment. We'd
183 -- have to simply abstract the term over the missing type arguments.
184 pa <- paDictOfType ty
185 prsel <- builtin paPRSel
186 return $ Var prsel `mkApps` [Type ty, pa]
187
188 prDictOfReprType' :: Type -> VM CoreExpr
189 prDictOfReprType' ty = prDictOfReprType ty `orElseV`
190 do dflags <- getDynFlags
191 cantVectorise dflags "No PR dictionary for representation type"
192 (ppr ty)
193
194 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
195 -- to the argument types.
196 prDFunApply :: Var -> [Type] -> VM CoreExpr
197 prDFunApply dfun tys
198 | Just [] <- ctxs -- PR (a :-> b) doesn't have a context
199 = return $ Var dfun `mkTyApps` tys
200
201 | Just tycons <- ctxs
202 , length tycons == length tys
203 = do
204 pa <- builtin paTyCon
205 pr <- builtin prTyCon
206 dflags <- getDynFlags
207 args <- zipWithM (dictionary dflags pa pr) tys tycons
208 return $ Var dfun `mkTyApps` tys `mkApps` args
209
210 | otherwise = do dflags <- getDynFlags
211 invalid dflags
212 where
213 -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
214 -- ctxs is Just [PA, PR]
215 ctxs = fmap (map fst)
216 $ sequence
217 $ map splitTyConApp_maybe
218 $ fst
219 $ splitFunTys
220 $ snd
221 $ splitForAllTys
222 $ varType dfun
223
224 dictionary dflags pa pr ty tycon
225 | tycon == pa = paDictOfType ty
226 | tycon == pr = prDictOfReprType ty
227 | otherwise = invalid dflags
228
229 invalid dflags = cantVectorise dflags "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)