Add code for looking up PA methods of primitive TyCons
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 mkDataConTag,
5 splitClosureTy,
6
7 mkBuiltinCo,
8 mkPADictType, mkPArrayType, mkPReprType,
9
10 parrayReprTyCon, parrayReprDataCon, mkVScrut,
11 prDFunOfTyCon,
12 paDictArgType, paDictOfType, paDFunType,
13 paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14 polyAbstract, polyApply, polyVApply,
15 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16 buildClosure, buildClosures,
17 mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id ( mkWildId )
35 import MkId ( unwrapFamInstScrut )
36 import Name ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim ( intPrimTy )
40 import BasicTypes ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List ( zipWith4 )
46 import Control.Monad ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50 where
51 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52 go e tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56 where
57 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58 go bs e = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62 where
63 go bs (_, AnnLam b e) | isId b = go (b:bs) e
64 go bs e = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _ = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75 | Just (tc, [ty']) <- splitTyConApp_maybe ty
76 , tyConName tc == name
77 = ty'
78
79 | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83 | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84 , tyConName tc == name
85 = (ty1, ty2)
86
87 | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91 | Just (tc', tys) <- splitTyConApp_maybe ty
92 , tc == tc'
93 = tys
94
95 | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105 = do
106 tc <- builtin get_tc
107 return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111 = do
112 tc <- builtin get_tc
113 return $ foldr (mk tc) ty tys
114 where
115 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120 = do
121 tc <- builtin get_tc
122 case tys of
123 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124 _ -> return $ foldr1 (mk tc) tys
125 where
126 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 mkClosureType :: Type -> Type -> VM Type
129 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
130
131 mkClosureTypes :: [Type] -> Type -> VM Type
132 mkClosureTypes = mkBuiltinTyConApps closureTyCon
133
134 mkPReprType :: Type -> VM Type
135 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
136
137 mkPADictType :: Type -> VM Type
138 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
139
140 mkPArrayType :: Type -> VM Type
141 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
142
143 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
144 mkBuiltinCo get_tc
145 = do
146 tc <- builtin get_tc
147 return $ mkTyConApp tc []
148
149 parrayReprTyCon :: Type -> VM (TyCon, [Type])
150 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
151
152 parrayReprDataCon :: Type -> VM (DataCon, [Type])
153 parrayReprDataCon ty
154 = do
155 (tc, arg_tys) <- parrayReprTyCon ty
156 let [dc] = tyConDataCons tc
157 return (dc, arg_tys)
158
159 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
160 mkVScrut (ve, le)
161 = do
162 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
163 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
164
165 prDFunOfTyCon :: TyCon -> VM CoreExpr
166 prDFunOfTyCon tycon
167 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
168
169 paDictArgType :: TyVar -> VM (Maybe Type)
170 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
171 where
172 go ty k | Just k' <- kindView k = go ty k'
173 go ty (FunTy k1 k2)
174 = do
175 tv <- newTyVar FSLIT("a") k1
176 mty1 <- go (TyVarTy tv) k1
177 case mty1 of
178 Just ty1 -> do
179 mty2 <- go (AppTy ty (TyVarTy tv)) k2
180 return $ fmap (ForAllTy tv . FunTy ty1) mty2
181 Nothing -> go ty k2
182
183 go ty k
184 | isLiftedTypeKind k
185 = liftM Just (mkPADictType ty)
186
187 go ty k = return Nothing
188
189 paDictOfType :: Type -> VM CoreExpr
190 paDictOfType ty = paDictOfTyApp ty_fn ty_args
191 where
192 (ty_fn, ty_args) = splitAppTys ty
193
194 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
195 paDictOfTyApp ty_fn ty_args
196 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
197 paDictOfTyApp (TyVarTy tv) ty_args
198 = do
199 dfun <- maybeV (lookupTyVarPA tv)
200 paDFunApply dfun ty_args
201 paDictOfTyApp (TyConApp tc _) ty_args
202 = do
203 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
204 paDFunApply (Var dfun) ty_args
205 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
206
207 paDFunType :: TyCon -> VM Type
208 paDFunType tc
209 = do
210 margs <- mapM paDictArgType tvs
211 res <- mkPADictType (mkTyConApp tc arg_tys)
212 return . mkForAllTys tvs
213 $ mkFunTys [arg | Just arg <- margs] res
214 where
215 tvs = tyConTyVars tc
216 arg_tys = mkTyVarTys tvs
217
218 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
219 paDFunApply dfun tys
220 = do
221 dicts <- mapM paDictOfType tys
222 return $ mkApps (mkTyApps dfun tys) dicts
223
224 type PAMethod = (Builtins -> Var, String)
225
226 pa_length = (lengthPAVar, "lengthPA")
227 pa_replicate = (replicatePAVar, "replicatePA")
228 pa_empty = (emptyPAVar, "emptyPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232 | Just (tycon, []) <- splitTyConApp_maybe ty
233 , isPrimTyCon tycon
234 = do
235 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
236 $ lookupPrimMethod tycon name
237 return (Var fn)
238
239 paMethod (method, name) ty
240 = do
241 fn <- builtin method
242 dict <- paDictOfType ty
243 return $ mkApps (Var fn) [Type ty, dict]
244
245 mkPR :: Type -> VM CoreExpr
246 mkPR ty
247 = do
248 fn <- builtin mkPRVar
249 dict <- paDictOfType ty
250 return $ mkApps (Var fn) [Type ty, dict]
251
252 lengthPA :: CoreExpr -> VM CoreExpr
253 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
254 where
255 ty = splitPArrayTy (exprType x)
256
257 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
258 replicatePA len x = liftM (`mkApps` [len,x])
259 (paMethod pa_replicate (exprType x))
260
261 emptyPA :: Type -> VM CoreExpr
262 emptyPA = paMethod pa_empty
263
264 liftPA :: CoreExpr -> VM CoreExpr
265 liftPA x
266 = do
267 lc <- builtin liftingContext
268 replicatePA (Var lc) x
269
270 newLocalVVar :: FastString -> Type -> VM VVar
271 newLocalVVar fs vty
272 = do
273 lty <- mkPArrayType vty
274 vv <- newLocalVar fs vty
275 lv <- newLocalVar fs lty
276 return (vv,lv)
277
278 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
279 polyAbstract tvs p
280 = localV
281 $ do
282 mdicts <- mapM mk_dict_var tvs
283 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
284 p (mk_lams mdicts)
285 where
286 mk_dict_var tv = do
287 r <- paDictArgType tv
288 case r of
289 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
290 Nothing -> return Nothing
291
292 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
293
294 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
295 polyApply expr tys
296 = do
297 dicts <- mapM paDictOfType tys
298 return $ expr `mkTyApps` tys `mkApps` dicts
299
300 polyVApply :: VExpr -> [Type] -> VM VExpr
301 polyVApply expr tys
302 = do
303 dicts <- mapM paDictOfType tys
304 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
305
306 hoistBinding :: Var -> CoreExpr -> VM ()
307 hoistBinding v e = updGEnv $ \env ->
308 env { global_bindings = (v,e) : global_bindings env }
309
310 hoistExpr :: FastString -> CoreExpr -> VM Var
311 hoistExpr fs expr
312 = do
313 var <- newLocalVar fs (exprType expr)
314 hoistBinding var expr
315 return var
316
317 hoistVExpr :: VExpr -> VM VVar
318 hoistVExpr (ve, le)
319 = do
320 fs <- getBindName
321 vv <- hoistExpr ('v' `consFS` fs) ve
322 lv <- hoistExpr ('l' `consFS` fs) le
323 return (vv, lv)
324
325 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
326 hoistPolyVExpr tvs p
327 = do
328 expr <- closedV . polyAbstract tvs $ \abstract ->
329 liftM (mapVect abstract) p
330 fn <- hoistVExpr expr
331 polyVApply (vVar fn) (mkTyVarTys tvs)
332
333 takeHoisted :: VM [(Var, CoreExpr)]
334 takeHoisted
335 = do
336 env <- readGEnv id
337 setGEnv $ env { global_bindings = [] }
338 return $ global_bindings env
339
340 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
341 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
342 = do
343 dict <- paDictOfType env_ty
344 mkv <- builtin mkClosureVar
345 mkl <- builtin mkClosurePVar
346 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
347 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
348
349 mkClosureApp :: VExpr -> VExpr -> VM VExpr
350 mkClosureApp (vclo, lclo) (varg, larg)
351 = do
352 vapply <- builtin applyClosureVar
353 lapply <- builtin applyClosurePVar
354 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
355 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
356 where
357 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
358
359 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
360 buildClosures tvs vars [] res_ty mk_body
361 = mk_body
362 buildClosures tvs vars [arg_ty] res_ty mk_body
363 = buildClosure tvs vars arg_ty res_ty mk_body
364 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
365 = do
366 res_ty' <- mkClosureTypes arg_tys res_ty
367 arg <- newLocalVVar FSLIT("x") arg_ty
368 buildClosure tvs vars arg_ty res_ty'
369 . hoistPolyVExpr tvs
370 $ do
371 lc <- builtin liftingContext
372 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
373 return $ vLams lc (vars ++ [arg]) clo
374
375 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
376 -- where
377 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
378 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
379 --
380 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
381 buildClosure tvs vars arg_ty res_ty mk_body
382 = do
383 (env_ty, env, bind) <- buildEnv vars
384 env_bndr <- newLocalVVar FSLIT("env") env_ty
385 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
386
387 fn <- hoistPolyVExpr tvs
388 $ do
389 lc <- builtin liftingContext
390 body <- mk_body
391 body' <- bind (vVar env_bndr)
392 (vVarApps lc body (vars ++ [arg_bndr]))
393 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
394
395 mkClosure arg_ty res_ty env_ty fn env
396
397 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
398 buildEnv vvs
399 = do
400 lc <- builtin liftingContext
401 let (ty, venv, vbind) = mkVectEnv tys vs
402 (lenv, lbind) <- mkLiftEnv lc tys ls
403 return (ty, (venv, lenv),
404 \(venv,lenv) (vbody,lbody) ->
405 do
406 let vbody' = vbind venv vbody
407 lbody' <- lbind lenv lbody
408 return (vbody', lbody'))
409 where
410 (vs,ls) = unzip vvs
411 tys = map idType vs
412
413 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
414 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
415 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
416 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
417 \env body -> Case env (mkWildId ty) (exprType body)
418 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
419 where
420 ty = mkCoreTupTy tys
421
422 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
423 mkLiftEnv lc [ty] [v]
424 = return (Var v, \env body ->
425 do
426 len <- lengthPA (Var v)
427 return . Let (NonRec v env)
428 $ Case len lc (exprType body) [(DEFAULT, [], body)])
429
430 -- NOTE: this transparently deals with empty environments
431 mkLiftEnv lc tys vs
432 = do
433 (env_tc, env_tyargs) <- parrayReprTyCon vty
434 let [env_con] = tyConDataCons env_tc
435
436 env = Var (dataConWrapId env_con)
437 `mkTyApps` env_tyargs
438 `mkVarApps` (lc : vs)
439
440 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
441 in
442 return $ Case scrut (mkWildId (exprType scrut))
443 (exprType body)
444 [(DataAlt env_con, lc : bndrs, body)]
445 return (env, bind)
446 where
447 vty = mkCoreTupTy tys
448
449 bndrs | null vs = [mkWildId unitTy]
450 | otherwise = vs
451