Vectorisation of enumeration types
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 mkDataConTag, mkDataConTagLit,
5 splitClosureTy,
6
7 mkBuiltinCo,
8 mkPADictType, mkPArrayType, mkPReprType,
9
10 parrayReprTyCon, parrayReprDataCon, mkVScrut,
11 prDFunOfTyCon,
12 paDictArgType, paDictOfType, paDFunType,
13 paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14 polyAbstract, polyApply, polyVApply,
15 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16 buildClosure, buildClosures,
17 mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id ( mkWildId )
35 import MkId ( unwrapFamInstScrut )
36 import Name ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim ( intPrimTy )
40 import BasicTypes ( Boxity(..) )
41 import Literal ( Literal, mkMachInt )
42
43 import Outputable
44 import FastString
45
46 import Data.List ( zipWith4 )
47 import Control.Monad ( liftM, liftM2, zipWithM_ )
48
49 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
50 collectAnnTypeArgs expr = go expr []
51 where
52 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
53 go e tys = (e, tys)
54
55 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnTypeBinders expr = go [] expr
57 where
58 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
59 go bs e = (reverse bs, e)
60
61 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnValBinders expr = go [] expr
63 where
64 go bs (_, AnnLam b e) | isId b = go (b:bs) e
65 go bs e = (reverse bs, e)
66
67 isAnnTypeArg :: AnnExpr b ann -> Bool
68 isAnnTypeArg (_, AnnType t) = True
69 isAnnTypeArg _ = False
70
71 mkDataConTagLit :: DataCon -> Literal
72 mkDataConTagLit con
73 = mkMachInt . toInteger $ dataConTag con - fIRST_TAG
74
75 mkDataConTag :: DataCon -> CoreExpr
76 mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG)
77
78 splitUnTy :: String -> Name -> Type -> Type
79 splitUnTy s name ty
80 | Just (tc, [ty']) <- splitTyConApp_maybe ty
81 , tyConName tc == name
82 = ty'
83
84 | otherwise = pprPanic s (ppr ty)
85
86 splitBinTy :: String -> Name -> Type -> (Type, Type)
87 splitBinTy s name ty
88 | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
89 , tyConName tc == name
90 = (ty1, ty2)
91
92 | otherwise = pprPanic s (ppr ty)
93
94 splitFixedTyConApp :: TyCon -> Type -> [Type]
95 splitFixedTyConApp tc ty
96 | Just (tc', tys) <- splitTyConApp_maybe ty
97 , tc == tc'
98 = tys
99
100 | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
101
102 splitClosureTy :: Type -> (Type, Type)
103 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
104
105 splitPArrayTy :: Type -> Type
106 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
107
108 splitPrimTyCon :: Type -> Maybe TyCon
109 splitPrimTyCon ty
110 | Just (tycon, []) <- splitTyConApp_maybe ty
111 , isPrimTyCon tycon
112 = Just tycon
113
114 | otherwise = Nothing
115
116 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
117 mkBuiltinTyConApp get_tc tys
118 = do
119 tc <- builtin get_tc
120 return $ mkTyConApp tc tys
121
122 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
123 mkBuiltinTyConApps get_tc tys ty
124 = do
125 tc <- builtin get_tc
126 return $ foldr (mk tc) ty tys
127 where
128 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
129
130 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
131 mkBuiltinTyConApps1 get_tc dft [] = return dft
132 mkBuiltinTyConApps1 get_tc dft tys
133 = do
134 tc <- builtin get_tc
135 case tys of
136 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
137 _ -> return $ foldr1 (mk tc) tys
138 where
139 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
140
141 mkClosureType :: Type -> Type -> VM Type
142 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
143
144 mkClosureTypes :: [Type] -> Type -> VM Type
145 mkClosureTypes = mkBuiltinTyConApps closureTyCon
146
147 mkPReprType :: Type -> VM Type
148 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
149
150 mkPADictType :: Type -> VM Type
151 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
152
153 mkPArrayType :: Type -> VM Type
154 mkPArrayType ty
155 | Just tycon <- splitPrimTyCon ty
156 = do
157 arr <- traceMaybeV "mkPArrayType" (ppr tycon)
158 $ lookupPrimPArray tycon
159 return $ mkTyConApp arr []
160 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
161
162 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
163 mkBuiltinCo get_tc
164 = do
165 tc <- builtin get_tc
166 return $ mkTyConApp tc []
167
168 parrayReprTyCon :: Type -> VM (TyCon, [Type])
169 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
170
171 parrayReprDataCon :: Type -> VM (DataCon, [Type])
172 parrayReprDataCon ty
173 = do
174 (tc, arg_tys) <- parrayReprTyCon ty
175 let [dc] = tyConDataCons tc
176 return (dc, arg_tys)
177
178 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
179 mkVScrut (ve, le)
180 = do
181 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
182 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
183
184 prDFunOfTyCon :: TyCon -> VM CoreExpr
185 prDFunOfTyCon tycon
186 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
187
188 paDictArgType :: TyVar -> VM (Maybe Type)
189 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
190 where
191 go ty k | Just k' <- kindView k = go ty k'
192 go ty (FunTy k1 k2)
193 = do
194 tv <- newTyVar FSLIT("a") k1
195 mty1 <- go (TyVarTy tv) k1
196 case mty1 of
197 Just ty1 -> do
198 mty2 <- go (AppTy ty (TyVarTy tv)) k2
199 return $ fmap (ForAllTy tv . FunTy ty1) mty2
200 Nothing -> go ty k2
201
202 go ty k
203 | isLiftedTypeKind k
204 = liftM Just (mkPADictType ty)
205
206 go ty k = return Nothing
207
208 paDictOfType :: Type -> VM CoreExpr
209 paDictOfType ty = paDictOfTyApp ty_fn ty_args
210 where
211 (ty_fn, ty_args) = splitAppTys ty
212
213 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
214 paDictOfTyApp ty_fn ty_args
215 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
216 paDictOfTyApp (TyVarTy tv) ty_args
217 = do
218 dfun <- maybeV (lookupTyVarPA tv)
219 paDFunApply dfun ty_args
220 paDictOfTyApp (TyConApp tc _) ty_args
221 = do
222 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
223 paDFunApply (Var dfun) ty_args
224 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
225
226 paDFunType :: TyCon -> VM Type
227 paDFunType tc
228 = do
229 margs <- mapM paDictArgType tvs
230 res <- mkPADictType (mkTyConApp tc arg_tys)
231 return . mkForAllTys tvs
232 $ mkFunTys [arg | Just arg <- margs] res
233 where
234 tvs = tyConTyVars tc
235 arg_tys = mkTyVarTys tvs
236
237 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
238 paDFunApply dfun tys
239 = do
240 dicts <- mapM paDictOfType tys
241 return $ mkApps (mkTyApps dfun tys) dicts
242
243 type PAMethod = (Builtins -> Var, String)
244
245 pa_length = (lengthPAVar, "lengthPA")
246 pa_replicate = (replicatePAVar, "replicatePA")
247 pa_empty = (emptyPAVar, "emptyPA")
248
249 paMethod :: PAMethod -> Type -> VM CoreExpr
250 paMethod (method, name) ty
251 | Just tycon <- splitPrimTyCon ty
252 = do
253 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
254 $ lookupPrimMethod tycon name
255 return (Var fn)
256
257 paMethod (method, name) ty
258 = do
259 fn <- builtin method
260 dict <- paDictOfType ty
261 return $ mkApps (Var fn) [Type ty, dict]
262
263 mkPR :: Type -> VM CoreExpr
264 mkPR ty
265 = do
266 fn <- builtin mkPRVar
267 dict <- paDictOfType ty
268 return $ mkApps (Var fn) [Type ty, dict]
269
270 lengthPA :: CoreExpr -> VM CoreExpr
271 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
272 where
273 ty = splitPArrayTy (exprType x)
274
275 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
276 replicatePA len x = liftM (`mkApps` [len,x])
277 (paMethod pa_replicate (exprType x))
278
279 emptyPA :: Type -> VM CoreExpr
280 emptyPA = paMethod pa_empty
281
282 liftPA :: CoreExpr -> VM CoreExpr
283 liftPA x
284 = do
285 lc <- builtin liftingContext
286 replicatePA (Var lc) x
287
288 newLocalVVar :: FastString -> Type -> VM VVar
289 newLocalVVar fs vty
290 = do
291 lty <- mkPArrayType vty
292 vv <- newLocalVar fs vty
293 lv <- newLocalVar fs lty
294 return (vv,lv)
295
296 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
297 polyAbstract tvs p
298 = localV
299 $ do
300 mdicts <- mapM mk_dict_var tvs
301 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
302 p (mk_lams mdicts)
303 where
304 mk_dict_var tv = do
305 r <- paDictArgType tv
306 case r of
307 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
308 Nothing -> return Nothing
309
310 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
311
312 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
313 polyApply expr tys
314 = do
315 dicts <- mapM paDictOfType tys
316 return $ expr `mkTyApps` tys `mkApps` dicts
317
318 polyVApply :: VExpr -> [Type] -> VM VExpr
319 polyVApply expr tys
320 = do
321 dicts <- mapM paDictOfType tys
322 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
323
324 hoistBinding :: Var -> CoreExpr -> VM ()
325 hoistBinding v e = updGEnv $ \env ->
326 env { global_bindings = (v,e) : global_bindings env }
327
328 hoistExpr :: FastString -> CoreExpr -> VM Var
329 hoistExpr fs expr
330 = do
331 var <- newLocalVar fs (exprType expr)
332 hoistBinding var expr
333 return var
334
335 hoistVExpr :: VExpr -> VM VVar
336 hoistVExpr (ve, le)
337 = do
338 fs <- getBindName
339 vv <- hoistExpr ('v' `consFS` fs) ve
340 lv <- hoistExpr ('l' `consFS` fs) le
341 return (vv, lv)
342
343 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
344 hoistPolyVExpr tvs p
345 = do
346 expr <- closedV . polyAbstract tvs $ \abstract ->
347 liftM (mapVect abstract) p
348 fn <- hoistVExpr expr
349 polyVApply (vVar fn) (mkTyVarTys tvs)
350
351 takeHoisted :: VM [(Var, CoreExpr)]
352 takeHoisted
353 = do
354 env <- readGEnv id
355 setGEnv $ env { global_bindings = [] }
356 return $ global_bindings env
357
358 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
359 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
360 = do
361 dict <- paDictOfType env_ty
362 mkv <- builtin mkClosureVar
363 mkl <- builtin mkClosurePVar
364 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
365 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
366
367 mkClosureApp :: VExpr -> VExpr -> VM VExpr
368 mkClosureApp (vclo, lclo) (varg, larg)
369 = do
370 vapply <- builtin applyClosureVar
371 lapply <- builtin applyClosurePVar
372 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
373 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
374 where
375 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
376
377 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
378 buildClosures tvs vars [] res_ty mk_body
379 = mk_body
380 buildClosures tvs vars [arg_ty] res_ty mk_body
381 = buildClosure tvs vars arg_ty res_ty mk_body
382 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
383 = do
384 res_ty' <- mkClosureTypes arg_tys res_ty
385 arg <- newLocalVVar FSLIT("x") arg_ty
386 buildClosure tvs vars arg_ty res_ty'
387 . hoistPolyVExpr tvs
388 $ do
389 lc <- builtin liftingContext
390 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
391 return $ vLams lc (vars ++ [arg]) clo
392
393 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
394 -- where
395 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
396 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
397 --
398 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
399 buildClosure tvs vars arg_ty res_ty mk_body
400 = do
401 (env_ty, env, bind) <- buildEnv vars
402 env_bndr <- newLocalVVar FSLIT("env") env_ty
403 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
404
405 fn <- hoistPolyVExpr tvs
406 $ do
407 lc <- builtin liftingContext
408 body <- mk_body
409 body' <- bind (vVar env_bndr)
410 (vVarApps lc body (vars ++ [arg_bndr]))
411 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
412
413 mkClosure arg_ty res_ty env_ty fn env
414
415 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
416 buildEnv vvs
417 = do
418 lc <- builtin liftingContext
419 let (ty, venv, vbind) = mkVectEnv tys vs
420 (lenv, lbind) <- mkLiftEnv lc tys ls
421 return (ty, (venv, lenv),
422 \(venv,lenv) (vbody,lbody) ->
423 do
424 let vbody' = vbind venv vbody
425 lbody' <- lbind lenv lbody
426 return (vbody', lbody'))
427 where
428 (vs,ls) = unzip vvs
429 tys = map idType vs
430
431 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
432 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
433 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
434 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
435 \env body -> Case env (mkWildId ty) (exprType body)
436 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
437 where
438 ty = mkCoreTupTy tys
439
440 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
441 mkLiftEnv lc [ty] [v]
442 = return (Var v, \env body ->
443 do
444 len <- lengthPA (Var v)
445 return . Let (NonRec v env)
446 $ Case len lc (exprType body) [(DEFAULT, [], body)])
447
448 -- NOTE: this transparently deals with empty environments
449 mkLiftEnv lc tys vs
450 = do
451 (env_tc, env_tyargs) <- parrayReprTyCon vty
452 let [env_con] = tyConDataCons env_tc
453
454 env = Var (dataConWrapId env_con)
455 `mkTyApps` env_tyargs
456 `mkVarApps` (lc : vs)
457
458 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
459 in
460 return $ Case scrut (mkWildId (exprType scrut))
461 (exprType body)
462 [(DataAlt env_con, lc : bndrs, body)]
463 return (env, bind)
464 where
465 vty = mkCoreTupTy tys
466
467 bndrs | null vs = [mkWildId unitTy]
468 | otherwise = vs
469