Rename functions
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 splitClosureTy,
4 mkPADictType, mkPArrayType,
5 paDictArgType, paDictOfType,
6 paMethod, lengthPA, replicatePA, emptyPA,
7 polyAbstract, polyApply, polyVApply,
8 lookupPArrayFamInst,
9 hoistExpr, hoistPolyVExpr, takeHoisted,
10 buildClosure
11 ) where
12
13 #include "HsVersions.h"
14
15 import VectCore
16 import VectMonad
17
18 import DsUtils
19 import CoreSyn
20 import CoreUtils
21 import Type
22 import TypeRep
23 import TyCon
24 import DataCon ( dataConWrapId )
25 import Var
26 import Id ( mkWildId )
27 import MkId ( unwrapFamInstScrut )
28 import PrelNames
29 import TysWiredIn
30 import BasicTypes ( Boxity(..) )
31
32 import Outputable
33 import FastString
34
35 import Control.Monad ( liftM, zipWithM_ )
36
37 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
38 collectAnnTypeArgs expr = go expr []
39 where
40 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
41 go e tys = (e, tys)
42
43 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
44 collectAnnTypeBinders expr = go [] expr
45 where
46 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
47 go bs e = (reverse bs, e)
48
49 isAnnTypeArg :: AnnExpr b ann -> Bool
50 isAnnTypeArg (_, AnnType t) = True
51 isAnnTypeArg _ = False
52
53 isClosureTyCon :: TyCon -> Bool
54 isClosureTyCon tc = tyConName tc == closureTyConName
55
56 splitClosureTy :: Type -> (Type, Type)
57 splitClosureTy ty
58 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
59 , isClosureTyCon tc
60 = (arg_ty, res_ty)
61
62 | otherwise = pprPanic "splitClosureTy" (ppr ty)
63
64 isPArrayTyCon :: TyCon -> Bool
65 isPArrayTyCon tc = tyConName tc == parrayTyConName
66
67 splitPArrayTy :: Type -> Type
68 splitPArrayTy ty
69 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
70 , isPArrayTyCon tc
71 = arg_ty
72
73 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
74
75 mkPADictType :: Type -> VM Type
76 mkPADictType ty
77 = do
78 tc <- builtin paDictTyCon
79 return $ TyConApp tc [ty]
80
81 mkPArrayType :: Type -> VM Type
82 mkPArrayType ty
83 = do
84 tc <- builtin parrayTyCon
85 return $ TyConApp tc [ty]
86
87 paDictArgType :: TyVar -> VM (Maybe Type)
88 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
89 where
90 go ty k | Just k' <- kindView k = go ty k'
91 go ty (FunTy k1 k2)
92 = do
93 tv <- newTyVar FSLIT("a") k1
94 mty1 <- go (TyVarTy tv) k1
95 case mty1 of
96 Just ty1 -> do
97 mty2 <- go (AppTy ty (TyVarTy tv)) k2
98 return $ fmap (ForAllTy tv . FunTy ty1) mty2
99 Nothing -> go ty k2
100
101 go ty k
102 | isLiftedTypeKind k
103 = liftM Just (mkPADictType ty)
104
105 go ty k = return Nothing
106
107 paDictOfType :: Type -> VM CoreExpr
108 paDictOfType ty = paDictOfTyApp ty_fn ty_args
109 where
110 (ty_fn, ty_args) = splitAppTys ty
111
112 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
113 paDictOfTyApp ty_fn ty_args
114 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
115 paDictOfTyApp (TyVarTy tv) ty_args
116 = do
117 dfun <- maybeV (lookupTyVarPA tv)
118 paDFunApply dfun ty_args
119 paDictOfTyApp (TyConApp tc _) ty_args
120 = do
121 pa_class <- builtin paClass
122 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
123 paDFunApply (Var dfun) ty_args'
124 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
125
126 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
127 paDFunApply dfun tys
128 = do
129 dicts <- mapM paDictOfType tys
130 return $ mkApps (mkTyApps dfun tys) dicts
131
132 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
133 paMethod method ty
134 = do
135 fn <- builtin method
136 dict <- paDictOfType ty
137 return $ mkApps (Var fn) [Type ty, dict]
138
139 lengthPA :: CoreExpr -> VM CoreExpr
140 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
141 where
142 ty = splitPArrayTy (exprType x)
143
144 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
145 replicatePA len x = liftM (`mkApps` [len,x])
146 (paMethod replicatePAVar (exprType x))
147
148 emptyPA :: Type -> VM CoreExpr
149 emptyPA = paMethod emptyPAVar
150
151 newLocalVVar :: FastString -> Type -> VM VVar
152 newLocalVVar fs vty
153 = do
154 lty <- mkPArrayType vty
155 vv <- newLocalVar fs vty
156 lv <- newLocalVar fs lty
157 return (vv,lv)
158
159 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
160 polyAbstract tvs p
161 = localV
162 $ do
163 mdicts <- mapM mk_dict_var tvs
164 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
165 p (mk_lams mdicts)
166 where
167 mk_dict_var tv = do
168 r <- paDictArgType tv
169 case r of
170 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
171 Nothing -> return Nothing
172
173 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
174
175 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
176 polyApply expr tys
177 = do
178 dicts <- mapM paDictOfType tys
179 return $ expr `mkTyApps` tys `mkApps` dicts
180
181 polyVApply :: VExpr -> [Type] -> VM VExpr
182 polyVApply expr tys
183 = do
184 dicts <- mapM paDictOfType tys
185 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
186
187 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
188 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
189
190 hoistExpr :: FastString -> CoreExpr -> VM Var
191 hoistExpr fs expr
192 = do
193 var <- newLocalVar fs (exprType expr)
194 updGEnv $ \env ->
195 env { global_bindings = (var, expr) : global_bindings env }
196 return var
197
198 hoistVExpr :: FastString -> VExpr -> VM VVar
199 hoistVExpr fs (ve, le)
200 = do
201 vv <- hoistExpr ('v' `consFS` fs) ve
202 lv <- hoistExpr ('l' `consFS` fs) le
203 return (vv, lv)
204
205 hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
206 hoistPolyVExpr fs tvs p
207 = do
208 expr <- closedV . polyAbstract tvs $ \abstract ->
209 liftM (mapVect abstract) p
210 fn <- hoistVExpr fs expr
211 polyVApply (vVar fn) (mkTyVarTys tvs)
212
213 takeHoisted :: VM [(Var, CoreExpr)]
214 takeHoisted
215 = do
216 env <- readGEnv id
217 setGEnv $ env { global_bindings = [] }
218 return $ global_bindings env
219
220
221 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
222 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
223 = do
224 dict <- paDictOfType env_ty
225 mkv <- builtin mkClosureVar
226 mkl <- builtin mkClosurePVar
227 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
228 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
229
230 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
231 -- where
232 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
233 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
234
235 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
236 buildClosure tvs lv vars arg_ty res_ty mk_body
237 = do
238 (env_ty, env, bind) <- buildEnv lv vars
239 env_bndr <- newLocalVVar FSLIT("env") env_ty
240 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
241
242 fn <- hoistPolyVExpr FSLIT("fn") tvs
243 $ do
244 body <- mk_body
245 body' <- bind (vVar env_bndr)
246 (vVarApps lv body (vars ++ [arg_bndr]))
247 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
248
249 mkClosure arg_ty res_ty env_ty fn env
250
251 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
252 buildEnv lv vvs
253 = do
254 let (ty, venv, vbind) = mkVectEnv tys vs
255 (lenv, lbind) <- mkLiftEnv lv tys ls
256 return (ty, (venv, lenv),
257 \(venv,lenv) (vbody,lbody) ->
258 do
259 let vbody' = vbind venv vbody
260 lbody' <- lbind lenv lbody
261 return (vbody', lbody'))
262 where
263 (vs,ls) = unzip vvs
264 tys = map idType vs
265
266 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
267 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
268 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
269 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
270 \env body -> Case env (mkWildId ty) (exprType body)
271 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
272 where
273 ty = mkCoreTupTy tys
274
275 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
276 mkLiftEnv lv [ty] [v]
277 = return (Var v, \env body ->
278 do
279 len <- lengthPA (Var v)
280 return . Let (NonRec v env)
281 $ Case len lv (exprType body) [(DEFAULT, [], body)])
282
283 -- NOTE: this transparently deals with empty environments
284 mkLiftEnv lv tys vs
285 = do
286 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
287 let [env_con] = tyConDataCons env_tc
288
289 env = Var (dataConWrapId env_con)
290 `mkTyApps` env_tyargs
291 `mkVarApps` (lv : vs)
292
293 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
294 in
295 return $ Case scrut (mkWildId (exprType scrut))
296 (exprType body)
297 [(DataAlt env_con, lv : bndrs, body)]
298 return (env, bind)
299 where
300 vty = mkCoreTupTy tys
301
302 bndrs | null vs = [mkWildId unitTy]
303 | otherwise = vs
304