Collect hoisted vectorised functions
[ghc.git] / compiler / vectorise / Vectorise.hs
1 module Vectorise( vectorise )
2 where
3
4 #include "HsVersions.h"
5
6 import VectMonad
7 import VectUtils
8
9 import DynFlags
10 import HscTypes
11
12 import CoreLint ( showPass, endPass )
13 import CoreSyn
14 import CoreUtils
15 import CoreFVs
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Var
21 import VarEnv
22 import VarSet
23 import Name ( mkSysTvName )
24 import NameEnv
25 import Id
26 import MkId ( unwrapFamInstScrut )
27
28 import DsMonad hiding (mapAndUnzipM)
29 import DsUtils ( mkCoreTup, mkCoreTupTy )
30
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes ( Boxity(..) )
34
35 import Outputable
36 import FastString
37 import Control.Monad ( liftM, liftM2, mapAndUnzipM, zipWithM_ )
38 import Data.Maybe ( maybeToList )
39
40 vectorise :: HscEnv -> ModGuts -> IO ModGuts
41 vectorise hsc_env guts
42 | not (Opt_Vectorise `dopt` dflags) = return guts
43 | otherwise
44 = do
45 showPass dflags "Vectorisation"
46 eps <- hscEPS hsc_env
47 let info = hptVectInfo hsc_env `plusVectInfo` eps_vect_info eps
48 Just (info', guts') <- initV hsc_env guts info (vectModule guts)
49 endPass dflags "Vectorisation" Opt_D_dump_vect (mg_binds guts')
50 return $ guts' { mg_vect_info = info' }
51 where
52 dflags = hsc_dflags hsc_env
53
54 vectModule :: ModGuts -> VM ModGuts
55 vectModule guts = return guts
56
57 -- ----------------------------------------------------------------------------
58 -- Bindings
59
60 vectBndr :: Var -> VM (Var, Var)
61 vectBndr v
62 = do
63 vty <- vectType (idType v)
64 lty <- mkPArrayType vty
65 let vv = v `Id.setIdType` vty
66 lv = v `Id.setIdType` lty
67 updLEnv (mapTo vv lv)
68 return (vv, lv)
69 where
70 mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (Var vv, Var lv) }
71
72 vectBndrIn :: Var -> VM a -> VM (Var, Var, a)
73 vectBndrIn v p
74 = localV
75 $ do
76 (vv, lv) <- vectBndr v
77 x <- p
78 return (vv, lv, x)
79
80 vectBndrsIn :: [Var] -> VM a -> VM ([Var], [Var], a)
81 vectBndrsIn vs p
82 = localV
83 $ do
84 (vvs, lvs) <- mapAndUnzipM vectBndr vs
85 x <- p
86 return (vvs, lvs, x)
87
88 -- ----------------------------------------------------------------------------
89 -- Expressions
90
91 replicateP :: CoreExpr -> CoreExpr -> VM CoreExpr
92 replicateP expr len
93 = do
94 dict <- paDictOfType ty
95 rep <- builtin replicatePAVar
96 return $ mkApps (Var rep) [Type ty, dict, expr, len]
97 where
98 ty = exprType expr
99
100 capply :: (CoreExpr, CoreExpr) -> (CoreExpr, CoreExpr) -> VM (CoreExpr, CoreExpr)
101 capply (vfn, lfn) (varg, larg)
102 = do
103 apply <- builtin applyClosureVar
104 applyP <- builtin applyClosurePVar
105 return (mkApps (Var apply) [Type arg_ty, Type res_ty, vfn, varg],
106 mkApps (Var applyP) [Type arg_ty, Type res_ty, lfn, larg])
107 where
108 fn_ty = exprType vfn
109 (arg_ty, res_ty) = splitClosureTy fn_ty
110
111 vectVar :: CoreExpr -> Var -> VM (CoreExpr, CoreExpr)
112 vectVar lc v = local v `orElseV` global v
113 where
114 local v = maybeV (readLEnv $ \env -> lookupVarEnv (local_vars env) v)
115 global v = do
116 vexpr <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
117 lexpr <- replicateP vexpr lc
118 return (vexpr, lexpr)
119
120 vectPolyVar :: CoreExpr -> Var -> [Type] -> VM (CoreExpr, CoreExpr)
121 vectPolyVar lc v tys
122 = do
123 r <- readLEnv $ \env -> lookupVarEnv (local_vars env) v
124 case r of
125 Just (vexpr, lexpr) -> liftM2 (,) (mk_app vexpr) (mk_app lexpr)
126 Nothing ->
127 do
128 poly <- maybeV (readGEnv $ \env -> lookupVarEnv (global_vars env) v)
129 vexpr <- mk_app poly
130 lexpr <- replicateP vexpr lc
131 return (vexpr, lexpr)
132 where
133 mk_app e = applyToTypes e =<< mapM vectType tys
134
135 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
136 abstractOverTyVars tvs p
137 = do
138 mdicts <- mapM mk_dict_var tvs
139 zipWithM_ (\tv -> maybe (deleteTyVarPA tv) (extendTyVarPA tv . Var)) tvs mdicts
140 p (mk_lams mdicts)
141 where
142 mk_dict_var tv = do
143 r <- paDictArgType tv
144 case r of
145 Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
146 Nothing -> return Nothing
147
148 mk_lams mdicts = mkLams [arg | (tv, mdict) <- zip tvs mdicts
149 , arg <- tv : maybeToList mdict]
150
151 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
152 applyToTypes expr tys
153 = do
154 dicts <- mapM paDictOfType tys
155 return $ mkApps expr [arg | (ty, dict) <- zip tys dicts
156 , arg <- [Type ty, dict]]
157
158
159 vectPolyExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
160 vectPolyExpr lc expr
161 = localV
162 . abstractOverTyVars tvs $ \mk_lams ->
163 -- FIXME: shadowing (tvs in lc)
164 do
165 (vmono, lmono) <- vectExpr lc mono
166 return $ (mk_lams vmono, mk_lams lmono)
167 where
168 (tvs, mono) = collectAnnTypeBinders expr
169
170 vectExpr :: CoreExpr -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
171 vectExpr lc (_, AnnType ty)
172 = do
173 vty <- vectType ty
174 return (Type vty, Type vty)
175
176 vectExpr lc (_, AnnVar v) = vectVar lc v
177
178 vectExpr lc (_, AnnLit lit)
179 = do
180 let vexpr = Lit lit
181 lexpr <- replicateP vexpr lc
182 return (vexpr, lexpr)
183
184 vectExpr lc (_, AnnNote note expr)
185 = do
186 (vexpr, lexpr) <- vectExpr lc expr
187 return (Note note vexpr, Note note lexpr)
188
189 vectExpr lc e@(_, AnnApp _ arg)
190 | isAnnTypeArg arg
191 = vectTyAppExpr lc fn tys
192 where
193 (fn, tys) = collectAnnTypeArgs e
194
195 vectExpr lc (_, AnnApp fn arg)
196 = do
197 fn' <- vectExpr lc fn
198 arg' <- vectExpr lc arg
199 capply fn' arg'
200
201 vectExpr lc (_, AnnCase expr bndr ty alts)
202 = panic "vectExpr: case"
203
204 vectExpr lc (_, AnnLet (AnnNonRec bndr rhs) body)
205 = do
206 (vrhs, lrhs) <- vectPolyExpr lc rhs
207 (vbndr, lbndr, (vbody, lbody)) <- vectBndrIn bndr (vectExpr lc body)
208 return (Let (NonRec vbndr vrhs) vbody,
209 Let (NonRec lbndr lrhs) lbody)
210
211 vectExpr lc (_, AnnLet (AnnRec prs) body)
212 = do
213 (vbndrs, lbndrs, (vrhss, vbody, lrhss, lbody)) <- vectBndrsIn bndrs vect
214 return (Let (Rec (zip vbndrs vrhss)) vbody,
215 Let (Rec (zip lbndrs lrhss)) lbody)
216 where
217 (bndrs, rhss) = unzip prs
218
219 vect = do
220 (vrhss, lrhss) <- mapAndUnzipM (vectExpr lc) rhss
221 (vbody, lbody) <- vectPolyExpr lc body
222 return (vrhss, vbody, lrhss, lbody)
223
224 vectExpr lc e@(_, AnnLam bndr body)
225 | isTyVar bndr = pprPanic "vectExpr" (ppr $ deAnnotate e)
226
227 vectExpr lc (fvs, AnnLam bndr body)
228 = do
229 let tyvars = filter isTyVar (varSetElems fvs)
230 info <- mkCEnvInfo fvs bndr body
231 (poly_vfn, poly_lfn) <- mkClosureFns info tyvars bndr body
232
233 vfn_var <- hoistExpr FSLIT("vfn") poly_vfn
234 lfn_var <- hoistExpr FSLIT("lfn") poly_lfn
235
236 let (venv, lenv) = mkClosureEnvs info lc
237
238 let env_ty = cenv_vty info
239
240 pa_dict <- paDictOfType env_ty
241
242 arg_ty <- vectType (varType bndr)
243 res_ty <- vectType (exprType $ deAnnotate body)
244
245 -- FIXME: move the functions to the top level
246 mono_vfn <- applyToTypes (Var vfn_var) (map TyVarTy tyvars)
247 mono_lfn <- applyToTypes (Var lfn_var) (map TyVarTy tyvars)
248
249 mk_clo <- builtin mkClosureVar
250 mk_cloP <- builtin mkClosurePVar
251
252 let vclo = Var mk_clo `mkTyApps` [arg_ty, res_ty, env_ty]
253 `mkApps` [pa_dict, mono_vfn, mono_lfn, venv]
254
255 lclo = Var mk_cloP `mkTyApps` [arg_ty, res_ty, env_ty]
256 `mkApps` [pa_dict, mono_vfn, mono_lfn, lenv]
257
258 return (vclo, lclo)
259
260
261 data CEnvInfo = CEnvInfo {
262 cenv_vars :: [Var]
263 , cenv_values :: [(CoreExpr, CoreExpr)]
264 , cenv_vty :: Type
265 , cenv_lty :: Type
266 , cenv_repr_tycon :: TyCon
267 , cenv_repr_tyargs :: [Type]
268 , cenv_repr_datacon :: DataCon
269 }
270
271 mkCEnvInfo :: VarSet -> Var -> CoreExprWithFVs -> VM CEnvInfo
272 mkCEnvInfo fvs arg body
273 = do
274 locals <- readLEnv local_vars
275 let
276 (vars, vals) = unzip
277 [(var, val) | var <- varSetElems fvs
278 , Just val <- [lookupVarEnv locals var]]
279 vtys <- mapM (vectType . varType) vars
280
281 (vty, repr_tycon, repr_tyargs, repr_datacon) <- mk_env_ty vtys
282 lty <- mkPArrayType vty
283
284 return $ CEnvInfo {
285 cenv_vars = vars
286 , cenv_values = vals
287 , cenv_vty = vty
288 , cenv_lty = lty
289 , cenv_repr_tycon = repr_tycon
290 , cenv_repr_tyargs = repr_tyargs
291 , cenv_repr_datacon = repr_datacon
292 }
293 where
294 mk_env_ty [vty]
295 = return (vty, error "absent cinfo_repr_tycon"
296 , error "absent cinfo_repr_tyargs"
297 , error "absent cinfo_repr_datacon")
298
299 mk_env_ty vtys
300 = do
301 let ty = mkCoreTupTy vtys
302 (repr_tc, repr_tyargs) <- lookupPArrayFamInst ty
303 let [repr_con] = tyConDataCons repr_tc
304 return (ty, repr_tc, repr_tyargs, repr_con)
305
306
307
308 mkClosureEnvs :: CEnvInfo -> CoreExpr -> (CoreExpr, CoreExpr)
309 mkClosureEnvs info lc
310 | [] <- vals
311 = (Var unitDataConId, mkApps (Var $ dataConWrapId (cenv_repr_datacon info))
312 [lc, Var unitDataConId])
313
314 | [(vval, lval)] <- vals
315 = (vval, lval)
316
317 | otherwise
318 = (mkCoreTup vvals, Var (dataConWrapId $ cenv_repr_datacon info)
319 `mkTyApps` cenv_repr_tyargs info
320 `mkApps` (lc : lvals))
321
322 where
323 vals = cenv_values info
324 (vvals, lvals) = unzip vals
325
326 mkClosureFns :: CEnvInfo -> [TyVar] -> Var -> CoreExprWithFVs
327 -> VM (CoreExpr, CoreExpr)
328 mkClosureFns info tyvars arg body
329 = closedV
330 . abstractOverTyVars tyvars
331 $ \mk_tlams ->
332 do
333 (vfn, lfn) <- mkClosureMonoFns info arg body
334 return (mk_tlams vfn, mk_tlams lfn)
335
336 mkClosureMonoFns :: CEnvInfo -> Var -> CoreExprWithFVs -> VM (CoreExpr, CoreExpr)
337 mkClosureMonoFns info arg body
338 = do
339 lc_bndr <- newLocalVar FSLIT("lc") intTy
340 (varg : vbndrs, larg : lbndrs, (vbody, lbody))
341 <- vectBndrsIn (arg : cenv_vars info)
342 (vectExpr (Var lc_bndr) body)
343
344 venv_bndr <- newLocalVar FSLIT("env") vty
345 lenv_bndr <- newLocalVar FSLIT("env") lty
346
347 let vcase = bind_venv (Var venv_bndr) vbody vbndrs
348 lcase <- bind_lenv (Var lenv_bndr) lbody lc_bndr lbndrs
349 return (mkLams [venv_bndr, varg] vcase, mkLams [lenv_bndr, larg] lcase)
350 where
351 vty = cenv_vty info
352 lty = cenv_lty info
353
354 arity = length (cenv_vars info)
355
356 bind_venv venv vbody [] = vbody
357 bind_venv venv vbody [vbndr] = Let (NonRec vbndr venv) vbody
358 bind_venv venv vbody vbndrs
359 = Case venv (mkWildId vty) (exprType vbody)
360 [(DataAlt (tupleCon Boxed arity), vbndrs, vbody)]
361
362 bind_lenv lenv lbody lc_bndr [lbndr]
363 = do
364 lengthPA <- builtin lengthPAVar
365 return . Let (NonRec lbndr lenv)
366 $ Case (mkApps (Var lengthPA) [Type vty, (Var lbndr)])
367 lc_bndr
368 intTy
369 [(DEFAULT, [], lbody)]
370
371 bind_lenv lenv lbody lc_bndr lbndrs
372 = return
373 $ Case (unwrapFamInstScrut (cenv_repr_tycon info)
374 (cenv_repr_tyargs info)
375 lenv)
376 (mkWildId lty)
377 (exprType lbody)
378 [(DataAlt (cenv_repr_datacon info), lc_bndr : lbndrs, lbody)]
379
380 vectTyAppExpr :: CoreExpr -> CoreExprWithFVs -> [Type] -> VM (CoreExpr, CoreExpr)
381 vectTyAppExpr lc (_, AnnVar v) tys = vectPolyVar lc v tys
382 vectTyAppExpr lc e tys = pprPanic "vectTyAppExpr" (ppr $ deAnnotate e)
383
384 -- ----------------------------------------------------------------------------
385 -- Types
386
387 vectTyCon :: TyCon -> VM TyCon
388 vectTyCon tc
389 | isFunTyCon tc = builtin closureTyCon
390 | isBoxedTupleTyCon tc = return tc
391 | isUnLiftedTyCon tc = return tc
392 | otherwise = do
393 r <- lookupTyCon tc
394 case r of
395 Just tc' -> return tc'
396
397 -- FIXME: just for now
398 Nothing -> pprTrace "ccTyCon:" (ppr tc) $ return tc
399
400 vectType :: Type -> VM Type
401 vectType ty | Just ty' <- coreView ty = vectType ty
402 vectType (TyVarTy tv) = return $ TyVarTy tv
403 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
404 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
405 vectType (FunTy ty1 ty2) = liftM2 TyConApp (builtin closureTyCon)
406 (mapM vectType [ty1,ty2])
407 vectType (ForAllTy tv ty)
408 = do
409 r <- paDictArgType tv
410 ty' <- vectType ty
411 return $ ForAllTy tv (wrap r ty')
412 where
413 wrap Nothing = id
414 wrap (Just pa_ty) = FunTy pa_ty
415
416 vectType ty = pprPanic "vectType:" (ppr ty)
417