Collect hoisted vectorised functions
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 splitClosureTy,
4 mkPADictType, mkPArrayType,
5 paDictArgType, paDictOfType,
6 lookupPArrayFamInst,
7 hoistExpr
8 ) where
9
10 #include "HsVersions.h"
11
12 import VectMonad
13
14 import CoreSyn
15 import CoreUtils
16 import Type
17 import TypeRep
18 import TyCon
19 import Var
20 import PrelNames
21
22 import Outputable
23 import FastString
24
25 import Control.Monad ( liftM )
26
27 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
28 collectAnnTypeArgs expr = go expr []
29 where
30 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
31 go e tys = (e, tys)
32
33 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
34 collectAnnTypeBinders expr = go [] expr
35 where
36 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
37 go bs e = (reverse bs, e)
38
39 isAnnTypeArg :: AnnExpr b ann -> Bool
40 isAnnTypeArg (_, AnnType t) = True
41 isAnnTypeArg _ = False
42
43 isClosureTyCon :: TyCon -> Bool
44 isClosureTyCon tc = tyConUnique tc == closureTyConKey
45
46 splitClosureTy :: Type -> (Type, Type)
47 splitClosureTy ty
48 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
49 , isClosureTyCon tc
50 = (arg_ty, res_ty)
51
52 | otherwise = pprPanic "splitClosureTy" (ppr ty)
53
54 mkPADictType :: Type -> VM Type
55 mkPADictType ty
56 = do
57 tc <- builtin paDictTyCon
58 return $ TyConApp tc [ty]
59
60 mkPArrayType :: Type -> VM Type
61 mkPArrayType ty
62 = do
63 tc <- builtin parrayTyCon
64 return $ TyConApp tc [ty]
65
66 paDictArgType :: TyVar -> VM (Maybe Type)
67 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
68 where
69 go ty k | Just k' <- kindView k = go ty k'
70 go ty (FunTy k1 k2)
71 = do
72 tv <- newTyVar FSLIT("a") k1
73 mty1 <- go (TyVarTy tv) k1
74 case mty1 of
75 Just ty1 -> do
76 mty2 <- go (AppTy ty (TyVarTy tv)) k2
77 return $ fmap (ForAllTy tv . FunTy ty1) mty2
78 Nothing -> go ty k2
79
80 go ty k
81 | isLiftedTypeKind k
82 = liftM Just (mkPADictType ty)
83
84 go ty k = return Nothing
85
86 paDictOfType :: Type -> VM CoreExpr
87 paDictOfType ty = paDictOfTyApp ty_fn ty_args
88 where
89 (ty_fn, ty_args) = splitAppTys ty
90
91 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
92 paDictOfTyApp ty_fn ty_args
93 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
94 paDictOfTyApp (TyVarTy tv) ty_args
95 = do
96 dfun <- maybeV (lookupTyVarPA tv)
97 paDFunApply dfun ty_args
98 paDictOfTyApp (TyConApp tc _) ty_args
99 = do
100 pa_class <- builtin paClass
101 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
102 paDFunApply (Var dfun) ty_args'
103 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
104
105 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
106 paDFunApply dfun tys
107 = do
108 dicts <- mapM paDictOfType tys
109 return $ mkApps (mkTyApps dfun tys) dicts
110
111 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
112 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
113
114 hoistExpr :: FastString -> CoreExpr -> VM Var
115 hoistExpr fs expr
116 = do
117 var <- newLocalVar fs (exprType expr)
118 updLEnv $ \env ->
119 env { local_bindings = (var, expr) : local_bindings env }
120 return var
121