Move code
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 splitClosureTy,
5 mkPADictType, mkPArrayType,
6 paDictArgType, paDictOfType,
7 paMethod, lengthPA, replicatePA, emptyPA,
8 polyAbstract, polyApply, polyVApply,
9 lookupPArrayFamInst,
10 hoistExpr, hoistPolyVExpr, takeHoisted,
11 buildClosure, buildClosures,
12 mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon ( dataConWrapId )
27 import Var
28 import Id ( mkWildId )
29 import MkId ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41 where
42 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43 go e tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47 where
48 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49 go bs e = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isId b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _ = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67 , isClosureTyCon tc
68 = (arg_ty, res_ty)
69
70 | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78 , isPArrayTyCon tc
79 = arg_ty
80
81 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85 = do
86 tc <- builtin closureTyCon
87 return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91 = do
92 tc <- builtin closureTyCon
93 return $ foldr (mk tc) res_ty arg_tys
94 where
95 mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99 = do
100 tc <- builtin paDictTyCon
101 return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105 = do
106 tc <- builtin parrayTyCon
107 return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111 where
112 go ty k | Just k' <- kindView k = go ty k'
113 go ty (FunTy k1 k2)
114 = do
115 tv <- newTyVar FSLIT("a") k1
116 mty1 <- go (TyVarTy tv) k1
117 case mty1 of
118 Just ty1 -> do
119 mty2 <- go (AppTy ty (TyVarTy tv)) k2
120 return $ fmap (ForAllTy tv . FunTy ty1) mty2
121 Nothing -> go ty k2
122
123 go ty k
124 | isLiftedTypeKind k
125 = liftM Just (mkPADictType ty)
126
127 go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131 where
132 (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138 = do
139 dfun <- maybeV (lookupTyVarPA tv)
140 paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142 = do
143 pa_class <- builtin paClass
144 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
145 paDFunApply (Var dfun) ty_args'
146 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
147
148 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
149 paDFunApply dfun tys
150 = do
151 dicts <- mapM paDictOfType tys
152 return $ mkApps (mkTyApps dfun tys) dicts
153
154 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
155 paMethod method ty
156 = do
157 fn <- builtin method
158 dict <- paDictOfType ty
159 return $ mkApps (Var fn) [Type ty, dict]
160
161 lengthPA :: CoreExpr -> VM CoreExpr
162 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
163 where
164 ty = splitPArrayTy (exprType x)
165
166 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
167 replicatePA len x = liftM (`mkApps` [len,x])
168 (paMethod replicatePAVar (exprType x))
169
170 emptyPA :: Type -> VM CoreExpr
171 emptyPA = paMethod emptyPAVar
172
173 newLocalVVar :: FastString -> Type -> VM VVar
174 newLocalVVar fs vty
175 = do
176 lty <- mkPArrayType vty
177 vv <- newLocalVar fs vty
178 lv <- newLocalVar fs lty
179 return (vv,lv)
180
181 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
182 polyAbstract tvs p
183 = localV
184 $ do
185 mdicts <- mapM mk_dict_var tvs
186 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
187 p (mk_lams mdicts)
188 where
189 mk_dict_var tv = do
190 r <- paDictArgType tv
191 case r of
192 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
193 Nothing -> return Nothing
194
195 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
196
197 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
198 polyApply expr tys
199 = do
200 dicts <- mapM paDictOfType tys
201 return $ expr `mkTyApps` tys `mkApps` dicts
202
203 polyVApply :: VExpr -> [Type] -> VM VExpr
204 polyVApply expr tys
205 = do
206 dicts <- mapM paDictOfType tys
207 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
208
209 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
210 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
211
212 hoistExpr :: FastString -> CoreExpr -> VM Var
213 hoistExpr fs expr
214 = do
215 var <- newLocalVar fs (exprType expr)
216 updGEnv $ \env ->
217 env { global_bindings = (var, expr) : global_bindings env }
218 return var
219
220 hoistVExpr :: VExpr -> VM VVar
221 hoistVExpr (ve, le)
222 = do
223 fs <- getBindName
224 vv <- hoistExpr ('v' `consFS` fs) ve
225 lv <- hoistExpr ('l' `consFS` fs) le
226 return (vv, lv)
227
228 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
229 hoistPolyVExpr tvs p
230 = do
231 expr <- closedV . polyAbstract tvs $ \abstract ->
232 liftM (mapVect abstract) p
233 fn <- hoistVExpr expr
234 polyVApply (vVar fn) (mkTyVarTys tvs)
235
236 takeHoisted :: VM [(Var, CoreExpr)]
237 takeHoisted
238 = do
239 env <- readGEnv id
240 setGEnv $ env { global_bindings = [] }
241 return $ global_bindings env
242
243 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
244 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
245 = do
246 dict <- paDictOfType env_ty
247 mkv <- builtin mkClosureVar
248 mkl <- builtin mkClosurePVar
249 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
250 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
251
252 mkClosureApp :: VExpr -> VExpr -> VM VExpr
253 mkClosureApp (vclo, lclo) (varg, larg)
254 = do
255 vapply <- builtin applyClosureVar
256 lapply <- builtin applyClosurePVar
257 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
258 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
259 where
260 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
261
262 buildClosures :: [TyVar] -> Var -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
263 buildClosures tvs lc vars [arg_ty] res_ty mk_body
264 = buildClosure tvs lc vars arg_ty res_ty mk_body
265 buildClosures tvs lc vars (arg_ty : arg_tys) res_ty mk_body
266 = do
267 res_ty' <- mkClosureTypes arg_tys res_ty
268 arg <- newLocalVVar FSLIT("x") arg_ty
269 buildClosure tvs lc vars arg_ty res_ty'
270 . hoistPolyVExpr tvs
271 $ do
272 clo <- buildClosures tvs lc (vars ++ [arg]) arg_tys res_ty mk_body
273 return $ vLams lc (vars ++ [arg]) clo
274
275 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
276 -- where
277 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
278 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
279 --
280 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
281 buildClosure tvs lv vars arg_ty res_ty mk_body
282 = do
283 (env_ty, env, bind) <- buildEnv lv vars
284 env_bndr <- newLocalVVar FSLIT("env") env_ty
285 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
286
287 fn <- hoistPolyVExpr tvs
288 $ do
289 body <- mk_body
290 body' <- bind (vVar env_bndr)
291 (vVarApps lv body (vars ++ [arg_bndr]))
292 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
293
294 mkClosure arg_ty res_ty env_ty fn env
295
296 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
297 buildEnv lv vvs
298 = do
299 let (ty, venv, vbind) = mkVectEnv tys vs
300 (lenv, lbind) <- mkLiftEnv lv tys ls
301 return (ty, (venv, lenv),
302 \(venv,lenv) (vbody,lbody) ->
303 do
304 let vbody' = vbind venv vbody
305 lbody' <- lbind lenv lbody
306 return (vbody', lbody'))
307 where
308 (vs,ls) = unzip vvs
309 tys = map idType vs
310
311 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
312 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
313 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
314 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
315 \env body -> Case env (mkWildId ty) (exprType body)
316 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
317 where
318 ty = mkCoreTupTy tys
319
320 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
321 mkLiftEnv lv [ty] [v]
322 = return (Var v, \env body ->
323 do
324 len <- lengthPA (Var v)
325 return . Let (NonRec v env)
326 $ Case len lv (exprType body) [(DEFAULT, [], body)])
327
328 -- NOTE: this transparently deals with empty environments
329 mkLiftEnv lv tys vs
330 = do
331 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
332 let [env_con] = tyConDataCons env_tc
333
334 env = Var (dataConWrapId env_con)
335 `mkTyApps` env_tyargs
336 `mkVarApps` (lv : vs)
337
338 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
339 in
340 return $ Case scrut (mkWildId (exprType scrut))
341 (exprType body)
342 [(DataAlt env_con, lv : bndrs, body)]
343 return (env, bind)
344 where
345 vty = mkCoreTupTy tys
346
347 bndrs | null vs = [mkWildId unitTy]
348 | otherwise = vs
349