More refactoring
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 splitClosureTy,
5 mkPADictType, mkPArrayType,
6 paDictArgType, paDictOfType,
7 paMethod, lengthPA, replicatePA, emptyPA, liftPA,
8 polyAbstract, polyApply, polyVApply,
9 lookupPArrayFamInst,
10 hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
11 buildClosure, buildClosures,
12 mkClosureApp
13 ) where
14
15 #include "HsVersions.h"
16
17 import VectCore
18 import VectMonad
19
20 import DsUtils
21 import CoreSyn
22 import CoreUtils
23 import Type
24 import TypeRep
25 import TyCon
26 import DataCon ( dataConWrapId )
27 import Var
28 import Id ( mkWildId )
29 import MkId ( unwrapFamInstScrut )
30 import PrelNames
31 import TysWiredIn
32 import BasicTypes ( Boxity(..) )
33
34 import Outputable
35 import FastString
36
37 import Control.Monad ( liftM, zipWithM_ )
38
39 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
40 collectAnnTypeArgs expr = go expr []
41 where
42 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
43 go e tys = (e, tys)
44
45 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
46 collectAnnTypeBinders expr = go [] expr
47 where
48 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
49 go bs e = (reverse bs, e)
50
51 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
52 collectAnnValBinders expr = go [] expr
53 where
54 go bs (_, AnnLam b e) | isId b = go (b:bs) e
55 go bs e = (reverse bs, e)
56
57 isAnnTypeArg :: AnnExpr b ann -> Bool
58 isAnnTypeArg (_, AnnType t) = True
59 isAnnTypeArg _ = False
60
61 isClosureTyCon :: TyCon -> Bool
62 isClosureTyCon tc = tyConName tc == closureTyConName
63
64 splitClosureTy :: Type -> (Type, Type)
65 splitClosureTy ty
66 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
67 , isClosureTyCon tc
68 = (arg_ty, res_ty)
69
70 | otherwise = pprPanic "splitClosureTy" (ppr ty)
71
72 isPArrayTyCon :: TyCon -> Bool
73 isPArrayTyCon tc = tyConName tc == parrayTyConName
74
75 splitPArrayTy :: Type -> Type
76 splitPArrayTy ty
77 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
78 , isPArrayTyCon tc
79 = arg_ty
80
81 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
82
83 mkClosureType :: Type -> Type -> VM Type
84 mkClosureType arg_ty res_ty
85 = do
86 tc <- builtin closureTyCon
87 return $ mkTyConApp tc [arg_ty, res_ty]
88
89 mkClosureTypes :: [Type] -> Type -> VM Type
90 mkClosureTypes arg_tys res_ty
91 = do
92 tc <- builtin closureTyCon
93 return $ foldr (mk tc) res_ty arg_tys
94 where
95 mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
96
97 mkPADictType :: Type -> VM Type
98 mkPADictType ty
99 = do
100 tc <- builtin paDictTyCon
101 return $ TyConApp tc [ty]
102
103 mkPArrayType :: Type -> VM Type
104 mkPArrayType ty
105 = do
106 tc <- builtin parrayTyCon
107 return $ TyConApp tc [ty]
108
109 paDictArgType :: TyVar -> VM (Maybe Type)
110 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
111 where
112 go ty k | Just k' <- kindView k = go ty k'
113 go ty (FunTy k1 k2)
114 = do
115 tv <- newTyVar FSLIT("a") k1
116 mty1 <- go (TyVarTy tv) k1
117 case mty1 of
118 Just ty1 -> do
119 mty2 <- go (AppTy ty (TyVarTy tv)) k2
120 return $ fmap (ForAllTy tv . FunTy ty1) mty2
121 Nothing -> go ty k2
122
123 go ty k
124 | isLiftedTypeKind k
125 = liftM Just (mkPADictType ty)
126
127 go ty k = return Nothing
128
129 paDictOfType :: Type -> VM CoreExpr
130 paDictOfType ty = paDictOfTyApp ty_fn ty_args
131 where
132 (ty_fn, ty_args) = splitAppTys ty
133
134 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
135 paDictOfTyApp ty_fn ty_args
136 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
137 paDictOfTyApp (TyVarTy tv) ty_args
138 = do
139 dfun <- maybeV (lookupTyVarPA tv)
140 paDFunApply dfun ty_args
141 paDictOfTyApp (TyConApp tc _) ty_args
142 = do
143 pa_class <- builtin paClass
144 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
145 paDFunApply (Var dfun) ty_args'
146 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
147
148 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
149 paDFunApply dfun tys
150 = do
151 dicts <- mapM paDictOfType tys
152 return $ mkApps (mkTyApps dfun tys) dicts
153
154 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
155 paMethod method ty
156 = do
157 fn <- builtin method
158 dict <- paDictOfType ty
159 return $ mkApps (Var fn) [Type ty, dict]
160
161 lengthPA :: CoreExpr -> VM CoreExpr
162 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
163 where
164 ty = splitPArrayTy (exprType x)
165
166 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
167 replicatePA len x = liftM (`mkApps` [len,x])
168 (paMethod replicatePAVar (exprType x))
169
170 emptyPA :: Type -> VM CoreExpr
171 emptyPA = paMethod emptyPAVar
172
173 liftPA :: CoreExpr -> VM CoreExpr
174 liftPA x
175 = do
176 lc <- builtin liftingContext
177 replicatePA (Var lc) x
178
179 newLocalVVar :: FastString -> Type -> VM VVar
180 newLocalVVar fs vty
181 = do
182 lty <- mkPArrayType vty
183 vv <- newLocalVar fs vty
184 lv <- newLocalVar fs lty
185 return (vv,lv)
186
187 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
188 polyAbstract tvs p
189 = localV
190 $ do
191 mdicts <- mapM mk_dict_var tvs
192 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
193 p (mk_lams mdicts)
194 where
195 mk_dict_var tv = do
196 r <- paDictArgType tv
197 case r of
198 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
199 Nothing -> return Nothing
200
201 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
202
203 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
204 polyApply expr tys
205 = do
206 dicts <- mapM paDictOfType tys
207 return $ expr `mkTyApps` tys `mkApps` dicts
208
209 polyVApply :: VExpr -> [Type] -> VM VExpr
210 polyVApply expr tys
211 = do
212 dicts <- mapM paDictOfType tys
213 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
214
215 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
216 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
217
218 hoistBinding :: Var -> CoreExpr -> VM ()
219 hoistBinding v e = updGEnv $ \env ->
220 env { global_bindings = (v,e) : global_bindings env }
221
222 hoistExpr :: FastString -> CoreExpr -> VM Var
223 hoistExpr fs expr
224 = do
225 var <- newLocalVar fs (exprType expr)
226 hoistBinding var expr
227 return var
228
229 hoistVExpr :: VExpr -> VM VVar
230 hoistVExpr (ve, le)
231 = do
232 fs <- getBindName
233 vv <- hoistExpr ('v' `consFS` fs) ve
234 lv <- hoistExpr ('l' `consFS` fs) le
235 return (vv, lv)
236
237 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
238 hoistPolyVExpr tvs p
239 = do
240 expr <- closedV . polyAbstract tvs $ \abstract ->
241 liftM (mapVect abstract) p
242 fn <- hoistVExpr expr
243 polyVApply (vVar fn) (mkTyVarTys tvs)
244
245 takeHoisted :: VM [(Var, CoreExpr)]
246 takeHoisted
247 = do
248 env <- readGEnv id
249 setGEnv $ env { global_bindings = [] }
250 return $ global_bindings env
251
252 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
253 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
254 = do
255 dict <- paDictOfType env_ty
256 mkv <- builtin mkClosureVar
257 mkl <- builtin mkClosurePVar
258 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
259 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
260
261 mkClosureApp :: VExpr -> VExpr -> VM VExpr
262 mkClosureApp (vclo, lclo) (varg, larg)
263 = do
264 vapply <- builtin applyClosureVar
265 lapply <- builtin applyClosurePVar
266 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
267 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
268 where
269 (arg_ty, res_ty) = splitClosureTy (exprType vclo)
270
271 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
272 buildClosures tvs vars [arg_ty] res_ty mk_body
273 = buildClosure tvs vars arg_ty res_ty mk_body
274 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
275 = do
276 res_ty' <- mkClosureTypes arg_tys res_ty
277 arg <- newLocalVVar FSLIT("x") arg_ty
278 buildClosure tvs vars arg_ty res_ty'
279 . hoistPolyVExpr tvs
280 $ do
281 lc <- builtin liftingContext
282 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
283 return $ vLams lc (vars ++ [arg]) clo
284
285 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
286 -- where
287 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
288 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
289 --
290 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
291 buildClosure tvs vars arg_ty res_ty mk_body
292 = do
293 (env_ty, env, bind) <- buildEnv vars
294 env_bndr <- newLocalVVar FSLIT("env") env_ty
295 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
296
297 fn <- hoistPolyVExpr tvs
298 $ do
299 lc <- builtin liftingContext
300 body <- mk_body
301 body' <- bind (vVar env_bndr)
302 (vVarApps lc body (vars ++ [arg_bndr]))
303 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
304
305 mkClosure arg_ty res_ty env_ty fn env
306
307 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
308 buildEnv vvs
309 = do
310 lc <- builtin liftingContext
311 let (ty, venv, vbind) = mkVectEnv tys vs
312 (lenv, lbind) <- mkLiftEnv lc tys ls
313 return (ty, (venv, lenv),
314 \(venv,lenv) (vbody,lbody) ->
315 do
316 let vbody' = vbind venv vbody
317 lbody' <- lbind lenv lbody
318 return (vbody', lbody'))
319 where
320 (vs,ls) = unzip vvs
321 tys = map idType vs
322
323 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
324 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
325 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
326 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
327 \env body -> Case env (mkWildId ty) (exprType body)
328 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
329 where
330 ty = mkCoreTupTy tys
331
332 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
333 mkLiftEnv lc [ty] [v]
334 = return (Var v, \env body ->
335 do
336 len <- lengthPA (Var v)
337 return . Let (NonRec v env)
338 $ Case len lc (exprType body) [(DEFAULT, [], body)])
339
340 -- NOTE: this transparently deals with empty environments
341 mkLiftEnv lc tys vs
342 = do
343 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
344 let [env_con] = tyConDataCons env_tc
345
346 env = Var (dataConWrapId env_con)
347 `mkTyApps` env_tyargs
348 `mkVarApps` (lc : vs)
349
350 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
351 in
352 return $ Case scrut (mkWildId (exprType scrut))
353 (exprType body)
354 [(DataAlt env_con, lc : bndrs, body)]
355 return (env, bind)
356 where
357 vty = mkCoreTupTy tys
358
359 bndrs | null vs = [mkWildId unitTy]
360 | otherwise = vs
361