Utility functions for vectorisation
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 mkDataConTag,
5 splitClosureTy,
6 mkPlusType, mkPlusTypes, mkCrossType, mkCrossTypes, mkEmbedType,
7 mkPlusAlts, mkCrosses, mkEmbed,
8 mkPADictType, mkPArrayType,
9 parrayReprTyCon, parrayReprDataCon, mkVScrut,
10 paDictArgType, paDictOfType, paDFunType,
11 paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12 polyAbstract, polyApply, polyVApply,
13 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14 buildClosure, buildClosures,
15 mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Type
27 import TypeRep
28 import TyCon
29 import DataCon ( DataCon, dataConWrapId, dataConTag )
30 import Var
31 import Id ( mkWildId )
32 import MkId ( unwrapFamInstScrut )
33 import PrelNames
34 import TysWiredIn
35 import BasicTypes ( Boxity(..) )
36
37 import Outputable
38 import FastString
39
40 import Control.Monad ( liftM, zipWithM_ )
41
42 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
43 collectAnnTypeArgs expr = go expr []
44 where
45 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46 go e tys = (e, tys)
47
48 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
49 collectAnnTypeBinders expr = go [] expr
50 where
51 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
52 go bs e = (reverse bs, e)
53
54 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnValBinders expr = go [] expr
56 where
57 go bs (_, AnnLam b e) | isId b = go (b:bs) e
58 go bs e = (reverse bs, e)
59
60 isAnnTypeArg :: AnnExpr b ann -> Bool
61 isAnnTypeArg (_, AnnType t) = True
62 isAnnTypeArg _ = False
63
64 mkDataConTag :: DataCon -> CoreExpr
65 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
66
67 isClosureTyCon :: TyCon -> Bool
68 isClosureTyCon tc = tyConName tc == closureTyConName
69
70 splitClosureTy :: Type -> (Type, Type)
71 splitClosureTy ty
72 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
73 , isClosureTyCon tc
74 = (arg_ty, res_ty)
75
76 | otherwise = pprPanic "splitClosureTy" (ppr ty)
77
78 isPArrayTyCon :: TyCon -> Bool
79 isPArrayTyCon tc = tyConName tc == parrayTyConName
80
81 splitPArrayTy :: Type -> Type
82 splitPArrayTy ty
83 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
84 , isPArrayTyCon tc
85 = arg_ty
86
87 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
88
89 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
90 mkBuiltinTyConApp get_tc tys
91 = do
92 tc <- builtin get_tc
93 return $ mkTyConApp tc tys
94
95 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
96 mkBuiltinTyConApps get_tc tys ty
97 = do
98 tc <- builtin get_tc
99 return $ foldr (mk tc) ty tys
100 where
101 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
102
103 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
104 mkBuiltinTyConApps1 get_tc dft [] = return dft
105 mkBuiltinTyConApps1 get_tc dft tys
106 = do
107 tc <- builtin get_tc
108 case tys of
109 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
110 _ -> return $ foldr1 (mk tc) tys
111 where
112 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
113
114 mkBuiltinDataConApp :: (Builtins -> DataCon) -> [CoreExpr] -> VM CoreExpr
115 mkBuiltinDataConApp get_dc args
116 = do
117 dc <- builtin get_dc
118 return $ mkConApp dc args
119
120 mkPlusType :: Type -> Type -> VM Type
121 mkPlusType ty1 ty2 = mkBuiltinTyConApp plusTyCon [ty1, ty2]
122
123 mkPlusTypes :: Type -> [Type] -> VM Type
124 mkPlusTypes = mkBuiltinTyConApps1 plusTyCon
125
126 mkPlusAlts :: [CoreExpr] -> VM [CoreExpr]
127 mkPlusAlts [] = return []
128 mkPlusAlts exprs
129 = do
130 plus_tc <- builtin plusTyCon
131 left_dc <- builtin leftDataCon
132 right_dc <- builtin rightDataCon
133
134 let go [expr] = ([expr], exprType expr)
135 go (expr : exprs)
136 | (alts, right_ty) <- go exprs
137 = (mkConApp left_dc [Type left_ty, Type right_ty, expr]
138 : [mkConApp right_dc [Type left_ty, Type right_ty, alt]
139 | alt <- alts],
140 mkTyConApp plus_tc [left_ty, right_ty])
141 where
142 left_ty = exprType expr
143
144 return . fst $ go exprs
145
146 mkCrossType :: Type -> Type -> VM Type
147 mkCrossType ty1 ty2 = mkBuiltinTyConApp crossTyCon [ty1, ty2]
148
149 mkCrossTypes :: Type -> [Type] -> VM Type
150 mkCrossTypes = mkBuiltinTyConApps1 crossTyCon
151
152 mkCrosses :: [CoreExpr] -> VM CoreExpr
153 mkCrosses [] = return (Var unitDataConId)
154 mkCrosses exprs
155 = do
156 cross_tc <- builtin crossTyCon
157 cross_dc <- builtin crossDataCon
158
159 let mk (left, left_ty) (right, right_ty)
160 = (mkConApp cross_dc [Type left_ty, Type right_ty, left, right],
161 mkTyConApp cross_tc [left_ty, right_ty])
162
163 return . fst
164 $ foldr1 mk [(expr, exprType expr) | expr <- exprs]
165
166 mkEmbedType :: Type -> VM Type
167 mkEmbedType ty = mkBuiltinTyConApp embedTyCon [ty]
168
169 mkEmbed :: CoreExpr -> VM CoreExpr
170 mkEmbed expr = mkBuiltinDataConApp embedDataCon
171 [Type $ exprType expr, expr]
172
173 mkClosureType :: Type -> Type -> VM Type
174 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
175
176 mkClosureTypes :: [Type] -> Type -> VM Type
177 mkClosureTypes = mkBuiltinTyConApps closureTyCon
178
179 mkPADictType :: Type -> VM Type
180 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
181
182 mkPArrayType :: Type -> VM Type
183 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
184
185 parrayReprTyCon :: Type -> VM (TyCon, [Type])
186 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
187
188 parrayReprDataCon :: Type -> VM (DataCon, [Type])
189 parrayReprDataCon ty
190 = do
191 (tc, arg_tys) <- parrayReprTyCon ty
192 let [dc] = tyConDataCons tc
193 return (dc, arg_tys)
194
195 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
196 mkVScrut (ve, le)
197 = do
198 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
199 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
200
201 paDictArgType :: TyVar -> VM (Maybe Type)
202 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
203 where
204 go ty k | Just k' <- kindView k = go ty k'
205 go ty (FunTy k1 k2)
206 = do
207 tv <- newTyVar FSLIT("a") k1
208 mty1 <- go (TyVarTy tv) k1
209 case mty1 of
210 Just ty1 -> do
211 mty2 <- go (AppTy ty (TyVarTy tv)) k2
212 return $ fmap (ForAllTy tv . FunTy ty1) mty2
213 Nothing -> go ty k2
214
215 go ty k
216 | isLiftedTypeKind k
217 = liftM Just (mkPADictType ty)
218
219 go ty k = return Nothing
220
221 paDictOfType :: Type -> VM CoreExpr
222 paDictOfType ty = paDictOfTyApp ty_fn ty_args
223 where
224 (ty_fn, ty_args) = splitAppTys ty
225
226 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
227 paDictOfTyApp ty_fn ty_args
228 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
229 paDictOfTyApp (TyVarTy tv) ty_args
230 = do
231 dfun <- maybeV (lookupTyVarPA tv)
232 paDFunApply dfun ty_args
233 paDictOfTyApp (TyConApp tc _) ty_args
234 = do
235 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
236 paDFunApply (Var dfun) ty_args
237 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
238
239 paDFunType :: TyCon -> VM Type
240 paDFunType tc
241 = do
242 margs <- mapM paDictArgType tvs
243 res <- mkPADictType (mkTyConApp tc arg_tys)
244 return . mkForAllTys tvs
245 $ mkFunTys [arg | Just arg <- margs] res
246 where
247 tvs = tyConTyVars tc
248 arg_tys = mkTyVarTys tvs
249
250 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
251 paDFunApply dfun tys
252 = do
253 dicts <- mapM paDictOfType tys
254 return $ mkApps (mkTyApps dfun tys) dicts
255
256 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
257 paMethod method ty
258 = do
259 fn <- builtin method
260 dict <- paDictOfType ty
261 return $ mkApps (Var fn) [Type ty, dict]
262
263 lengthPA :: CoreExpr -> VM CoreExpr
264 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
265 where
266 ty = splitPArrayTy (exprType x)
267
268 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
269 replicatePA len x = liftM (`mkApps` [len,x])
270 (paMethod replicatePAVar (exprType x))
271
272 emptyPA :: Type -> VM CoreExpr
273 emptyPA = paMethod emptyPAVar
274
275 liftPA :: CoreExpr -> VM CoreExpr
276 liftPA x
277 = do
278 lc <- builtin liftingContext
279 replicatePA (Var lc) x
280
281 newLocalVVar :: FastString -> Type -> VM VVar
282 newLocalVVar fs vty
283 = do
284 lty <- mkPArrayType vty
285 vv <- newLocalVar fs vty
286 lv <- newLocalVar fs lty
287 return (vv,lv)
288
289 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
290 polyAbstract tvs p
291 = localV
292 $ do
293 mdicts <- mapM mk_dict_var tvs
294 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
295 p (mk_lams mdicts)
296 where
297 mk_dict_var tv = do
298 r <- paDictArgType tv
299 case r of
300 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
301 Nothing -> return Nothing
302
303 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
304
305 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
306 polyApply expr tys
307 = do
308 dicts <- mapM paDictOfType tys
309 return $ expr `mkTyApps` tys `mkApps` dicts
310
311 polyVApply :: VExpr -> [Type] -> VM VExpr
312 polyVApply expr tys
313 = do
314 dicts <- mapM paDictOfType tys
315 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
316
317 hoistBinding :: Var -> CoreExpr -> VM ()
318 hoistBinding v e = updGEnv $ \env ->
319 env { global_bindings = (v,e) : global_bindings env }
320
321 hoistExpr :: FastString -> CoreExpr -> VM Var
322 hoistExpr fs expr
323 = do
324 var <- newLocalVar fs (exprType expr)
325 hoistBinding var expr
326 return var
327
328 hoistVExpr :: VExpr -> VM VVar
329 hoistVExpr (ve, le)
330 = do
331 fs <- getBindName
332 vv <- hoistExpr ('v' `consFS` fs) ve
333 lv <- hoistExpr ('l' `consFS` fs) le
334 return (vv, lv)
335
336 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
337 hoistPolyVExpr tvs p
338 = do
339 expr <- closedV . polyAbstract tvs $ \abstract ->
340 liftM (mapVect abstract) p
341 fn <- hoistVExpr expr
342 polyVApply (vVar fn) (mkTyVarTys tvs)
343
344 takeHoisted :: VM [(Var, CoreExpr)]
345 takeHoisted
346 = do
347 env <- readGEnv id
348 setGEnv $ env { global_bindings = [] }
349 return $ global_bindings env
350
351 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
352 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
353 = do
354 dict <- paDictOfType env_ty
355 mkv <- builtin mkClosureVar
356 mkl <- builtin mkClosurePVar
357 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
358 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
359
360 mkClosureApp :: VExpr -> VExpr -> VM VExpr
361 mkClosureApp (vclo, lclo) (varg, larg)
362 = do
363 vapply <- builtin applyClosureVar
364 lapply <- builtin applyClosurePVar
365 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
366 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
367 where
368 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
369
370 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
371 buildClosures tvs vars [] res_ty mk_body
372 = mk_body
373 buildClosures tvs vars [arg_ty] res_ty mk_body
374 = buildClosure tvs vars arg_ty res_ty mk_body
375 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
376 = do
377 res_ty' <- mkClosureTypes arg_tys res_ty
378 arg <- newLocalVVar FSLIT("x") arg_ty
379 buildClosure tvs vars arg_ty res_ty'
380 . hoistPolyVExpr tvs
381 $ do
382 lc <- builtin liftingContext
383 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
384 return $ vLams lc (vars ++ [arg]) clo
385
386 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
387 -- where
388 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
389 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
390 --
391 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
392 buildClosure tvs vars arg_ty res_ty mk_body
393 = do
394 (env_ty, env, bind) <- buildEnv vars
395 env_bndr <- newLocalVVar FSLIT("env") env_ty
396 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
397
398 fn <- hoistPolyVExpr tvs
399 $ do
400 lc <- builtin liftingContext
401 body <- mk_body
402 body' <- bind (vVar env_bndr)
403 (vVarApps lc body (vars ++ [arg_bndr]))
404 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
405
406 mkClosure arg_ty res_ty env_ty fn env
407
408 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
409 buildEnv vvs
410 = do
411 lc <- builtin liftingContext
412 let (ty, venv, vbind) = mkVectEnv tys vs
413 (lenv, lbind) <- mkLiftEnv lc tys ls
414 return (ty, (venv, lenv),
415 \(venv,lenv) (vbody,lbody) ->
416 do
417 let vbody' = vbind venv vbody
418 lbody' <- lbind lenv lbody
419 return (vbody', lbody'))
420 where
421 (vs,ls) = unzip vvs
422 tys = map idType vs
423
424 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
425 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
426 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
427 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
428 \env body -> Case env (mkWildId ty) (exprType body)
429 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
430 where
431 ty = mkCoreTupTy tys
432
433 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
434 mkLiftEnv lc [ty] [v]
435 = return (Var v, \env body ->
436 do
437 len <- lengthPA (Var v)
438 return . Let (NonRec v env)
439 $ Case len lc (exprType body) [(DEFAULT, [], body)])
440
441 -- NOTE: this transparently deals with empty environments
442 mkLiftEnv lc tys vs
443 = do
444 (env_tc, env_tyargs) <- parrayReprTyCon vty
445 let [env_con] = tyConDataCons env_tc
446
447 env = Var (dataConWrapId env_con)
448 `mkTyApps` env_tyargs
449 `mkVarApps` (lc : vs)
450
451 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
452 in
453 return $ Case scrut (mkWildId (exprType scrut))
454 (exprType body)
455 [(DataAlt env_con, lc : bndrs, body)]
456 return (env, bind)
457 where
458 vty = mkCoreTupTy tys
459
460 bndrs | null vs = [mkWildId unitTy]
461 | otherwise = vs
462