Follow introduction of MkCore in VectUtils
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6 newLocalVVar,
7
8 mkBuiltinCo,
9 mkPADictType, mkPArrayType, mkPReprType,
10
11 parrayReprTyCon, parrayReprDataCon, mkVScrut,
12 prDFunOfTyCon,
13 paDictArgType, paDictOfType, paDFunType,
14 paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
15 polyAbstract, polyApply, polyVApply,
16 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
17 buildClosure, buildClosures,
18 mkClosureApp
19 ) where
20
21 import VectCore
22 import VectMonad
23
24 import MkCore
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import Id ( mkWildId )
34 import MkId ( unwrapFamInstScrut )
35 import TysWiredIn
36 import BasicTypes ( Boxity(..) )
37 import Literal ( Literal, mkMachInt )
38
39 import Outputable
40 import FastString
41
42 import Control.Monad
43
44
45 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
46 collectAnnTypeArgs expr = go expr []
47 where
48 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
49 go e tys = (e, tys)
50
51 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnTypeBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
58 collectAnnValBinders expr = go [] expr
59 where
60 go bs (_, AnnLam b e) | isId b = go (b:bs) e
61 go bs e = (reverse bs, e)
62
63 isAnnTypeArg :: AnnExpr b ann -> Bool
64 isAnnTypeArg (_, AnnType _) = True
65 isAnnTypeArg _ = False
66
67 dataConTagZ :: DataCon -> Int
68 dataConTagZ con = dataConTag con - fIRST_TAG
69
70 mkDataConTagLit :: DataCon -> Literal
71 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
72
73 mkDataConTag :: DataCon -> CoreExpr
74 mkDataConTag = mkIntLitInt . dataConTagZ
75
76 splitPrimTyCon :: Type -> Maybe TyCon
77 splitPrimTyCon ty
78 | Just (tycon, []) <- splitTyConApp_maybe ty
79 , isPrimTyCon tycon
80 = Just tycon
81
82 | otherwise = Nothing
83
84 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
85 mkBuiltinTyConApp get_tc tys
86 = do
87 tc <- builtin get_tc
88 return $ mkTyConApp tc tys
89
90 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
91 mkBuiltinTyConApps get_tc tys ty
92 = do
93 tc <- builtin get_tc
94 return $ foldr (mk tc) ty tys
95 where
96 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
97
98 {-
99 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
100 mkBuiltinTyConApps1 _ dft [] = return dft
101 mkBuiltinTyConApps1 get_tc _ tys
102 = do
103 tc <- builtin get_tc
104 case tys of
105 [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
106 _ -> return $ foldr1 (mk tc) tys
107 where
108 mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkClosureType :: Type -> Type -> VM Type
111 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
112 -}
113
114 mkClosureTypes :: [Type] -> Type -> VM Type
115 mkClosureTypes = mkBuiltinTyConApps closureTyCon
116
117 mkPReprType :: Type -> VM Type
118 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
119
120 mkPADictType :: Type -> VM Type
121 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
122
123 mkPArrayType :: Type -> VM Type
124 mkPArrayType ty
125 | Just tycon <- splitPrimTyCon ty
126 = do
127 arr <- traceMaybeV "mkPArrayType" (ppr tycon)
128 $ lookupPrimPArray tycon
129 return $ mkTyConApp arr []
130 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
131
132 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
133 mkBuiltinCo get_tc
134 = do
135 tc <- builtin get_tc
136 return $ mkTyConApp tc []
137
138 parrayReprTyCon :: Type -> VM (TyCon, [Type])
139 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
140
141 parrayReprDataCon :: Type -> VM (DataCon, [Type])
142 parrayReprDataCon ty
143 = do
144 (tc, arg_tys) <- parrayReprTyCon ty
145 let [dc] = tyConDataCons tc
146 return (dc, arg_tys)
147
148 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
149 mkVScrut (ve, le)
150 = do
151 (tc, arg_tys) <- parrayReprTyCon (exprType ve)
152 return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
153
154 prDFunOfTyCon :: TyCon -> VM CoreExpr
155 prDFunOfTyCon tycon
156 = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
157
158 paDictArgType :: TyVar -> VM (Maybe Type)
159 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
160 where
161 go ty k | Just k' <- kindView k = go ty k'
162 go ty (FunTy k1 k2)
163 = do
164 tv <- newTyVar (fsLit "a") k1
165 mty1 <- go (TyVarTy tv) k1
166 case mty1 of
167 Just ty1 -> do
168 mty2 <- go (AppTy ty (TyVarTy tv)) k2
169 return $ fmap (ForAllTy tv . FunTy ty1) mty2
170 Nothing -> go ty k2
171
172 go ty k
173 | isLiftedTypeKind k
174 = liftM Just (mkPADictType ty)
175
176 go _ _ = return Nothing
177
178 paDictOfType :: Type -> VM CoreExpr
179 paDictOfType ty = paDictOfTyApp ty_fn ty_args
180 where
181 (ty_fn, ty_args) = splitAppTys ty
182
183 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
184 paDictOfTyApp ty_fn ty_args
185 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
186 paDictOfTyApp (TyVarTy tv) ty_args
187 = do
188 dfun <- maybeV (lookupTyVarPA tv)
189 paDFunApply dfun ty_args
190 paDictOfTyApp (TyConApp tc _) ty_args
191 = do
192 dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
193 paDFunApply (Var dfun) ty_args
194 paDictOfTyApp ty _ = pprPanic "paDictOfTyApp" (ppr ty)
195
196 paDFunType :: TyCon -> VM Type
197 paDFunType tc
198 = do
199 margs <- mapM paDictArgType tvs
200 res <- mkPADictType (mkTyConApp tc arg_tys)
201 return . mkForAllTys tvs
202 $ mkFunTys [arg | Just arg <- margs] res
203 where
204 tvs = tyConTyVars tc
205 arg_tys = mkTyVarTys tvs
206
207 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
208 paDFunApply dfun tys
209 = do
210 dicts <- mapM paDictOfType tys
211 return $ mkApps (mkTyApps dfun tys) dicts
212
213 type PAMethod = (Builtins -> Var, String)
214
215 pa_length, pa_replicate, pa_empty, pa_pack :: (Builtins -> Var, String)
216 pa_length = (lengthPAVar, "lengthPA")
217 pa_replicate = (replicatePAVar, "replicatePA")
218 pa_empty = (emptyPAVar, "emptyPA")
219 pa_pack = (packPAVar, "packPA")
220
221 paMethod :: PAMethod -> Type -> VM CoreExpr
222 paMethod (_method, name) ty
223 | Just tycon <- splitPrimTyCon ty
224 = do
225 fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
226 $ lookupPrimMethod tycon name
227 return (Var fn)
228
229 paMethod (method, _name) ty
230 = do
231 fn <- builtin method
232 dict <- paDictOfType ty
233 return $ mkApps (Var fn) [Type ty, dict]
234
235 mkPR :: Type -> VM CoreExpr
236 mkPR ty
237 = do
238 fn <- builtin mkPRVar
239 dict <- paDictOfType ty
240 return $ mkApps (Var fn) [Type ty, dict]
241
242 lengthPA :: Type -> CoreExpr -> VM CoreExpr
243 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
244
245 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
246 replicatePA len x = liftM (`mkApps` [len,x])
247 (paMethod pa_replicate (exprType x))
248
249 emptyPA :: Type -> VM CoreExpr
250 emptyPA = paMethod pa_empty
251
252 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
253 packPA ty xs len sel = liftM (`mkApps` [xs, len, sel])
254 (paMethod pa_pack ty)
255
256 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
257 -> VM CoreExpr
258 combinePA ty len sel is xs
259 = liftM (`mkApps` (len : sel : is : xs))
260 (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
261 where
262 n = length xs
263
264 liftPA :: CoreExpr -> VM CoreExpr
265 liftPA x
266 = do
267 lc <- builtin liftingContext
268 replicatePA (Var lc) x
269
270 newLocalVVar :: FastString -> Type -> VM VVar
271 newLocalVVar fs vty
272 = do
273 lty <- mkPArrayType vty
274 vv <- newLocalVar fs vty
275 lv <- newLocalVar fs lty
276 return (vv,lv)
277
278 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
279 polyAbstract tvs p
280 = localV
281 $ do
282 mdicts <- mapM mk_dict_var tvs
283 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
284 p (mk_lams mdicts)
285 where
286 mk_dict_var tv = do
287 r <- paDictArgType tv
288 case r of
289 Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
290 Nothing -> return Nothing
291
292 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
293
294 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
295 polyApply expr tys
296 = do
297 dicts <- mapM paDictOfType tys
298 return $ expr `mkTyApps` tys `mkApps` dicts
299
300 polyVApply :: VExpr -> [Type] -> VM VExpr
301 polyVApply expr tys
302 = do
303 dicts <- mapM paDictOfType tys
304 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
305
306 hoistBinding :: Var -> CoreExpr -> VM ()
307 hoistBinding v e = updGEnv $ \env ->
308 env { global_bindings = (v,e) : global_bindings env }
309
310 hoistExpr :: FastString -> CoreExpr -> VM Var
311 hoistExpr fs expr
312 = do
313 var <- newLocalVar fs (exprType expr)
314 hoistBinding var expr
315 return var
316
317 hoistVExpr :: VExpr -> VM VVar
318 hoistVExpr (ve, le)
319 = do
320 fs <- getBindName
321 vv <- hoistExpr ('v' `consFS` fs) ve
322 lv <- hoistExpr ('l' `consFS` fs) le
323 return (vv, lv)
324
325 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
326 hoistPolyVExpr tvs p
327 = do
328 expr <- closedV . polyAbstract tvs $ \abstract ->
329 liftM (mapVect abstract) p
330 fn <- hoistVExpr expr
331 polyVApply (vVar fn) (mkTyVarTys tvs)
332
333 takeHoisted :: VM [(Var, CoreExpr)]
334 takeHoisted
335 = do
336 env <- readGEnv id
337 setGEnv $ env { global_bindings = [] }
338 return $ global_bindings env
339
340 {-
341 boxExpr :: Type -> VExpr -> VM VExpr
342 boxExpr ty (vexpr, lexpr)
343 | Just (tycon, []) <- splitTyConApp_maybe ty
344 , isUnLiftedTyCon tycon
345 = do
346 r <- lookupBoxedTyCon tycon
347 case r of
348 Just tycon' -> let [dc] = tyConDataCons tycon'
349 in
350 return (mkConApp dc [vexpr], lexpr)
351 Nothing -> return (vexpr, lexpr)
352 -}
353
354 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
355 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
356 = do
357 dict <- paDictOfType env_ty
358 mkv <- builtin mkClosureVar
359 mkl <- builtin mkClosurePVar
360 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
361 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
362
363 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
364 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
365 = do
366 vapply <- builtin applyClosureVar
367 lapply <- builtin applyClosurePVar
368 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
369 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
370
371 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
372 buildClosures _ _ [] _ mk_body
373 = mk_body
374 buildClosures tvs vars [arg_ty] res_ty mk_body
375 = buildClosure tvs vars arg_ty res_ty mk_body
376 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
377 = do
378 res_ty' <- mkClosureTypes arg_tys res_ty
379 arg <- newLocalVVar (fsLit "x") arg_ty
380 buildClosure tvs vars arg_ty res_ty'
381 . hoistPolyVExpr tvs
382 $ do
383 lc <- builtin liftingContext
384 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
385 return $ vLams lc (vars ++ [arg]) clo
386
387 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
388 -- where
389 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
390 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
391 --
392 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
393 buildClosure tvs vars arg_ty res_ty mk_body
394 = do
395 (env_ty, env, bind) <- buildEnv vars
396 env_bndr <- newLocalVVar (fsLit "env") env_ty
397 arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
398
399 fn <- hoistPolyVExpr tvs
400 $ do
401 lc <- builtin liftingContext
402 body <- mk_body
403 body' <- bind (vVar env_bndr)
404 (vVarApps lc body (vars ++ [arg_bndr]))
405 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
406
407 mkClosure arg_ty res_ty env_ty fn env
408
409 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
410 buildEnv vvs
411 = do
412 lc <- builtin liftingContext
413 let (ty, venv, vbind) = mkVectEnv tys vs
414 (lenv, lbind) <- mkLiftEnv lc tys ls
415 return (ty, (venv, lenv),
416 \(venv,lenv) (vbody,lbody) ->
417 do
418 let vbody' = vbind venv vbody
419 lbody' <- lbind lenv lbody
420 return (vbody', lbody'))
421 where
422 (vs,ls) = unzip vvs
423 tys = map varType vs
424
425 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
426 mkVectEnv [] [] = (unitTy, Var unitDataConId, \_ body -> body)
427 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
428 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
429 \env body -> Case env (mkWildId ty) (exprType body)
430 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
431 where
432 ty = mkCoreTupTy tys
433
434 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
435 mkLiftEnv lc [ty] [v]
436 = return (Var v, \env body ->
437 do
438 len <- lengthPA ty (Var v)
439 return . Let (NonRec v env)
440 $ Case len lc (exprType body) [(DEFAULT, [], body)])
441
442 -- NOTE: this transparently deals with empty environments
443 mkLiftEnv lc tys vs
444 = do
445 (env_tc, env_tyargs) <- parrayReprTyCon vty
446
447 bndrs <- if null vs then do
448 v <- newDummyVar unitTy
449 return [v]
450 else return vs
451 let [env_con] = tyConDataCons env_tc
452
453 env = Var (dataConWrapId env_con)
454 `mkTyApps` env_tyargs
455 `mkApps` (Var lc : args)
456
457 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
458 in
459 return $ Case scrut (mkWildId (exprType scrut))
460 (exprType body)
461 [(DataAlt env_con, lc : bndrs, body)]
462 return (env, bind)
463 where
464 vty = mkCoreTupTy tys
465
466 args | null vs = [Var unitDataConId]
467 | otherwise = map Var vs
468