Find the correct array type for primitive tycons
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 mkDataConTag,
5 splitClosureTy,
6
7 mkBuiltinCo,
8 mkPADictType, mkPArrayType, mkPReprType,
9
10 parrayReprTyCon, parrayReprDataCon, mkVScrut,
11 prDFunOfTyCon,
12 paDictArgType, paDictOfType, paDFunType,
13 paMethod, mkPR, lengthPA, replicatePA, emptyPA, liftPA,
14 polyAbstract, polyApply, polyVApply,
15 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
16 buildClosure, buildClosures,
17 mkClosureApp
18 ) where
19
20 #include "HsVersions.h"
21
22 import VectCore
23 import VectMonad
24
25 import DsUtils
26 import CoreSyn
27 import CoreUtils
28 import Coercion
29 import Type
30 import TypeRep
31 import TyCon
32 import DataCon
33 import Var
34 import Id ( mkWildId )
35 import MkId ( unwrapFamInstScrut )
36 import Name ( Name )
37 import PrelNames
38 import TysWiredIn
39 import TysPrim ( intPrimTy )
40 import BasicTypes ( Boxity(..) )
41
42 import Outputable
43 import FastString
44
45 import Data.List ( zipWith4 )
46 import Control.Monad ( liftM, liftM2, zipWithM_ )
47
48 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
49 collectAnnTypeArgs expr = go expr []
50 where
51 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
52 go e tys = (e, tys)
53
54 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
55 collectAnnTypeBinders expr = go [] expr
56 where
57 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
58 go bs e = (reverse bs, e)
59
60 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
61 collectAnnValBinders expr = go [] expr
62 where
63 go bs (_, AnnLam b e) | isId b = go (b:bs) e
64 go bs e = (reverse bs, e)
65
66 isAnnTypeArg :: AnnExpr b ann -> Bool
67 isAnnTypeArg (_, AnnType t) = True
68 isAnnTypeArg _ = False
69
70 mkDataConTag :: DataCon -> CoreExpr
71 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
72
73 splitUnTy :: String -> Name -> Type -> Type
74 splitUnTy s name ty
75 | Just (tc, [ty']) <- splitTyConApp_maybe ty
76 , tyConName tc == name
77 = ty'
78
79 | otherwise = pprPanic s (ppr ty)
80
81 splitBinTy :: String -> Name -> Type -> (Type, Type)
82 splitBinTy s name ty
83 | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
84 , tyConName tc == name
85 = (ty1, ty2)
86
87 | otherwise = pprPanic s (ppr ty)
88
89 splitFixedTyConApp :: TyCon -> Type -> [Type]
90 splitFixedTyConApp tc ty
91 | Just (tc', tys) <- splitTyConApp_maybe ty
92 , tc == tc'
93 = tys
94
95 | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
96
97 splitClosureTy :: Type -> (Type, Type)
98 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
99
100 splitPArrayTy :: Type -> Type
101 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
102
103 splitPrimTyCon :: Type -> Maybe TyCon
104 splitPrimTyCon ty
105 | Just (tycon, []) <- splitTyConApp_maybe ty
106 , isPrimTyCon tycon
107 = Just tycon
108
109 | otherwise = Nothing
110
111 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
112 mkBuiltinTyConApp get_tc tys
113 = do
114 tc <- builtin get_tc
115 return $ mkTyConApp tc tys
116
117 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
118 mkBuiltinTyConApps get_tc tys ty
119 = do
120 tc <- builtin get_tc
121 return $ foldr (mk tc) ty tys
122 where
123 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
124
125 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
126 mkBuiltinTyConApps1 get_tc dft [] = return dft
127 mkBuiltinTyConApps1 get_tc dft tys
128 = do
129 tc <- builtin get_tc
130 case tys of
131 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
132 _ -> return $ foldr1 (mk tc) tys
133 where
134 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
135
136 mkClosureType :: Type -> Type -> VM Type
137 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
138
139 mkClosureTypes :: [Type] -> Type -> VM Type
140 mkClosureTypes = mkBuiltinTyConApps closureTyCon
141
142 mkPReprType :: Type -> VM Type
143 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
144
145 mkPADictType :: Type -> VM Type
146 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
147
148 mkPArrayType :: Type -> VM Type
149 mkPArrayType ty
150 | Just tycon <- splitPrimTyCon ty
151 = do
152 arr <- traceMaybeV "mkPArrayType" (ppr tycon)
153 $ lookupPrimPArray tycon
154 return $ mkTyConApp arr []
155 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
156
157 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
158 mkBuiltinCo get_tc
159 = do
160 tc <- builtin get_tc
161 return $ mkTyConApp tc []
162
163 parrayReprTyCon :: Type -> VM (TyCon, [Type])
164 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
165
166 parrayReprDataCon :: Type -> VM (DataCon, [Type])
167 parrayReprDataCon ty
168 = do
169 (tc, arg_tys) <- parrayReprTyCon ty
170 let [dc] = tyConDataCons tc
171 return (dc, arg_tys)
172
173 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
174 mkVScrut (ve, le)
175 = do
176 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
177 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
178
179 prDFunOfTyCon :: TyCon -> VM CoreExpr
180 prDFunOfTyCon tycon
181 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
182
183 paDictArgType :: TyVar -> VM (Maybe Type)
184 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
185 where
186 go ty k | Just k' <- kindView k = go ty k'
187 go ty (FunTy k1 k2)
188 = do
189 tv <- newTyVar FSLIT("a") k1
190 mty1 <- go (TyVarTy tv) k1
191 case mty1 of
192 Just ty1 -> do
193 mty2 <- go (AppTy ty (TyVarTy tv)) k2
194 return $ fmap (ForAllTy tv . FunTy ty1) mty2
195 Nothing -> go ty k2
196
197 go ty k
198 | isLiftedTypeKind k
199 = liftM Just (mkPADictType ty)
200
201 go ty k = return Nothing
202
203 paDictOfType :: Type -> VM CoreExpr
204 paDictOfType ty = paDictOfTyApp ty_fn ty_args
205 where
206 (ty_fn, ty_args) = splitAppTys ty
207
208 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
209 paDictOfTyApp ty_fn ty_args
210 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
211 paDictOfTyApp (TyVarTy tv) ty_args
212 = do
213 dfun <- maybeV (lookupTyVarPA tv)
214 paDFunApply dfun ty_args
215 paDictOfTyApp (TyConApp tc _) ty_args
216 = do
217 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
218 paDFunApply (Var dfun) ty_args
219 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
220
221 paDFunType :: TyCon -> VM Type
222 paDFunType tc
223 = do
224 margs <- mapM paDictArgType tvs
225 res <- mkPADictType (mkTyConApp tc arg_tys)
226 return . mkForAllTys tvs
227 $ mkFunTys [arg | Just arg <- margs] res
228 where
229 tvs = tyConTyVars tc
230 arg_tys = mkTyVarTys tvs
231
232 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
233 paDFunApply dfun tys
234 = do
235 dicts <- mapM paDictOfType tys
236 return $ mkApps (mkTyApps dfun tys) dicts
237
238 type PAMethod = (Builtins -> Var, String)
239
240 pa_length = (lengthPAVar, "lengthPA")
241 pa_replicate = (replicatePAVar, "replicatePA")
242 pa_empty = (emptyPAVar, "emptyPA")
243
244 paMethod :: PAMethod -> Type -> VM CoreExpr
245 paMethod (method, name) ty
246 | Just tycon <- splitPrimTyCon ty
247 = do
248 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
249 $ lookupPrimMethod tycon name
250 return (Var fn)
251
252 paMethod (method, name) ty
253 = do
254 fn <- builtin method
255 dict <- paDictOfType ty
256 return $ mkApps (Var fn) [Type ty, dict]
257
258 mkPR :: Type -> VM CoreExpr
259 mkPR ty
260 = do
261 fn <- builtin mkPRVar
262 dict <- paDictOfType ty
263 return $ mkApps (Var fn) [Type ty, dict]
264
265 lengthPA :: CoreExpr -> VM CoreExpr
266 lengthPA x = liftM (`App` x) (paMethod pa_length ty)
267 where
268 ty = splitPArrayTy (exprType x)
269
270 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
271 replicatePA len x = liftM (`mkApps` [len,x])
272 (paMethod pa_replicate (exprType x))
273
274 emptyPA :: Type -> VM CoreExpr
275 emptyPA = paMethod pa_empty
276
277 liftPA :: CoreExpr -> VM CoreExpr
278 liftPA x
279 = do
280 lc <- builtin liftingContext
281 replicatePA (Var lc) x
282
283 newLocalVVar :: FastString -> Type -> VM VVar
284 newLocalVVar fs vty
285 = do
286 lty <- mkPArrayType vty
287 vv <- newLocalVar fs vty
288 lv <- newLocalVar fs lty
289 return (vv,lv)
290
291 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
292 polyAbstract tvs p
293 = localV
294 $ do
295 mdicts <- mapM mk_dict_var tvs
296 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
297 p (mk_lams mdicts)
298 where
299 mk_dict_var tv = do
300 r <- paDictArgType tv
301 case r of
302 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
303 Nothing -> return Nothing
304
305 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
306
307 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
308 polyApply expr tys
309 = do
310 dicts <- mapM paDictOfType tys
311 return $ expr `mkTyApps` tys `mkApps` dicts
312
313 polyVApply :: VExpr -> [Type] -> VM VExpr
314 polyVApply expr tys
315 = do
316 dicts <- mapM paDictOfType tys
317 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
318
319 hoistBinding :: Var -> CoreExpr -> VM ()
320 hoistBinding v e = updGEnv $ \env ->
321 env { global_bindings = (v,e) : global_bindings env }
322
323 hoistExpr :: FastString -> CoreExpr -> VM Var
324 hoistExpr fs expr
325 = do
326 var <- newLocalVar fs (exprType expr)
327 hoistBinding var expr
328 return var
329
330 hoistVExpr :: VExpr -> VM VVar
331 hoistVExpr (ve, le)
332 = do
333 fs <- getBindName
334 vv <- hoistExpr ('v' `consFS` fs) ve
335 lv <- hoistExpr ('l' `consFS` fs) le
336 return (vv, lv)
337
338 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
339 hoistPolyVExpr tvs p
340 = do
341 expr <- closedV . polyAbstract tvs $ \abstract ->
342 liftM (mapVect abstract) p
343 fn <- hoistVExpr expr
344 polyVApply (vVar fn) (mkTyVarTys tvs)
345
346 takeHoisted :: VM [(Var, CoreExpr)]
347 takeHoisted
348 = do
349 env <- readGEnv id
350 setGEnv $ env { global_bindings = [] }
351 return $ global_bindings env
352
353 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
354 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
355 = do
356 dict <- paDictOfType env_ty
357 mkv <- builtin mkClosureVar
358 mkl <- builtin mkClosurePVar
359 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
360 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
361
362 mkClosureApp :: VExpr -> VExpr -> VM VExpr
363 mkClosureApp (vclo, lclo) (varg, larg)
364 = do
365 vapply <- builtin applyClosureVar
366 lapply <- builtin applyClosurePVar
367 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
368 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
369 where
370 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
371
372 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
373 buildClosures tvs vars [] res_ty mk_body
374 = mk_body
375 buildClosures tvs vars [arg_ty] res_ty mk_body
376 = buildClosure tvs vars arg_ty res_ty mk_body
377 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
378 = do
379 res_ty' <- mkClosureTypes arg_tys res_ty
380 arg <- newLocalVVar FSLIT("x") arg_ty
381 buildClosure tvs vars arg_ty res_ty'
382 . hoistPolyVExpr tvs
383 $ do
384 lc <- builtin liftingContext
385 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
386 return $ vLams lc (vars ++ [arg]) clo
387
388 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
389 -- where
390 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
391 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
392 --
393 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
394 buildClosure tvs vars arg_ty res_ty mk_body
395 = do
396 (env_ty, env, bind) <- buildEnv vars
397 env_bndr <- newLocalVVar FSLIT("env") env_ty
398 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
399
400 fn <- hoistPolyVExpr tvs
401 $ do
402 lc <- builtin liftingContext
403 body <- mk_body
404 body' <- bind (vVar env_bndr)
405 (vVarApps lc body (vars ++ [arg_bndr]))
406 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
407
408 mkClosure arg_ty res_ty env_ty fn env
409
410 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
411 buildEnv vvs
412 = do
413 lc <- builtin liftingContext
414 let (ty, venv, vbind) = mkVectEnv tys vs
415 (lenv, lbind) <- mkLiftEnv lc tys ls
416 return (ty, (venv, lenv),
417 \(venv,lenv) (vbody,lbody) ->
418 do
419 let vbody' = vbind venv vbody
420 lbody' <- lbind lenv lbody
421 return (vbody', lbody'))
422 where
423 (vs,ls) = unzip vvs
424 tys = map idType vs
425
426 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
427 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
428 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
429 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
430 \env body -> Case env (mkWildId ty) (exprType body)
431 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
432 where
433 ty = mkCoreTupTy tys
434
435 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
436 mkLiftEnv lc [ty] [v]
437 = return (Var v, \env body ->
438 do
439 len <- lengthPA (Var v)
440 return . Let (NonRec v env)
441 $ Case len lc (exprType body) [(DEFAULT, [], body)])
442
443 -- NOTE: this transparently deals with empty environments
444 mkLiftEnv lc tys vs
445 = do
446 (env_tc, env_tyargs) <- parrayReprTyCon vty
447 let [env_con] = tyConDataCons env_tc
448
449 env = Var (dataConWrapId env_con)
450 `mkTyApps` env_tyargs
451 `mkVarApps` (lc : vs)
452
453 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
454 in
455 return $ Case scrut (mkWildId (exprType scrut))
456 (exprType body)
457 [(DataAlt env_con, lc : bndrs, body)]
458 return (env, bind)
459 where
460 vty = mkCoreTupTy tys
461
462 bndrs | null vs = [mkWildId unitTy]
463 | otherwise = vs
464