PA is now an explicit record instead of a typeclass
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 splitClosureTy,
5 mkPADictType, mkPArrayType,
6 paDictArgType, paDictOfType, paDFunType,
7 paMethod, lengthPA, replicatePA, emptyPA, liftPA,
8 polyAbstract, polyApply, polyVApply,
9 lookupPArrayFamInst,
10 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
11 buildClosure, buildClosures,
12 mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon ( dataConWrapId )
27 import Var
28 import Id ( mkWildId )
29 import MkId ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41 where
42 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43 go e tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47 where
48 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49 go bs e = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isId b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _ = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67 , isClosureTyCon tc
68 = (arg_ty, res_ty)
69
70 | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78 , isPArrayTyCon tc
79 = arg_ty
80
81 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85 = do
86 tc <- builtin closureTyCon
87 return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91 = do
92 tc <- builtin closureTyCon
93 return $ foldr (mk tc) res_ty arg_tys
94 where
95 mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99 = do
100 tc <- builtin paTyCon
101 return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105 = do
106 tc <- builtin parrayTyCon
107 return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111 where
112 go ty k | Just k' <- kindView k = go ty k'
113 go ty (FunTy k1 k2)
114 = do
115 tv <- newTyVar FSLIT("a") k1
116 mty1 <- go (TyVarTy tv) k1
117 case mty1 of
118 Just ty1 -> do
119 mty2 <- go (AppTy ty (TyVarTy tv)) k2
120 return $ fmap (ForAllTy tv . FunTy ty1) mty2
121 Nothing -> go ty k2
122
123 go ty k
124 | isLiftedTypeKind k
125 = liftM Just (mkPADictType ty)
126
127 go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131 where
132 (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138 = do
139 dfun <- maybeV (lookupTyVarPA tv)
140 paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142 = do
143 dfun <- maybeV (lookupTyConPA tc)
144 paDFunApply (Var dfun) ty_args
145 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
146
147 paDFunType :: TyCon -> VM Type
148 paDFunType tc
149 = do
150 margs <- mapM paDictArgType tvs
151 res <- mkPADictType (mkTyConApp tc arg_tys)
152 return . mkForAllTys tvs
153 $ mkFunTys [arg | Just arg <- margs] res
154 where
155 tvs = tyConTyVars tc
156 arg_tys = mkTyVarTys tvs
157
158 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
159 paDFunApply dfun tys
160 = do
161 dicts <- mapM paDictOfType tys
162 return $ mkApps (mkTyApps dfun tys) dicts
163
164 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
165 paMethod method ty
166 = do
167 fn <- builtin method
168 dict <- paDictOfType ty
169 return $ mkApps (Var fn) [Type ty, dict]
170
171 lengthPA :: CoreExpr -> VM CoreExpr
172 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
173 where
174 ty = splitPArrayTy (exprType x)
175
176 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
177 replicatePA len x = liftM (`mkApps` [len,x])
178 (paMethod replicatePAVar (exprType x))
179
180 emptyPA :: Type -> VM CoreExpr
181 emptyPA = paMethod emptyPAVar
182
183 liftPA :: CoreExpr -> VM CoreExpr
184 liftPA x
185 = do
186 lc <- builtin liftingContext
187 replicatePA (Var lc) x
188
189 newLocalVVar :: FastString -> Type -> VM VVar
190 newLocalVVar fs vty
191 = do
192 lty <- mkPArrayType vty
193 vv <- newLocalVar fs vty
194 lv <- newLocalVar fs lty
195 return (vv,lv)
196
197 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
198 polyAbstract tvs p
199 = localV
200 $ do
201 mdicts <- mapM mk_dict_var tvs
202 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
203 p (mk_lams mdicts)
204 where
205 mk_dict_var tv = do
206 r <- paDictArgType tv
207 case r of
208 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
209 Nothing -> return Nothing
210
211 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
212
213 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
214 polyApply expr tys
215 = do
216 dicts <- mapM paDictOfType tys
217 return $ expr `mkTyApps` tys `mkApps` dicts
218
219 polyVApply :: VExpr -> [Type] -> VM VExpr
220 polyVApply expr tys
221 = do
222 dicts <- mapM paDictOfType tys
223 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
224
225 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
226 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
227
228 hoistBinding :: Var -> CoreExpr -> VM ()
229 hoistBinding v e = updGEnv $ \env ->
230 env { global_bindings = (v,e) : global_bindings env }
231
232 hoistExpr :: FastString -> CoreExpr -> VM Var
233 hoistExpr fs expr
234 = do
235 var <- newLocalVar fs (exprType expr)
236 hoistBinding var expr
237 return var
238
239 hoistVExpr :: VExpr -> VM VVar
240 hoistVExpr (ve, le)
241 = do
242 fs <- getBindName
243 vv <- hoistExpr ('v' `consFS` fs) ve
244 lv <- hoistExpr ('l' `consFS` fs) le
245 return (vv, lv)
246
247 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
248 hoistPolyVExpr tvs p
249 = do
250 expr <- closedV . polyAbstract tvs $ \abstract ->
251 liftM (mapVect abstract) p
252 fn <- hoistVExpr expr
253 polyVApply (vVar fn) (mkTyVarTys tvs)
254
255 takeHoisted :: VM [(Var, CoreExpr)]
256 takeHoisted
257 = do
258 env <- readGEnv id
259 setGEnv $ env { global_bindings = [] }
260 return $ global_bindings env
261
262 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
263 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
264 = do
265 dict <- paDictOfType env_ty
266 mkv <- builtin mkClosureVar
267 mkl <- builtin mkClosurePVar
268 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
269 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
270
271 mkClosureApp :: VExpr -> VExpr -> VM VExpr
272 mkClosureApp (vclo, lclo) (varg, larg)
273 = do
274 vapply <- builtin applyClosureVar
275 lapply <- builtin applyClosurePVar
276 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
277 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
278 where
279 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
280
281 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
282 buildClosures tvs vars [arg_ty] res_ty mk_body
283 = buildClosure tvs vars arg_ty res_ty mk_body
284 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
285 = do
286 res_ty' <- mkClosureTypes arg_tys res_ty
287 arg <- newLocalVVar FSLIT("x") arg_ty
288 buildClosure tvs vars arg_ty res_ty'
289 . hoistPolyVExpr tvs
290 $ do
291 lc <- builtin liftingContext
292 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
293 return $ vLams lc (vars ++ [arg]) clo
294
295 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
296 -- where
297 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
298 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
299 --
300 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
301 buildClosure tvs vars arg_ty res_ty mk_body
302 = do
303 (env_ty, env, bind) <- buildEnv vars
304 env_bndr <- newLocalVVar FSLIT("env") env_ty
305 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
306
307 fn <- hoistPolyVExpr tvs
308 $ do
309 lc <- builtin liftingContext
310 body <- mk_body
311 body' <- bind (vVar env_bndr)
312 (vVarApps lc body (vars ++ [arg_bndr]))
313 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
314
315 mkClosure arg_ty res_ty env_ty fn env
316
317 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
318 buildEnv vvs
319 = do
320 lc <- builtin liftingContext
321 let (ty, venv, vbind) = mkVectEnv tys vs
322 (lenv, lbind) <- mkLiftEnv lc tys ls
323 return (ty, (venv, lenv),
324 \(venv,lenv) (vbody,lbody) ->
325 do
326 let vbody' = vbind venv vbody
327 lbody' <- lbind lenv lbody
328 return (vbody', lbody'))
329 where
330 (vs,ls) = unzip vvs
331 tys = map idType vs
332
333 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
334 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
335 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
336 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
337 \env body -> Case env (mkWildId ty) (exprType body)
338 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
339 where
340 ty = mkCoreTupTy tys
341
342 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
343 mkLiftEnv lc [ty] [v]
344 = return (Var v, \env body ->
345 do
346 len <- lengthPA (Var v)
347 return . Let (NonRec v env)
348 $ Case len lc (exprType body) [(DEFAULT, [], body)])
349
350 -- NOTE: this transparently deals with empty environments
351 mkLiftEnv lc tys vs
352 = do
353 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
354 let [env_con] = tyConDataCons env_tc
355
356 env = Var (dataConWrapId env_con)
357 `mkTyApps` env_tyargs
358 `mkVarApps` (lc : vs)
359
360 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
361 in
362 return $ Case scrut (mkWildId (exprType scrut))
363 (exprType body)
364 [(DataAlt env_con, lc : bndrs, body)]
365 return (env, bind)
366 where
367 vty = mkCoreTupTy tys
368
369 bndrs | null vs = [mkWildId unitTy]
370 | otherwise = vs
371