bdee5ea87c608715972fd928e4052e3592971217
[ghc.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 -- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10 collectAnnValBinders,
11 dataConTagZ, mkDataConTag, mkDataConTagLit,
12
13 newLocalVVar,
14
15 mkBuiltinCo,
16 mkPADictType, mkPArrayType, mkPReprType,
17
18 parrayReprTyCon, parrayReprDataCon, mkVScrut,
19 prDFunOfTyCon,
20 paDictArgType, paDictOfType, paDFunType,
21 paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, liftPA,
22 polyAbstract, polyApply, polyVApply,
23 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
24 buildClosure, buildClosures,
25 mkClosureApp
26 ) where
27
28 #include "HsVersions.h"
29
30 import VectCore
31 import VectMonad
32
33 import DsUtils
34 import CoreSyn
35 import CoreUtils
36 import Coercion
37 import Type
38 import TypeRep
39 import TyCon
40 import DataCon
41 import Var
42 import Id ( mkWildId )
43 import MkId ( unwrapFamInstScrut )
44 import Name ( Name )
45 import PrelNames
46 import TysWiredIn
47 import TysPrim ( intPrimTy )
48 import BasicTypes ( Boxity(..) )
49 import Literal ( Literal, mkMachInt )
50
51 import Outputable
52 import FastString
53
54 import Data.List ( zipWith4 )
55 import Control.Monad ( liftM, liftM2, zipWithM_ )
56
57 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
58 collectAnnTypeArgs expr = go expr []
59 where
60 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
61 go e tys = (e, tys)
62
63 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
64 collectAnnTypeBinders expr = go [] expr
65 where
66 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
67 go bs e = (reverse bs, e)
68
69 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
70 collectAnnValBinders expr = go [] expr
71 where
72 go bs (_, AnnLam b e) | isId b = go (b:bs) e
73 go bs e = (reverse bs, e)
74
75 isAnnTypeArg :: AnnExpr b ann -> Bool
76 isAnnTypeArg (_, AnnType t) = True
77 isAnnTypeArg _ = False
78
79 dataConTagZ :: DataCon -> Int
80 dataConTagZ con = dataConTag con - fIRST_TAG
81
82 mkDataConTagLit :: DataCon -> Literal
83 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
84
85 mkDataConTag :: DataCon -> CoreExpr
86 mkDataConTag = mkIntLitInt . dataConTagZ
87
88 splitPrimTyCon :: Type -> Maybe TyCon
89 splitPrimTyCon ty
90 | Just (tycon, []) <- splitTyConApp_maybe ty
91 , isPrimTyCon tycon
92 = Just tycon
93
94 | otherwise = Nothing
95
96 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
97 mkBuiltinTyConApp get_tc tys
98 = do
99 tc <- builtin get_tc
100 return $ mkTyConApp tc tys
101
102 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
103 mkBuiltinTyConApps get_tc tys ty
104 = do
105 tc <- builtin get_tc
106 return $ foldr (mk tc) ty tys
107 where
108 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
111 mkBuiltinTyConApps1 get_tc dft [] = return dft
112 mkBuiltinTyConApps1 get_tc dft tys
113 = do
114 tc <- builtin get_tc
115 case tys of
116 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
117 _ -> return $ foldr1 (mk tc) tys
118 where
119 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
120
121 mkClosureType :: Type -> Type -> VM Type
122 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
123
124 mkClosureTypes :: [Type] -> Type -> VM Type
125 mkClosureTypes = mkBuiltinTyConApps closureTyCon
126
127 mkPReprType :: Type -> VM Type
128 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
129
130 mkPADictType :: Type -> VM Type
131 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
132
133 mkPArrayType :: Type -> VM Type
134 mkPArrayType ty
135 | Just tycon <- splitPrimTyCon ty
136 = do
137 arr <- traceMaybeV "mkPArrayType" (ppr tycon)
138 $ lookupPrimPArray tycon
139 return $ mkTyConApp arr []
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
143 mkBuiltinCo get_tc
144 = do
145 tc <- builtin get_tc
146 return $ mkTyConApp tc []
147
148 parrayReprTyCon :: Type -> VM (TyCon, [Type])
149 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
150
151 parrayReprDataCon :: Type -> VM (DataCon, [Type])
152 parrayReprDataCon ty
153 = do
154 (tc, arg_tys) <- parrayReprTyCon ty
155 let [dc] = tyConDataCons tc
156 return (dc, arg_tys)
157
158 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
159 mkVScrut (ve, le)
160 = do
161 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
162 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
163
164 prDFunOfTyCon :: TyCon -> VM CoreExpr
165 prDFunOfTyCon tycon
166 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170 where
171 go ty k | Just k' <- kindView k = go ty k'
172 go ty (FunTy k1 k2)
173 = do
174 tv <- newTyVar FSLIT("a") k1
175 mty1 <- go (TyVarTy tv) k1
176 case mty1 of
177 Just ty1 -> do
178 mty2 <- go (AppTy ty (TyVarTy tv)) k2
179 return $ fmap (ForAllTy tv . FunTy ty1) mty2
180 Nothing -> go ty k2
181
182 go ty k
183 | isLiftedTypeKind k
184 = liftM Just (mkPADictType ty)
185
186 go ty k = return Nothing
187
188 paDictOfType :: Type -> VM CoreExpr
189 paDictOfType ty = paDictOfTyApp ty_fn ty_args
190 where
191 (ty_fn, ty_args) = splitAppTys ty
192
193 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
194 paDictOfTyApp ty_fn ty_args
195 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
196 paDictOfTyApp (TyVarTy tv) ty_args
197 = do
198 dfun <- maybeV (lookupTyVarPA tv)
199 paDFunApply dfun ty_args
200 paDictOfTyApp (TyConApp tc _) ty_args
201 = do
202 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
203 paDFunApply (Var dfun) ty_args
204 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
205
206 paDFunType :: TyCon -> VM Type
207 paDFunType tc
208 = do
209 margs <- mapM paDictArgType tvs
210 res <- mkPADictType (mkTyConApp tc arg_tys)
211 return . mkForAllTys tvs
212 $ mkFunTys [arg | Just arg <- margs] res
213 where
214 tvs = tyConTyVars tc
215 arg_tys = mkTyVarTys tvs
216
217 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
218 paDFunApply dfun tys
219 = do
220 dicts <- mapM paDictOfType tys
221 return $ mkApps (mkTyApps dfun tys) dicts
222
223 type PAMethod = (Builtins -> Var, String)
224
225 pa_length = (lengthPAVar, "lengthPA")
226 pa_replicate = (replicatePAVar, "replicatePA")
227 pa_empty = (emptyPAVar, "emptyPA")
228 pa_pack = (packPAVar, "packPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232 | Just tycon <- splitPrimTyCon ty
233 = do
234 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
235 $ lookupPrimMethod tycon name
236 return (Var fn)
237
238 paMethod (method, name) ty
239 = do
240 fn <- builtin method
241 dict <- paDictOfType ty
242 return $ mkApps (Var fn) [Type ty, dict]
243
244 mkPR :: Type -> VM CoreExpr
245 mkPR ty
246 = do
247 fn <- builtin mkPRVar
248 dict <- paDictOfType ty
249 return $ mkApps (Var fn) [Type ty, dict]
250
251 lengthPA :: Type -> CoreExpr -> VM CoreExpr
252 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256 (paMethod pa_replicate (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod pa_empty
260
261 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
262 packPA ty xs len sel = liftM (`mkApps` [len, sel])
263 (paMethod pa_pack ty)
264
265 liftPA :: CoreExpr -> VM CoreExpr
266 liftPA x
267 = do
268 lc <- builtin liftingContext
269 replicatePA (Var lc) x
270
271 newLocalVVar :: FastString -> Type -> VM VVar
272 newLocalVVar fs vty
273 = do
274 lty <- mkPArrayType vty
275 vv <- newLocalVar fs vty
276 lv <- newLocalVar fs lty
277 return (vv,lv)
278
279 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
280 polyAbstract tvs p
281 = localV
282 $ do
283 mdicts <- mapM mk_dict_var tvs
284 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
285 p (mk_lams mdicts)
286 where
287 mk_dict_var tv = do
288 r <- paDictArgType tv
289 case r of
290 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
291 Nothing -> return Nothing
292
293 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
294
295 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
296 polyApply expr tys
297 = do
298 dicts <- mapM paDictOfType tys
299 return $ expr `mkTyApps` tys `mkApps` dicts
300
301 polyVApply :: VExpr -> [Type] -> VM VExpr
302 polyVApply expr tys
303 = do
304 dicts <- mapM paDictOfType tys
305 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
306
307 hoistBinding :: Var -> CoreExpr -> VM ()
308 hoistBinding v e = updGEnv $ \env ->
309 env { global_bindings = (v,e) : global_bindings env }
310
311 hoistExpr :: FastString -> CoreExpr -> VM Var
312 hoistExpr fs expr
313 = do
314 var <- newLocalVar fs (exprType expr)
315 hoistBinding var expr
316 return var
317
318 hoistVExpr :: VExpr -> VM VVar
319 hoistVExpr (ve, le)
320 = do
321 fs <- getBindName
322 vv <- hoistExpr ('v' `consFS` fs) ve
323 lv <- hoistExpr ('l' `consFS` fs) le
324 return (vv, lv)
325
326 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
327 hoistPolyVExpr tvs p
328 = do
329 expr <- closedV . polyAbstract tvs $ \abstract ->
330 liftM (mapVect abstract) p
331 fn <- hoistVExpr expr
332 polyVApply (vVar fn) (mkTyVarTys tvs)
333
334 takeHoisted :: VM [(Var, CoreExpr)]
335 takeHoisted
336 = do
337 env <- readGEnv id
338 setGEnv $ env { global_bindings = [] }
339 return $ global_bindings env
340
341 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
342 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
343 = do
344 dict <- paDictOfType env_ty
345 mkv <- builtin mkClosureVar
346 mkl <- builtin mkClosurePVar
347 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
348 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
349
350 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
351 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
352 = do
353 vapply <- builtin applyClosureVar
354 lapply <- builtin applyClosurePVar
355 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
356 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
357
358 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
359 buildClosures tvs vars [] res_ty mk_body
360 = mk_body
361 buildClosures tvs vars [arg_ty] res_ty mk_body
362 = buildClosure tvs vars arg_ty res_ty mk_body
363 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
364 = do
365 res_ty' <- mkClosureTypes arg_tys res_ty
366 arg <- newLocalVVar FSLIT("x") arg_ty
367 buildClosure tvs vars arg_ty res_ty'
368 . hoistPolyVExpr tvs
369 $ do
370 lc <- builtin liftingContext
371 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
372 return $ vLams lc (vars ++ [arg]) clo
373
374 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
375 -- where
376 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
377 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
378 --
379 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
380 buildClosure tvs vars arg_ty res_ty mk_body
381 = do
382 (env_ty, env, bind) <- buildEnv vars
383 env_bndr <- newLocalVVar FSLIT("env") env_ty
384 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
385
386 fn <- hoistPolyVExpr tvs
387 $ do
388 lc <- builtin liftingContext
389 body <- mk_body
390 body' <- bind (vVar env_bndr)
391 (vVarApps lc body (vars ++ [arg_bndr]))
392 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
393
394 mkClosure arg_ty res_ty env_ty fn env
395
396 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
397 buildEnv vvs
398 = do
399 lc <- builtin liftingContext
400 let (ty, venv, vbind) = mkVectEnv tys vs
401 (lenv, lbind) <- mkLiftEnv lc tys ls
402 return (ty, (venv, lenv),
403 \(venv,lenv) (vbody,lbody) ->
404 do
405 let vbody' = vbind venv vbody
406 lbody' <- lbind lenv lbody
407 return (vbody', lbody'))
408 where
409 (vs,ls) = unzip vvs
410 tys = map idType vs
411
412 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
413 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
414 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
415 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
416 \env body -> Case env (mkWildId ty) (exprType body)
417 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
418 where
419 ty = mkCoreTupTy tys
420
421 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
422 mkLiftEnv lc [ty] [v]
423 = return (Var v, \env body ->
424 do
425 len <- lengthPA ty (Var v)
426 return . Let (NonRec v env)
427 $ Case len lc (exprType body) [(DEFAULT, [], body)])
428
429 -- NOTE: this transparently deals with empty environments
430 mkLiftEnv lc tys vs
431 = do
432 (env_tc, env_tyargs) <- parrayReprTyCon vty
433 let [env_con] = tyConDataCons env_tc
434
435 env = Var (dataConWrapId env_con)
436 `mkTyApps` env_tyargs
437 `mkVarApps` (lc : vs)
438
439 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
440 in
441 return $ Case scrut (mkWildId (exprType scrut))
442 (exprType body)
443 [(DataAlt env_con, lc : bndrs, body)]
444 return (env, bind)
445 where
446 vty = mkCoreTupTy tys
447
448 bndrs | null vs = [mkWildId unitTy]
449 | otherwise = vs
450