0f101bd3cb11134be62a94e00728f2bff3288fba
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 mkDataConTag,
5 splitClosureTy,
6 mkPRepr, mkToPRepr, mkFromPRepr,
7 mkPADictType, mkPArrayType, mkPReprType,
8 parrayReprTyCon, parrayReprDataCon, mkVScrut,
9 prDictOfType, prCoerce,
10 paDictArgType, paDictOfType, paDFunType,
11 paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12 polyAbstract, polyApply, polyVApply,
13 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14 buildClosure, buildClosures,
15 mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Coercion
27 import Type
28 import TypeRep
29 import TyCon
30 import DataCon ( DataCon, dataConWrapId, dataConTag )
31 import Var
32 import Id ( mkWildId )
33 import MkId ( unwrapFamInstScrut )
34 import Name ( Name )
35 import PrelNames
36 import TysWiredIn
37 import BasicTypes ( Boxity(..) )
38
39 import Outputable
40 import FastString
41
42 import Data.List ( zipWith4 )
43 import Control.Monad ( liftM, liftM2, zipWithM_ )
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47 where
48 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49 go e tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59 where
60 go bs (_, AnnLam b e) | isId b = go (b:bs) e
61 go bs e = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType t) = True
65 isAnnTypeArg _ = False
66
67 mkDataConTag :: DataCon -> CoreExpr
68 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
69
70 splitUnTy :: String -> Name -> Type -> Type
71 splitUnTy s name ty
72 | Just (tc, [ty']) <- splitTyConApp_maybe ty
73 , tyConName tc == name
74 = ty'
75
76 | otherwise = pprPanic s (ppr ty)
77
78 splitBinTy :: String -> Name -> Type -> (Type, Type)
79 splitBinTy s name ty
80 | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
81 , tyConName tc == name
82 = (ty1, ty2)
83
84 | otherwise = pprPanic s (ppr ty)
85
86 splitFixedTyConApp :: TyCon -> Type -> [Type]
87 splitFixedTyConApp tc ty
88 | Just (tc', tys) <- splitTyConApp_maybe ty
89 , tc == tc'
90 = tys
91
92 | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
93
94 splitEmbedTy :: Type -> Type
95 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
104 mkBuiltinTyConApp get_tc tys
105 = do
106 tc <- builtin get_tc
107 return $ mkTyConApp tc tys
108
109 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
110 mkBuiltinTyConApps get_tc tys ty
111 = do
112 tc <- builtin get_tc
113 return $ foldr (mk tc) ty tys
114 where
115 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
118 mkBuiltinTyConApps1 get_tc dft [] = return dft
119 mkBuiltinTyConApps1 get_tc dft tys
120 = do
121 tc <- builtin get_tc
122 case tys of
123 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
124 _ -> return $ foldr1 (mk tc) tys
125 where
126 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
127
128 mkPRepr :: [[Type]] -> VM Type
129 mkPRepr tys
130 = do
131 embed_tc <- builtin embedTyCon
132 sum_tcs <- builtins sumTyCon
133 prod_tcs <- builtins prodTyCon
134
135 let mk_sum [] = unitTy
136 mk_sum [ty] = ty
137 mk_sum tys = mkTyConApp (sum_tcs $ length tys) tys
138
139 mk_prod [] = unitTy
140 mk_prod [ty] = ty
141 mk_prod tys = mkTyConApp (prod_tcs $ length tys) tys
142
143 mk_embed ty = mkTyConApp embed_tc [ty]
144
145 return . mk_sum
146 . map (mk_prod . map mk_embed)
147 $ tys
148
149 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
150 mkToPRepr ess
151 = do
152 embed_tc <- builtin embedTyCon
153 embed_dc <- builtin embedDataCon
154 sum_tcs <- builtins sumTyCon
155 prod_tcs <- builtins prodTyCon
156
157 let mk_sum [] = ([Var unitDataConId], unitTy)
158 mk_sum [(expr, ty)] = ([expr], ty)
159 mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
160 mkTyConApp sum_tc tys)
161 where
162 (exprs, tys) = unzip es
163 sum_tc = sum_tcs (length es)
164 mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
165
166 mk_prod [] = (Var unitDataConId, unitTy)
167 mk_prod [(expr, ty)] = (expr, ty)
168 mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
169 mkTyConApp prod_tc tys)
170 where
171 (exprs, tys) = unzip es
172 prod_tc = prod_tcs (length es)
173 [prod_dc] = tyConDataCons prod_tc
174
175 mk_embed expr = (mkConApp embed_dc [Type ty, expr],
176 mkTyConApp embed_tc [ty])
177 where ty = exprType expr
178
179 return . mk_sum $ map (mk_prod . map mk_embed) ess
180
181 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
182 mkFromPRepr scrut res_ty alts
183 = do
184 embed_dc <- builtin embedDataCon
185 sum_tcs <- builtins sumTyCon
186 prod_tcs <- builtins prodTyCon
187
188 let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
189 un_sum expr ty bs
190 = do
191 ps <- mapM (newLocalVar FSLIT("p")) tys
192 bodies <- sequence
193 $ zipWith4 un_prod (map Var ps) tys vars rs
194 return . Case expr (mkWildId ty) res_ty
195 $ zipWith3 mk_alt sum_dcs ps bodies
196 where
197 (vars, rs) = unzip bs
198 tys = splitFixedTyConApp sum_tc ty
199 sum_tc = sum_tcs $ length bs
200 sum_dcs = tyConDataCons sum_tc
201
202 mk_alt dc p body = (DataAlt dc, [p], body)
203
204 un_prod expr ty [] r = return r
205 un_prod expr ty [var] r = return $ un_embed expr ty var r
206 un_prod expr ty vars r
207 = do
208 xs <- mapM (newLocalVar FSLIT("x")) tys
209 let body = foldr (\(e,t,v) r -> un_embed e t v r) r
210 $ zip3 (map Var xs) tys vars
211 return $ Case expr (mkWildId ty) res_ty
212 [(DataAlt prod_dc, xs, body)]
213 where
214 tys = splitFixedTyConApp prod_tc ty
215 prod_tc = prod_tcs $ length vars
216 [prod_dc] = tyConDataCons prod_tc
217
218 un_embed expr ty var r
219 = Case expr (mkWildId ty) res_ty
220 [(DataAlt embed_dc, [var], r)]
221
222 un_sum scrut (exprType scrut) alts
223
224 mkClosureType :: Type -> Type -> VM Type
225 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
226
227 mkClosureTypes :: [Type] -> Type -> VM Type
228 mkClosureTypes = mkBuiltinTyConApps closureTyCon
229
230 mkPReprType :: Type -> VM Type
231 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
232
233 mkPADictType :: Type -> VM Type
234 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
235
236 mkPArrayType :: Type -> VM Type
237 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
238
239 parrayReprTyCon :: Type -> VM (TyCon, [Type])
240 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
241
242 parrayReprDataCon :: Type -> VM (DataCon, [Type])
243 parrayReprDataCon ty
244 = do
245 (tc, arg_tys) <- parrayReprTyCon ty
246 let [dc] = tyConDataCons tc
247 return (dc, arg_tys)
248
249 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
250 mkVScrut (ve, le)
251 = do
252 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
253 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
254
255 prDictOfType :: Type -> VM CoreExpr
256 prDictOfType orig_ty
257 | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
258 = do
259 dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
260 prDFunApply (Var dfun) ty_args
261
262 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
263 prDFunApply dfun tys
264 = do
265 args <- mapM mkDFunArg arg_tys
266 return $ mkApps mono_dfun args
267 where
268 mono_dfun = mkTyApps dfun tys
269 (arg_tys, _) = splitFunTys (exprType mono_dfun)
270
271 mkDFunArg :: Type -> VM CoreExpr
272 mkDFunArg ty
273 | Just (tycon, [arg]) <- splitTyConApp_maybe ty
274
275 = let name = tyConName tycon
276
277 get_dict | name == paTyConName = paDictOfType
278 | name == prTyConName = prDictOfType
279 | otherwise = pprPanic "mkDFunArg" (ppr ty)
280
281 in get_dict arg
282
283 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
284
285 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
286 prCoerce repr_tc args expr
287 | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
288 = do
289 pr_tc <- builtin prTyCon
290
291 let co = mkAppCoercion (mkTyConApp pr_tc [])
292 (mkSymCoercion (mkTyConApp arg_co args))
293
294 return $ mkCoerce co expr
295
296 paDictArgType :: TyVar -> VM (Maybe Type)
297 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
298 where
299 go ty k | Just k' <- kindView k = go ty k'
300 go ty (FunTy k1 k2)
301 = do
302 tv <- newTyVar FSLIT("a") k1
303 mty1 <- go (TyVarTy tv) k1
304 case mty1 of
305 Just ty1 -> do
306 mty2 <- go (AppTy ty (TyVarTy tv)) k2
307 return $ fmap (ForAllTy tv . FunTy ty1) mty2
308 Nothing -> go ty k2
309
310 go ty k
311 | isLiftedTypeKind k
312 = liftM Just (mkPADictType ty)
313
314 go ty k = return Nothing
315
316 paDictOfType :: Type -> VM CoreExpr
317 paDictOfType ty = paDictOfTyApp ty_fn ty_args
318 where
319 (ty_fn, ty_args) = splitAppTys ty
320
321 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
322 paDictOfTyApp ty_fn ty_args
323 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
324 paDictOfTyApp (TyVarTy tv) ty_args
325 = do
326 dfun <- maybeV (lookupTyVarPA tv)
327 paDFunApply dfun ty_args
328 paDictOfTyApp (TyConApp tc _) ty_args
329 = do
330 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
331 paDFunApply (Var dfun) ty_args
332 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
333
334 paDFunType :: TyCon -> VM Type
335 paDFunType tc
336 = do
337 margs <- mapM paDictArgType tvs
338 res <- mkPADictType (mkTyConApp tc arg_tys)
339 return . mkForAllTys tvs
340 $ mkFunTys [arg | Just arg <- margs] res
341 where
342 tvs = tyConTyVars tc
343 arg_tys = mkTyVarTys tvs
344
345 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
346 paDFunApply dfun tys
347 = do
348 dicts <- mapM paDictOfType tys
349 return $ mkApps (mkTyApps dfun tys) dicts
350
351 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
352 paMethod method ty
353 = do
354 fn <- builtin method
355 dict <- paDictOfType ty
356 return $ mkApps (Var fn) [Type ty, dict]
357
358 lengthPA :: CoreExpr -> VM CoreExpr
359 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
360 where
361 ty = splitPArrayTy (exprType x)
362
363 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
364 replicatePA len x = liftM (`mkApps` [len,x])
365 (paMethod replicatePAVar (exprType x))
366
367 emptyPA :: Type -> VM CoreExpr
368 emptyPA = paMethod emptyPAVar
369
370 liftPA :: CoreExpr -> VM CoreExpr
371 liftPA x
372 = do
373 lc <- builtin liftingContext
374 replicatePA (Var lc) x
375
376 newLocalVVar :: FastString -> Type -> VM VVar
377 newLocalVVar fs vty
378 = do
379 lty <- mkPArrayType vty
380 vv <- newLocalVar fs vty
381 lv <- newLocalVar fs lty
382 return (vv,lv)
383
384 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
385 polyAbstract tvs p
386 = localV
387 $ do
388 mdicts <- mapM mk_dict_var tvs
389 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
390 p (mk_lams mdicts)
391 where
392 mk_dict_var tv = do
393 r <- paDictArgType tv
394 case r of
395 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
396 Nothing -> return Nothing
397
398 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
399
400 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
401 polyApply expr tys
402 = do
403 dicts <- mapM paDictOfType tys
404 return $ expr `mkTyApps` tys `mkApps` dicts
405
406 polyVApply :: VExpr -> [Type] -> VM VExpr
407 polyVApply expr tys
408 = do
409 dicts <- mapM paDictOfType tys
410 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
411
412 hoistBinding :: Var -> CoreExpr -> VM ()
413 hoistBinding v e = updGEnv $ \env ->
414 env { global_bindings = (v,e) : global_bindings env }
415
416 hoistExpr :: FastString -> CoreExpr -> VM Var
417 hoistExpr fs expr
418 = do
419 var <- newLocalVar fs (exprType expr)
420 hoistBinding var expr
421 return var
422
423 hoistVExpr :: VExpr -> VM VVar
424 hoistVExpr (ve, le)
425 = do
426 fs <- getBindName
427 vv <- hoistExpr ('v' `consFS` fs) ve
428 lv <- hoistExpr ('l' `consFS` fs) le
429 return (vv, lv)
430
431 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
432 hoistPolyVExpr tvs p
433 = do
434 expr <- closedV . polyAbstract tvs $ \abstract ->
435 liftM (mapVect abstract) p
436 fn <- hoistVExpr expr
437 polyVApply (vVar fn) (mkTyVarTys tvs)
438
439 takeHoisted :: VM [(Var, CoreExpr)]
440 takeHoisted
441 = do
442 env <- readGEnv id
443 setGEnv $ env { global_bindings = [] }
444 return $ global_bindings env
445
446 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
447 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
448 = do
449 dict <- paDictOfType env_ty
450 mkv <- builtin mkClosureVar
451 mkl <- builtin mkClosurePVar
452 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
453 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
454
455 mkClosureApp :: VExpr -> VExpr -> VM VExpr
456 mkClosureApp (vclo, lclo) (varg, larg)
457 = do
458 vapply <- builtin applyClosureVar
459 lapply <- builtin applyClosurePVar
460 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
461 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
462 where
463 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
464
465 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
466 buildClosures tvs vars [] res_ty mk_body
467 = mk_body
468 buildClosures tvs vars [arg_ty] res_ty mk_body
469 = buildClosure tvs vars arg_ty res_ty mk_body
470 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
471 = do
472 res_ty' <- mkClosureTypes arg_tys res_ty
473 arg <- newLocalVVar FSLIT("x") arg_ty
474 buildClosure tvs vars arg_ty res_ty'
475 . hoistPolyVExpr tvs
476 $ do
477 lc <- builtin liftingContext
478 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
479 return $ vLams lc (vars ++ [arg]) clo
480
481 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
482 -- where
483 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
484 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
485 --
486 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
487 buildClosure tvs vars arg_ty res_ty mk_body
488 = do
489 (env_ty, env, bind) <- buildEnv vars
490 env_bndr <- newLocalVVar FSLIT("env") env_ty
491 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
492
493 fn <- hoistPolyVExpr tvs
494 $ do
495 lc <- builtin liftingContext
496 body <- mk_body
497 body' <- bind (vVar env_bndr)
498 (vVarApps lc body (vars ++ [arg_bndr]))
499 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
500
501 mkClosure arg_ty res_ty env_ty fn env
502
503 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
504 buildEnv vvs
505 = do
506 lc <- builtin liftingContext
507 let (ty, venv, vbind) = mkVectEnv tys vs
508 (lenv, lbind) <- mkLiftEnv lc tys ls
509 return (ty, (venv, lenv),
510 \(venv,lenv) (vbody,lbody) ->
511 do
512 let vbody' = vbind venv vbody
513 lbody' <- lbind lenv lbody
514 return (vbody', lbody'))
515 where
516 (vs,ls) = unzip vvs
517 tys = map idType vs
518
519 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
520 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
521 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
522 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
523 \env body -> Case env (mkWildId ty) (exprType body)
524 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
525 where
526 ty = mkCoreTupTy tys
527
528 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
529 mkLiftEnv lc [ty] [v]
530 = return (Var v, \env body ->
531 do
532 len <- lengthPA (Var v)
533 return . Let (NonRec v env)
534 $ Case len lc (exprType body) [(DEFAULT, [], body)])
535
536 -- NOTE: this transparently deals with empty environments
537 mkLiftEnv lc tys vs
538 = do
539 (env_tc, env_tyargs) <- parrayReprTyCon vty
540 let [env_con] = tyConDataCons env_tc
541
542 env = Var (dataConWrapId env_con)
543 `mkTyApps` env_tyargs
544 `mkVarApps` (lc : vs)
545
546 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
547 in
548 return $ Case scrut (mkWildId (exprType scrut))
549 (exprType body)
550 [(DataAlt env_con, lc : bndrs, body)]
551 return (env, bind)
552 where
553 vty = mkCoreTupTy tys
554
555 bndrs | null vs = [mkWildId unitTy]
556 | otherwise = vs
557