Improve closure generation for functions with multiple parameters
[ghc.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2 collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3 collectAnnValBinders,
4 splitClosureTy,
5 mkPADictType, mkPArrayType,
6 paDictArgType, paDictOfType,
7 paMethod, lengthPA, replicatePA, emptyPA,
8 polyAbstract, polyApply, polyVApply,
9 lookupPArrayFamInst,
10 hoistExpr, hoistPolyVExpr, takeHoisted,
11 buildClosure, buildClosures
12 ) where
13
14 #include "HsVersions.h"
15
16 import VectCore
17 import VectMonad
18
19 import DsUtils
20 import CoreSyn
21 import CoreUtils
22 import Type
23 import TypeRep
24 import TyCon
25 import DataCon ( dataConWrapId )
26 import Var
27 import Id ( mkWildId )
28 import MkId ( unwrapFamInstScrut )
29 import PrelNames
30 import TysWiredIn
31 import BasicTypes ( Boxity(..) )
32
33 import Outputable
34 import FastString
35
36 import Control.Monad ( liftM, zipWithM_ )
37
38 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
39 collectAnnTypeArgs expr = go expr []
40 where
41 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
42 go e tys = (e, tys)
43
44 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
45 collectAnnTypeBinders expr = go [] expr
46 where
47 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
48 go bs e = (reverse bs, e)
49
50 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
51 collectAnnValBinders expr = go [] expr
52 where
53 go bs (_, AnnLam b e) | isId b = go (b:bs) e
54 go bs e = (reverse bs, e)
55
56 isAnnTypeArg :: AnnExpr b ann -> Bool
57 isAnnTypeArg (_, AnnType t) = True
58 isAnnTypeArg _ = False
59
60 isClosureTyCon :: TyCon -> Bool
61 isClosureTyCon tc = tyConName tc == closureTyConName
62
63 splitClosureTy :: Type -> (Type, Type)
64 splitClosureTy ty
65 | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
66 , isClosureTyCon tc
67 = (arg_ty, res_ty)
68
69 | otherwise = pprPanic "splitClosureTy" (ppr ty)
70
71 isPArrayTyCon :: TyCon -> Bool
72 isPArrayTyCon tc = tyConName tc == parrayTyConName
73
74 splitPArrayTy :: Type -> Type
75 splitPArrayTy ty
76 | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
77 , isPArrayTyCon tc
78 = arg_ty
79
80 | otherwise = pprPanic "splitPArrayTy" (ppr ty)
81
82 mkClosureType :: Type -> Type -> VM Type
83 mkClosureType arg_ty res_ty
84 = do
85 tc <- builtin closureTyCon
86 return $ mkTyConApp tc [arg_ty, res_ty]
87
88 mkClosureTypes :: [Type] -> Type -> VM Type
89 mkClosureTypes arg_tys res_ty
90 = do
91 tc <- builtin closureTyCon
92 return $ foldr (mk tc) res_ty arg_tys
93 where
94 mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
95
96 mkPADictType :: Type -> VM Type
97 mkPADictType ty
98 = do
99 tc <- builtin paDictTyCon
100 return $ TyConApp tc [ty]
101
102 mkPArrayType :: Type -> VM Type
103 mkPArrayType ty
104 = do
105 tc <- builtin parrayTyCon
106 return $ TyConApp tc [ty]
107
108 paDictArgType :: TyVar -> VM (Maybe Type)
109 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
110 where
111 go ty k | Just k' <- kindView k = go ty k'
112 go ty (FunTy k1 k2)
113 = do
114 tv <- newTyVar FSLIT("a") k1
115 mty1 <- go (TyVarTy tv) k1
116 case mty1 of
117 Just ty1 -> do
118 mty2 <- go (AppTy ty (TyVarTy tv)) k2
119 return $ fmap (ForAllTy tv . FunTy ty1) mty2
120 Nothing -> go ty k2
121
122 go ty k
123 | isLiftedTypeKind k
124 = liftM Just (mkPADictType ty)
125
126 go ty k = return Nothing
127
128 paDictOfType :: Type -> VM CoreExpr
129 paDictOfType ty = paDictOfTyApp ty_fn ty_args
130 where
131 (ty_fn, ty_args) = splitAppTys ty
132
133 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
134 paDictOfTyApp ty_fn ty_args
135 | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
136 paDictOfTyApp (TyVarTy tv) ty_args
137 = do
138 dfun <- maybeV (lookupTyVarPA tv)
139 paDFunApply dfun ty_args
140 paDictOfTyApp (TyConApp tc _) ty_args
141 = do
142 pa_class <- builtin paClass
143 (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
144 paDFunApply (Var dfun) ty_args'
145 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
146
147 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
148 paDFunApply dfun tys
149 = do
150 dicts <- mapM paDictOfType tys
151 return $ mkApps (mkTyApps dfun tys) dicts
152
153 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
154 paMethod method ty
155 = do
156 fn <- builtin method
157 dict <- paDictOfType ty
158 return $ mkApps (Var fn) [Type ty, dict]
159
160 lengthPA :: CoreExpr -> VM CoreExpr
161 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
162 where
163 ty = splitPArrayTy (exprType x)
164
165 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
166 replicatePA len x = liftM (`mkApps` [len,x])
167 (paMethod replicatePAVar (exprType x))
168
169 emptyPA :: Type -> VM CoreExpr
170 emptyPA = paMethod emptyPAVar
171
172 newLocalVVar :: FastString -> Type -> VM VVar
173 newLocalVVar fs vty
174 = do
175 lty <- mkPArrayType vty
176 vv <- newLocalVar fs vty
177 lv <- newLocalVar fs lty
178 return (vv,lv)
179
180 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
181 polyAbstract tvs p
182 = localV
183 $ do
184 mdicts <- mapM mk_dict_var tvs
185 zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
186 p (mk_lams mdicts)
187 where
188 mk_dict_var tv = do
189 r <- paDictArgType tv
190 case r of
191 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
192 Nothing -> return Nothing
193
194 mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
195
196 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
197 polyApply expr tys
198 = do
199 dicts <- mapM paDictOfType tys
200 return $ expr `mkTyApps` tys `mkApps` dicts
201
202 polyVApply :: VExpr -> [Type] -> VM VExpr
203 polyVApply expr tys
204 = do
205 dicts <- mapM paDictOfType tys
206 return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
207
208 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
209 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
210
211 hoistExpr :: FastString -> CoreExpr -> VM Var
212 hoistExpr fs expr
213 = do
214 var <- newLocalVar fs (exprType expr)
215 updGEnv $ \env ->
216 env { global_bindings = (var, expr) : global_bindings env }
217 return var
218
219 hoistVExpr :: FastString -> VExpr -> VM VVar
220 hoistVExpr fs (ve, le)
221 = do
222 vv <- hoistExpr ('v' `consFS` fs) ve
223 lv <- hoistExpr ('l' `consFS` fs) le
224 return (vv, lv)
225
226 hoistPolyVExpr :: FastString -> [TyVar] -> VM VExpr -> VM VExpr
227 hoistPolyVExpr fs tvs p
228 = do
229 expr <- closedV . polyAbstract tvs $ \abstract ->
230 liftM (mapVect abstract) p
231 fn <- hoistVExpr fs expr
232 polyVApply (vVar fn) (mkTyVarTys tvs)
233
234 takeHoisted :: VM [(Var, CoreExpr)]
235 takeHoisted
236 = do
237 env <- readGEnv id
238 setGEnv $ env { global_bindings = [] }
239 return $ global_bindings env
240
241
242 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
243 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
244 = do
245 dict <- paDictOfType env_ty
246 mkv <- builtin mkClosureVar
247 mkl <- builtin mkClosurePVar
248 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
249 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
250
251 buildClosures :: [TyVar] -> Var -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
252 buildClosures tvs lc vars [arg_ty] res_ty mk_body
253 = buildClosure tvs lc vars arg_ty res_ty mk_body
254 buildClosures tvs lc vars (arg_ty : arg_tys) res_ty mk_body
255 = do
256 res_ty' <- mkClosureTypes arg_tys res_ty
257 arg <- newLocalVVar FSLIT("x") arg_ty
258 buildClosure tvs lc vars arg_ty res_ty'
259 . hoistPolyVExpr FSLIT("fn") tvs
260 $ do
261 clo <- buildClosures tvs lc (vars ++ [arg]) arg_tys res_ty mk_body
262 return $ vLams lc (vars ++ [arg]) clo
263
264 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
265 -- where
266 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
267 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
268 --
269 buildClosure :: [TyVar] -> Var -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
270 buildClosure tvs lv vars arg_ty res_ty mk_body
271 = do
272 (env_ty, env, bind) <- buildEnv lv vars
273 env_bndr <- newLocalVVar FSLIT("env") env_ty
274 arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
275
276 fn <- hoistPolyVExpr FSLIT("fn") tvs
277 $ do
278 body <- mk_body
279 body' <- bind (vVar env_bndr)
280 (vVarApps lv body (vars ++ [arg_bndr]))
281 return (vLamsWithoutLC [env_bndr, arg_bndr] body')
282
283 mkClosure arg_ty res_ty env_ty fn env
284
285 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
286 buildEnv lv vvs
287 = do
288 let (ty, venv, vbind) = mkVectEnv tys vs
289 (lenv, lbind) <- mkLiftEnv lv tys ls
290 return (ty, (venv, lenv),
291 \(venv,lenv) (vbody,lbody) ->
292 do
293 let vbody' = vbind venv vbody
294 lbody' <- lbind lenv lbody
295 return (vbody', lbody'))
296 where
297 (vs,ls) = unzip vvs
298 tys = map idType vs
299
300 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
301 mkVectEnv [] [] = (unitTy, Var unitDataConId, \env body -> body)
302 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
303 mkVectEnv tys vs = (ty, mkCoreTup (map Var vs),
304 \env body -> Case env (mkWildId ty) (exprType body)
305 [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
306 where
307 ty = mkCoreTupTy tys
308
309 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
310 mkLiftEnv lv [ty] [v]
311 = return (Var v, \env body ->
312 do
313 len <- lengthPA (Var v)
314 return . Let (NonRec v env)
315 $ Case len lv (exprType body) [(DEFAULT, [], body)])
316
317 -- NOTE: this transparently deals with empty environments
318 mkLiftEnv lv tys vs
319 = do
320 (env_tc, env_tyargs) <- lookupPArrayFamInst vty
321 let [env_con] = tyConDataCons env_tc
322
323 env = Var (dataConWrapId env_con)
324 `mkTyApps` env_tyargs
325 `mkVarApps` (lv : vs)
326
327 bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
328 in
329 return $ Case scrut (mkWildId (exprType scrut))
330 (exprType body)
331 [(DataAlt env_con, lv : bndrs, body)]
332 return (env, bind)
333 where
334 vty = mkCoreTupTy tys
335
336 bndrs | null vs = [mkWildId unitTy]
337 | otherwise = vs
338