More vectorisation-related built-ins
[ghc.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 -- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10 collectAnnValBinders,
11 dataConTagZ, mkDataConTag, mkDataConTagLit,
12
13 newLocalVVar,
14
15 mkBuiltinCo,
16 mkPADictType, mkPArrayType, mkPReprType,
17
18 parrayReprTyCon, parrayReprDataCon, mkVScrut,
19 prDFunOfTyCon,
20 paDictArgType, paDictOfType, paDFunType,
21 paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
22 polyAbstract, polyApply, polyVApply,
23 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
24 buildClosure, buildClosures,
25 mkClosureApp
26 ) where
27
28 #include "HsVersions.h"
29
30 import VectCore
31 import VectMonad
32
33 import DsUtils
34 import CoreSyn
35 import CoreUtils
36 import Coercion
37 import Type
38 import TypeRep
39 import TyCon
40 import DataCon
41 import Var
42 import Id ( mkWildId )
43 import MkId ( unwrapFamInstScrut )
44 import Name ( Name )
45 import PrelNames
46 import TysWiredIn
47 import TysPrim ( intPrimTy )
48 import BasicTypes ( Boxity(..) )
49 import Literal ( Literal, mkMachInt )
50
51 import Outputable
52 import FastString
53
54 import Data.List ( zipWith4 )
55 import Control.Monad ( liftM, liftM2, zipWithM_ )
56
57 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
58 collectAnnTypeArgs expr = go expr []
59 where
60 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
61 go e tys = (e, tys)
62
63 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
64 collectAnnTypeBinders expr = go [] expr
65 where
66 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
67 go bs e = (reverse bs, e)
68
69 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
70 collectAnnValBinders expr = go [] expr
71 where
72 go bs (_, AnnLam b e) | isId b = go (b:bs) e
73 go bs e = (reverse bs, e)
74
75 isAnnTypeArg :: AnnExpr b ann -> Bool
76 isAnnTypeArg (_, AnnType t) = True
77 isAnnTypeArg _ = False
78
79 dataConTagZ :: DataCon -> Int
80 dataConTagZ con = dataConTag con - fIRST_TAG
81
82 mkDataConTagLit :: DataCon -> Literal
83 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
84
85 mkDataConTag :: DataCon -> CoreExpr
86 mkDataConTag = mkIntLitInt . dataConTagZ
87
88 splitPrimTyCon :: Type -> Maybe TyCon
89 splitPrimTyCon ty
90 | Just (tycon, []) <- splitTyConApp_maybe ty
91 , isPrimTyCon tycon
92 = Just tycon
93
94 | otherwise = Nothing
95
96 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
97 mkBuiltinTyConApp get_tc tys
98 = do
99 tc <- builtin get_tc
100 return $ mkTyConApp tc tys
101
102 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
103 mkBuiltinTyConApps get_tc tys ty
104 = do
105 tc <- builtin get_tc
106 return $ foldr (mk tc) ty tys
107 where
108 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
111 mkBuiltinTyConApps1 get_tc dft [] = return dft
112 mkBuiltinTyConApps1 get_tc dft tys
113 = do
114 tc <- builtin get_tc
115 case tys of
116 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
117 _ -> return $ foldr1 (mk tc) tys
118 where
119 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
120
121 mkClosureType :: Type -> Type -> VM Type
122 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
123
124 mkClosureTypes :: [Type] -> Type -> VM Type
125 mkClosureTypes = mkBuiltinTyConApps closureTyCon
126
127 mkPReprType :: Type -> VM Type
128 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
129
130 mkPADictType :: Type -> VM Type
131 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
132
133 mkPArrayType :: Type -> VM Type
134 mkPArrayType ty
135 | Just tycon <- splitPrimTyCon ty
136 = do
137 arr <- traceMaybeV "mkPArrayType" (ppr tycon)
138 $ lookupPrimPArray tycon
139 return $ mkTyConApp arr []
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
143 mkBuiltinCo get_tc
144 = do
145 tc <- builtin get_tc
146 return $ mkTyConApp tc []
147
148 parrayReprTyCon :: Type -> VM (TyCon, [Type])
149 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
150
151 parrayReprDataCon :: Type -> VM (DataCon, [Type])
152 parrayReprDataCon ty
153 = do
154 (tc, arg_tys) <- parrayReprTyCon ty
155 let [dc] = tyConDataCons tc
156 return (dc, arg_tys)
157
158 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
159 mkVScrut (ve, le)
160 = do
161 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
162 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
163
164 prDFunOfTyCon :: TyCon -> VM CoreExpr
165 prDFunOfTyCon tycon
166 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170 where
171 go ty k | Just k' <- kindView k = go ty k'
172 go ty (FunTy k1 k2)
173 = do
174 tv <- newTyVar FSLIT("a") k1
175 mty1 <- go (TyVarTy tv) k1
176 case mty1 of
177 Just ty1 -> do
178 mty2 <- go (AppTy ty (TyVarTy tv)) k2
179 return $ fmap (ForAllTy tv . FunTy ty1) mty2
180 Nothing -> go ty k2
181
182 go ty k
183 | isLiftedTypeKind k
184 = liftM Just (mkPADictType ty)
185
186 go ty k = return Nothing
187
188 paDictOfType :: Type -> VM CoreExpr
189 paDictOfType ty = paDictOfTyApp ty_fn ty_args
190 where
191 (ty_fn, ty_args) = splitAppTys ty
192
193 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
194 paDictOfTyApp ty_fn ty_args
195 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
196 paDictOfTyApp (TyVarTy tv) ty_args
197 = do
198 dfun <- maybeV (lookupTyVarPA tv)
199 paDFunApply dfun ty_args
200 paDictOfTyApp (TyConApp tc _) ty_args
201 = do
202 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
203 paDFunApply (Var dfun) ty_args
204 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
205
206 paDFunType :: TyCon -> VM Type
207 paDFunType tc
208 = do
209 margs <- mapM paDictArgType tvs
210 res <- mkPADictType (mkTyConApp tc arg_tys)
211 return . mkForAllTys tvs
212 $ mkFunTys [arg | Just arg <- margs] res
213 where
214 tvs = tyConTyVars tc
215 arg_tys = mkTyVarTys tvs
216
217 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
218 paDFunApply dfun tys
219 = do
220 dicts <- mapM paDictOfType tys
221 return $ mkApps (mkTyApps dfun tys) dicts
222
223 type PAMethod = (Builtins -> Var, String)
224
225 pa_length = (lengthPAVar, "lengthPA")
226 pa_replicate = (replicatePAVar, "replicatePA")
227 pa_empty = (emptyPAVar, "emptyPA")
228 pa_pack = (packPAVar, "packPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232 | Just tycon <- splitPrimTyCon ty
233 = do
234 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
235 $ lookupPrimMethod tycon name
236 return (Var fn)
237
238 paMethod (method, name) ty
239 = do
240 fn <- builtin method
241 dict <- paDictOfType ty
242 return $ mkApps (Var fn) [Type ty, dict]
243
244 mkPR :: Type -> VM CoreExpr
245 mkPR ty
246 = do
247 fn <- builtin mkPRVar
248 dict <- paDictOfType ty
249 return $ mkApps (Var fn) [Type ty, dict]
250
251 lengthPA :: Type -> CoreExpr -> VM CoreExpr
252 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256 (paMethod pa_replicate (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod pa_empty
260
261 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
262 packPA ty xs len sel = liftM (`mkApps` [len, sel])
263 (paMethod pa_pack ty)
264
265 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
266 -> VM CoreExpr
267 combinePA ty len sel is xs
268 = liftM (`mkApps` (len : sel : is : xs))
269 (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
270 where
271 n = length xs
272
273 liftPA :: CoreExpr -> VM CoreExpr
274 liftPA x
275 = do
276 lc <- builtin liftingContext
277 replicatePA (Var lc) x
278
279 newLocalVVar :: FastString -> Type -> VM VVar
280 newLocalVVar fs vty
281 = do
282 lty <- mkPArrayType vty
283 vv <- newLocalVar fs vty
284 lv <- newLocalVar fs lty
285 return (vv,lv)
286
287 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
288 polyAbstract tvs p
289 = localV
290 $ do
291 mdicts <- mapM mk_dict_var tvs
292 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
293 p (mk_lams mdicts)
294 where
295 mk_dict_var tv = do
296 r <- paDictArgType tv
297 case r of
298 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
299 Nothing -> return Nothing
300
301 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
302
303 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
304 polyApply expr tys
305 = do
306 dicts <- mapM paDictOfType tys
307 return $ expr `mkTyApps` tys `mkApps` dicts
308
309 polyVApply :: VExpr -> [Type] -> VM VExpr
310 polyVApply expr tys
311 = do
312 dicts <- mapM paDictOfType tys
313 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
314
315 hoistBinding :: Var -> CoreExpr -> VM ()
316 hoistBinding v e = updGEnv $ \env ->
317 env { global_bindings = (v,e) : global_bindings env }
318
319 hoistExpr :: FastString -> CoreExpr -> VM Var
320 hoistExpr fs expr
321 = do
322 var <- newLocalVar fs (exprType expr)
323 hoistBinding var expr
324 return var
325
326 hoistVExpr :: VExpr -> VM VVar
327 hoistVExpr (ve, le)
328 = do
329 fs <- getBindName
330 vv <- hoistExpr ('v' `consFS` fs) ve
331 lv <- hoistExpr ('l' `consFS` fs) le
332 return (vv, lv)
333
334 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
335 hoistPolyVExpr tvs p
336 = do
337 expr <- closedV . polyAbstract tvs $ \abstract ->
338 liftM (mapVect abstract) p
339 fn <- hoistVExpr expr
340 polyVApply (vVar fn) (mkTyVarTys tvs)
341
342 takeHoisted :: VM [(Var, CoreExpr)]
343 takeHoisted
344 = do
345 env <- readGEnv id
346 setGEnv $ env { global_bindings = [] }
347 return $ global_bindings env
348
349 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
350 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
351 = do
352 dict <- paDictOfType env_ty
353 mkv <- builtin mkClosureVar
354 mkl <- builtin mkClosurePVar
355 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
356 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
357
358 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
359 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
360 = do
361 vapply <- builtin applyClosureVar
362 lapply <- builtin applyClosurePVar
363 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
364 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
365
366 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
367 buildClosures tvs vars [] res_ty mk_body
368 = mk_body
369 buildClosures tvs vars [arg_ty] res_ty mk_body
370 = buildClosure tvs vars arg_ty res_ty mk_body
371 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
372 = do
373 res_ty' <- mkClosureTypes arg_tys res_ty
374 arg <- newLocalVVar FSLIT("x") arg_ty
375 buildClosure tvs vars arg_ty res_ty'
376 . hoistPolyVExpr tvs
377 $ do
378 lc <- builtin liftingContext
379 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
380 return $ vLams lc (vars ++ [arg]) clo
381
382 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
383 -- where
384 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
385 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
386 --
387 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
388 buildClosure tvs vars arg_ty res_ty mk_body
389 = do
390 (env_ty, env, bind) <- buildEnv vars
391 env_bndr <- newLocalVVar FSLIT("env") env_ty
392 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
393
394 fn <- hoistPolyVExpr tvs
395 $ do
396 lc <- builtin liftingContext
397 body <- mk_body
398 body' <- bind (vVar env_bndr)
399 (vVarApps lc body (vars ++ [arg_bndr]))
400 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
401
402 mkClosure arg_ty res_ty env_ty fn env
403
404 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
405 buildEnv vvs
406 = do
407 lc <- builtin liftingContext
408 let (ty, venv, vbind) = mkVectEnv tys vs
409 (lenv, lbind) <- mkLiftEnv lc tys ls
410 return (ty, (venv, lenv),
411 \(venv,lenv) (vbody,lbody) ->
412 do
413 let vbody' = vbind venv vbody
414 lbody' <- lbind lenv lbody
415 return (vbody', lbody'))
416 where
417 (vs,ls) = unzip vvs
418 tys = map idType vs
419
420 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
421 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
422 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
423 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
424 \env body -> Case env (mkWildId ty) (exprType body)
425 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
426 where
427 ty = mkCoreTupTy tys
428
429 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
430 mkLiftEnv lc [ty] [v]
431 = return (Var v, \env body ->
432 do
433 len <- lengthPA ty (Var v)
434 return . Let (NonRec v env)
435 $ Case len lc (exprType body) [(DEFAULT, [], body)])
436
437 -- NOTE: this transparently deals with empty environments
438 mkLiftEnv lc tys vs
439 = do
440 (env_tc, env_tyargs) <- parrayReprTyCon vty
441 let [env_con] = tyConDataCons env_tc
442
443 env = Var (dataConWrapId env_con)
444 `mkTyApps` env_tyargs
445 `mkVarApps` (lc : vs)
446
447 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
448 in
449 return $ Case scrut (mkWildId (exprType scrut))
450 (exprType body)
451 [(DataAlt env_con, lc : bndrs, body)]
452 return (env, bind)
453 where
454 vty = mkCoreTupTy tys
455
456 bndrs | null vs = [mkWildId unitTy]
457 | otherwise = vs
458