Thread lifting context implicitly in the vectorisation monad
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 splitClosureTy,
5 mkPADictType, mkPArrayType,
6 paDictArgType, paDictOfType,
7 paMethod, lengthPA, replicatePA, emptyPA, liftPA,
8 polyAbstract, polyApply, polyVApply,
9 lookupPArrayFamInst,
10 hoistExpr, hoistPolyVExpr, takeHoisted,
11 buildClosure, buildClosures,
12 mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon ( dataConWrapId )
27 import Var
28 import Id ( mkWildId )
29 import MkId ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41 where
42 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43 go e tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47 where
48 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49 go bs e = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isId b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _ = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67 , isClosureTyCon tc
68 = (arg_ty, res_ty)
69
70 | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78 , isPArrayTyCon tc
79 = arg_ty
80
81 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85 = do
86 tc <- builtin closureTyCon
87 return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91 = do
92 tc <- builtin closureTyCon
93 return $ foldr (mk tc) res_ty arg_tys
94 where
95 mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99 = do
100 tc <- builtin paDictTyCon
101 return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105 = do
106 tc <- builtin parrayTyCon
107 return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111 where
112 go ty k | Just k' <- kindView k = go ty k'
113 go ty (FunTy k1 k2)
114 = do
115 tv <- newTyVar FSLIT("a") k1
116 mty1 <- go (TyVarTy tv) k1
117 case mty1 of
118 Just ty1 -> do
119 mty2 <- go (AppTy ty (TyVarTy tv)) k2
120 return $ fmap (ForAllTy tv . FunTy ty1) mty2
121 Nothing -> go ty k2
122
123 go ty k
124 | isLiftedTypeKind k
125 = liftM Just (mkPADictType ty)
126
127 go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131 where
132 (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138 = do
139 dfun <- maybeV (lookupTyVarPA tv)
140 paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142 = do
143 pa_class <- builtin paClass
144 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
145 paDFunApply (Var dfun) ty_args'
146 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
147
148 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
149 paDFunApply dfun tys
150 = do
151 dicts <- mapM paDictOfType tys
152 return $ mkApps (mkTyApps dfun tys) dicts
153
154 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
155 paMethod method ty
156 = do
157 fn <- builtin method
158 dict <- paDictOfType ty
159 return $ mkApps (Var fn) [Type ty, dict]
160
161 lengthPA :: CoreExpr -> VM CoreExpr
162 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
163 where
164 ty = splitPArrayTy (exprType x)
165
166 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
167 replicatePA len x = liftM (`mkApps` [len,x])
168 (paMethod replicatePAVar (exprType x))
169
170 emptyPA :: Type -> VM CoreExpr
171 emptyPA = paMethod emptyPAVar
172
173 liftPA :: CoreExpr -> VM CoreExpr
174 liftPA x
175 = do
176 lc <- builtin liftingContext
177 replicatePA (Var lc) x
178
179 newLocalVVar :: FastString -> Type -> VM VVar
180 newLocalVVar fs vty
181 = do
182 lty <- mkPArrayType vty
183 vv <- newLocalVar fs vty
184 lv <- newLocalVar fs lty
185 return (vv,lv)
186
187 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
188 polyAbstract tvs p
189 = localV
190 $ do
191 mdicts <- mapM mk_dict_var tvs
192 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
193 p (mk_lams mdicts)
194 where
195 mk_dict_var tv = do
196 r <- paDictArgType tv
197 case r of
198 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
199 Nothing -> return Nothing
200
201 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
202
203 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
204 polyApply expr tys
205 = do
206 dicts <- mapM paDictOfType tys
207 return $ expr `mkTyApps` tys `mkApps` dicts
208
209 polyVApply :: VExpr -> [Type] -> VM VExpr
210 polyVApply expr tys
211 = do
212 dicts <- mapM paDictOfType tys
213 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
214
215 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
216 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
217
218 hoistExpr :: FastString -> CoreExpr -> VM Var
219 hoistExpr fs expr
220 = do
221 var <- newLocalVar fs (exprType expr)
222 updGEnv $ \env ->
223 env { global_bindings = (var, expr) : global_bindings env }
224 return var
225
226 hoistVExpr :: VExpr -> VM VVar
227 hoistVExpr (ve, le)
228 = do
229 fs <- getBindName
230 vv <- hoistExpr ('v' `consFS` fs) ve
231 lv <- hoistExpr ('l' `consFS` fs) le
232 return (vv, lv)
233
234 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
235 hoistPolyVExpr tvs p
236 = do
237 expr <- closedV . polyAbstract tvs $ \abstract ->
238 liftM (mapVect abstract) p
239 fn <- hoistVExpr expr
240 polyVApply (vVar fn) (mkTyVarTys tvs)
241
242 takeHoisted :: VM [(Var, CoreExpr)]
243 takeHoisted
244 = do
245 env <- readGEnv id
246 setGEnv $ env { global_bindings = [] }
247 return $ global_bindings env
248
249 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
250 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
251 = do
252 dict <- paDictOfType env_ty
253 mkv <- builtin mkClosureVar
254 mkl <- builtin mkClosurePVar
255 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
256 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
257
258 mkClosureApp :: VExpr -> VExpr -> VM VExpr
259 mkClosureApp (vclo, lclo) (varg, larg)
260 = do
261 vapply <- builtin applyClosureVar
262 lapply <- builtin applyClosurePVar
263 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
264 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
265 where
266 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
267
268 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
269 buildClosures tvs vars [arg_ty] res_ty mk_body
270 = buildClosure tvs vars arg_ty res_ty mk_body
271 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
272 = do
273 res_ty' <- mkClosureTypes arg_tys res_ty
274 arg <- newLocalVVar FSLIT("x") arg_ty
275 buildClosure tvs vars arg_ty res_ty'
276 . hoistPolyVExpr tvs
277 $ do
278 lc <- builtin liftingContext
279 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
280 return $ vLams lc (vars ++ [arg]) clo
281
282 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
283 -- where
284 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
285 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
286 --
287 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
288 buildClosure tvs vars arg_ty res_ty mk_body
289 = do
290 (env_ty, env, bind) <- buildEnv vars
291 env_bndr <- newLocalVVar FSLIT("env") env_ty
292 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
293
294 fn <- hoistPolyVExpr tvs
295 $ do
296 lc <- builtin liftingContext
297 body <- mk_body
298 body' <- bind (vVar env_bndr)
299 (vVarApps lc body (vars ++ [arg_bndr]))
300 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
301
302 mkClosure arg_ty res_ty env_ty fn env
303
304 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
305 buildEnv vvs
306 = do
307 lc <- builtin liftingContext
308 let (ty, venv, vbind) = mkVectEnv tys vs
309 (lenv, lbind) <- mkLiftEnv lc tys ls
310 return (ty, (venv, lenv),
311 \(venv,lenv) (vbody,lbody) ->
312 do
313 let vbody' = vbind venv vbody
314 lbody' <- lbind lenv lbody
315 return (vbody', lbody'))
316 where
317 (vs,ls) = unzip vvs
318 tys = map idType vs
319
320 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
321 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
322 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
323 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
324 \env body -> Case env (mkWildId ty) (exprType body)
325 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
326 where
327 ty = mkCoreTupTy tys
328
329 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
330 mkLiftEnv lc [ty] [v]
331 = return (Var v, \env body ->
332 do
333 len <- lengthPA (Var v)
334 return . Let (NonRec v env)
335 $ Case len lc (exprType body) [(DEFAULT, [], body)])
336
337 -- NOTE: this transparently deals with empty environments
338 mkLiftEnv lc tys vs
339 = do
340 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
341 let [env_con] = tyConDataCons env_tc
342
343 env = Var (dataConWrapId env_con)
344 `mkTyApps` env_tyargs
345 `mkVarApps` (lc : vs)
346
347 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
348 in
349 return $ Case scrut (mkWildId (exprType scrut))
350 (exprType body)
351 [(DataAlt env_con, lc : bndrs, body)]
352 return (env, bind)
353 where
354 vty = mkCoreTupTy tys
355
356 bndrs | null vs = [mkWildId unitTy]
357 | otherwise = vs
358