6d6a473b44f5ac0eb2c960836b01f1540b42530c
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp (
4
5 -- Vectorise a polymorphic expression
6 vectPolyExpr,
7
8 -- Vectorise a scalar expression of functional type
9 vectScalarFun
10 ) where
11
12 #include "HsVersions.h"
13
14 import Vectorise.Type.Type
15 import Vectorise.Var
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19 import Vectorise.Builtins
20 import Vectorise.Utils
21
22 import CoreSyn
23 import CoreUtils
24 import MkCore
25 import CoreFVs
26 import DataCon
27 import TyCon
28 import Type
29 import Var
30 import VarEnv
31 import VarSet
32 import Id
33 import BasicTypes( isStrongLoopBreaker )
34 import Literal
35 import TysWiredIn
36 import TysPrim
37 import Outputable
38 import FastString
39 import Control.Monad
40 import Data.List
41
42
43 -- | Vectorise a polymorphic expression.
44 --
45 vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
46 -- binding is a loop breaker.
47 -> [Var]
48 -> CoreExprWithFVs
49 -> VM (Inline, Bool, VExpr)
50 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
51 = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
52 return (inline, isScalarFn, vNote note expr')
53 vectPolyExpr loop_breaker recFns expr
54 = do
55 arity <- polyArity tvs
56 polyAbstract tvs $ \args ->
57 do
58 (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
59 return (addInlineArity inline arity, isScalarFn,
60 mapVect (mkLams $ tvs ++ args) mono')
61 where
62 (tvs, mono) = collectAnnTypeBinders expr
63
64
65 -- |Vectorise an expression.
66 --
67 vectExpr :: CoreExprWithFVs -> VM VExpr
68 vectExpr (_, AnnType ty)
69 = liftM vType (vectType ty)
70
71 vectExpr (_, AnnVar v)
72 = vectVar v
73
74 vectExpr (_, AnnLit lit)
75 = vectLiteral lit
76
77 vectExpr (_, AnnNote note expr)
78 = liftM (vNote note) (vectExpr expr)
79
80 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
81 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
82 -- happy.
83 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
84 | v == pAT_ERROR_ID
85 = do { (vty, lty) <- vectAndLiftType ty
86 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
87 }
88 where
89 err' = deAnnotate err
90
91 vectExpr e@(_, AnnApp _ arg)
92 | isAnnTypeArg arg
93 = vectTyAppExpr fn tys
94 where
95 (fn, tys) = collectAnnTypeArgs e
96
97 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
98 | Just con <- isDataConId_maybe v
99 , is_special_con con
100 = do
101 let vexpr = App (Var v) (Lit lit)
102 lexpr <- liftPD vexpr
103 return (vexpr, lexpr)
104 where
105 is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
106
107
108 -- TODO: Avoid using closure application for dictionaries.
109 -- vectExpr (_, AnnApp fn arg)
110 -- | if is application of dictionary
111 -- just use regular app instead of closure app.
112
113 -- for lifted version.
114 -- do liftPD (sub a dNumber)
115 -- lift the result of the selection, not sub and dNumber seprately.
116
117 vectExpr (_, AnnApp fn arg)
118 = do
119 arg_ty' <- vectType arg_ty
120 res_ty' <- vectType res_ty
121
122 fn' <- vectExpr fn
123 arg' <- vectExpr arg
124
125 mkClosureApp arg_ty' res_ty' fn' arg'
126 where
127 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
128
129 vectExpr (_, AnnCase scrut bndr ty alts)
130 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
131 , isAlgTyCon tycon
132 = vectAlgCase tycon ty_args scrut bndr ty alts
133 | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
134 where
135 scrut_ty = exprType (deAnnotate scrut)
136
137 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
138 = do
139 vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
140 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
141 return $ vLet (vNonRec vbndr vrhs) vbody
142
143 vectExpr (_, AnnLet (AnnRec bs) body)
144 = do
145 (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
146 $ liftM2 (,)
147 (zipWithM vect_rhs bndrs rhss)
148 (vectExpr body)
149 return $ vLet (vRec vbndrs vrhss) vbody
150 where
151 (bndrs, rhss) = unzip bs
152
153 vect_rhs bndr rhs = localV
154 . inBind bndr
155 . liftM (\(_,_,z)->z)
156 $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs
157
158 vectExpr e@(_, AnnLam bndr _)
159 | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
160 {-
161 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
162 `orElseV` vectLam True fvs bs body
163 where
164 (bs,body) = collectAnnValBinders e
165 -}
166
167 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
168
169 -- | Vectorise an expression with an outer lambda abstraction.
170 --
171 vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
172 -- be inlined
173 -> Bool -- ^ Whether the binding is a loop breaker
174 -> [Var] -- ^ Names of function in same recursive binding group
175 -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
176 -> VM (Inline, Bool, VExpr)
177 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
178 | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
179 `orElseV`
180 mark inlineMe False (vectLam inline loop_breaker expr)
181 vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e
182
183 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
184 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
185
186 -- |Vectorise an expression of functional type, where all arguments and the result are of scalar
187 -- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
188 -- involve parallel arrays. Such functionals do not requires the full blown vectorisation
189 -- transformation; instead, they can be lifted by application of a member of the zipWith family
190 -- (i.e., 'map', 'zipWith', zipWith3', etc.)
191 --
192 vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
193 -> [Var] -- ^ Functions names in same recursive binding group
194 -> CoreExpr -- ^ Expression to be vectorised
195 -> VM VExpr
196 vectScalarFun forceScalar recFns expr
197 = do { gscalars <- globalScalars
198 ; let scalars = gscalars `extendVarSetList` recFns
199 (arg_tys, res_ty) = splitFunTys (exprType expr)
200 ; MASSERT( not $ null arg_tys )
201 ; onlyIfV (forceScalar -- user asserts the functions is scalar
202 ||
203 all is_prim_ty arg_tys -- check whether the function is scalar
204 && is_prim_ty res_ty
205 && is_scalar scalars expr
206 && uses scalars expr)
207 $ mkScalarFun arg_tys res_ty expr
208 }
209 where
210 -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!!
211 is_prim_ty ty
212 | Just (tycon, []) <- splitTyConApp_maybe ty
213 = tycon == intTyCon
214 || tycon == floatTyCon
215 || tycon == doubleTyCon
216 | otherwise = False
217
218 -- Checks whether an expression contain a non-scalar subexpression.
219 --
220 -- Precodition: The variables in the first argument are scalar.
221 --
222 -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
223 -- them to the list of scalar variables) and then check them. If one of them turns out not to
224 -- be scalar, the entire group is regarded as not being scalar.
225 --
226 -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
227 -- data constructor as scalar. Should be changed once scalar types are passed
228 -- through VectInfo.
229 --
230 is_scalar :: VarSet -> CoreExpr -> Bool
231 is_scalar scalars (Var v) = v `elemVarSet` scalars
232 is_scalar _scalars (Lit _) = True
233 is_scalar scalars e@(App e1 e2)
234 | maybe_parr_ty (exprType e) = False
235 | otherwise = is_scalar scalars e1 && is_scalar scalars e2
236 is_scalar scalars (Lam var body)
237 | maybe_parr_ty (varType var) = False
238 | otherwise = is_scalar (scalars `extendVarSet` var) body
239 is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body
240 where
241 (bindsAreScalar, scalars') = is_scalar_bind scalars bind
242 is_scalar scalars (Case e var ty alts)
243 | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts
244 | otherwise = False
245 where
246 scalars' = scalars `extendVarSet` var
247 is_scalar scalars (Cast e _coe) = is_scalar scalars e
248 is_scalar scalars (Note _ e ) = is_scalar scalars e
249 is_scalar _scalars (Type {}) = True
250 is_scalar _scalars (Coercion {}) = True
251
252 -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
253 is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
254 is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars')
255 where
256 (vars, es) = unzip bnds
257 scalars' = scalars `extendVarSetList` vars
258
259 is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
260
261 -- Checks whether the type might be a parallel array type. In particular, if the outermost
262 -- constructor is a type family, we conservatively assume that it may be a parallel array type.
263 maybe_parr_ty :: Type -> Bool
264 maybe_parr_ty ty
265 | Just ty' <- coreView ty = maybe_parr_ty ty'
266 | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
267 maybe_parr_ty _ = False
268
269 -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions
270 -- is called by some other function that is otherwise scalar, it would be very bad
271 -- that just this call to the identity makes it not be scalar.
272 -- A scalar function has to actually compute something. Without the check,
273 -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
274 -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
275 -- (\n# x -> x) which is what we want.
276 uses funs (Var v) = v `elemVarSet` funs
277 uses funs (App e1 e2) = uses funs e1 || uses funs e2
278 uses funs (Lam b body) = uses (funs `extendVarSet` b) body
279 uses funs (Let (NonRec _b letExpr) body)
280 = uses funs letExpr || uses funs body
281 uses funs (Case e _eId _ty alts)
282 = uses funs e || any (uses_alt funs) alts
283 uses _ _ = False
284
285 uses_alt funs (_, _bs, e) = uses funs e
286
287 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
288 mkScalarFun arg_tys res_ty expr
289 = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline
290 ; zipf <- zipScalars arg_tys res_ty
291 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
292 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
293 ; lclo <- liftPD (Var clo_var)
294 ; return (Var clo_var, lclo)
295 }
296
297 -- | Vectorise a lambda abstraction.
298 --
299 vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
300 -> Bool -- ^ Whether the binding is a loop breaker.
301 -> CoreExprWithFVs -- ^ Body of abstraction.
302 -> VM VExpr
303 vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
304 = do let (bs, body) = collectAnnValBinders expr
305
306 tyvars <- localTyVars
307 (vs, vvs) <- readLEnv $ \env ->
308 unzip [(var, vv) | var <- varSetElems fvs
309 , Just vv <- [lookupVarEnv (local_vars env) var]]
310
311 arg_tys <- mapM (vectType . idType) bs
312 res_ty <- vectType (exprType $ deAnnotate body)
313
314 buildClosures tyvars vvs arg_tys res_ty
315 . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
316 $ do
317 lc <- builtin liftingContext
318 (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
319
320 vbody' <- break_loop lc res_ty vbody
321 return $ vLams lc vbndrs vbody'
322 where
323 maybe_inline n | inline = Inline n
324 | otherwise = DontInline
325
326 break_loop lc ty (ve, le)
327 | loop_breaker
328 = do
329 empty <- emptyPD ty
330 lty <- mkPDataType ty
331 return (ve, mkWildCase (Var lc) intPrimTy lty
332 [(DEFAULT, [], le),
333 (LitAlt (mkMachInt 0), [], empty)])
334
335 | otherwise = return (ve, le)
336 vectLam _ _ _ = panic "vectLam"
337
338
339 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
340 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
341 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
342 (ppr $ deAnnotate e `mkTyApps` tys)
343
344
345 -- | Vectorise an algebraic case expression.
346 -- We convert
347 --
348 -- case e :: t of v { ... }
349 --
350 -- to
351 --
352 -- V: let v' = e in case v' of _ { ... }
353 -- L: let v' = e in case v' `cast` ... of _ { ... }
354 --
355 -- When lifting, we have to do it this way because v must have the type
356 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
357 -- have to handle the case where v is a wild var correctly.
358 --
359
360 -- FIXME: this is too lazy
361 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
362 -> [(AltCon, [Var], CoreExprWithFVs)]
363 -> VM VExpr
364 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
365 = do
366 vscrut <- vectExpr scrut
367 (vty, lty) <- vectAndLiftType ty
368 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
369 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
370
371 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
372 = do
373 vscrut <- vectExpr scrut
374 (vty, lty) <- vectAndLiftType ty
375 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
376 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
377
378 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
379 = do
380 (vty, lty) <- vectAndLiftType ty
381 vexpr <- vectExpr scrut
382 (vbndr, (vbndrs, (vect_body, lift_body)))
383 <- vect_scrut_bndr
384 . vectBndrsIn bndrs
385 $ vectExpr body
386 let (vect_bndrs, lift_bndrs) = unzip vbndrs
387 (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
388 vect_dc <- maybeV (lookupDataCon dc)
389 let [pdata_dc] = tyConDataCons pdata_tc
390
391 let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
392 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
393
394 return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
395 where
396 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
397 | otherwise = vectBndrIn bndr
398
399 mk_wild_case expr ty dc bndrs body
400 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
401
402 vectAlgCase tycon _ty_args scrut bndr ty alts
403 = do
404 vect_tc <- maybeV (lookupTyCon tycon)
405 (vty, lty) <- vectAndLiftType ty
406
407 let arity = length (tyConDataCons vect_tc)
408 sel_ty <- builtin (selTy arity)
409 sel_bndr <- newLocalVar (fsLit "sel") sel_ty
410 let sel = Var sel_bndr
411
412 (vbndr, valts) <- vect_scrut_bndr
413 $ mapM (proc_alt arity sel vty lty) alts'
414 let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
415
416 vexpr <- vectExpr scrut
417 (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
418 let [pdata_dc] = tyConDataCons pdata_tc
419
420 let (vect_bodies, lift_bodies) = unzip vbodies
421
422 vdummy <- newDummyVar (exprType vect_scrut)
423 ldummy <- newDummyVar (exprType lift_scrut)
424 let vect_case = Case vect_scrut vdummy vty
425 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
426
427 lc <- builtin liftingContext
428 lbody <- combinePD vty (Var lc) sel lift_bodies
429 let lift_case = Case lift_scrut ldummy lty
430 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
431 lbody)]
432
433 return . vLet (vNonRec vbndr vexpr)
434 $ (vect_case, lift_case)
435 where
436 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
437 | otherwise = vectBndrIn bndr
438
439 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
440
441 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
442 cmp DEFAULT DEFAULT = EQ
443 cmp DEFAULT _ = LT
444 cmp _ DEFAULT = GT
445 cmp _ _ = panic "vectAlgCase/cmp"
446
447 proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
448 = do
449 vect_dc <- maybeV (lookupDataCon dc)
450 let ntag = dataConTagZ vect_dc
451 tag = mkDataConTag vect_dc
452 fvs = freeVarsOf body `delVarSetList` bndrs
453
454 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
455 lc <- builtin liftingContext
456 elems <- builtin (selElements arity ntag)
457
458 (vbndrs, vbody)
459 <- vectBndrsIn bndrs
460 . localV
461 $ do
462 binds <- mapM (pack_var (Var lc) sel_tags tag)
463 . filter isLocalId
464 $ varSetElems fvs
465 (ve, le) <- vectExpr body
466 return (ve, Case (elems `App` sel) lc lty
467 [(DEFAULT, [], (mkLets (concat binds) le))])
468 -- empty <- emptyPD vty
469 -- return (ve, Case (elems `App` sel) lc lty
470 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
471 -- $ mkLets (concat binds) le),
472 -- (LitAlt (mkMachInt 0), [], empty)])
473 let (vect_bndrs, lift_bndrs) = unzip vbndrs
474 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
475
476 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
477
478 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
479
480 pack_var len tags t v
481 = do
482 r <- lookupVar v
483 case r of
484 Local (vv, lv) ->
485 do
486 lv' <- cloneVar lv
487 expr <- packByTagPD (idType vv) (Var lv) len tags t
488 updLEnv (\env -> env { local_vars = extendVarEnv
489 (local_vars env) v (vv, lv') })
490 return [(NonRec lv' expr)]
491
492 _ -> return []
493