Add VECTORISE [SCALAR] type pragma
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1
2 -- | Vectorisation of expressions.
3 module Vectorise.Exp (
4
5 -- Vectorise a polymorphic expression
6 vectPolyExpr,
7
8 -- Vectorise a scalar expression of functional type
9 vectScalarFun
10 ) where
11
12 #include "HsVersions.h"
13
14 import Vectorise.Type.Type
15 import Vectorise.Var
16 import Vectorise.Vect
17 import Vectorise.Env
18 import Vectorise.Monad
19 import Vectorise.Builtins
20 import Vectorise.Utils
21
22 import CoreSyn
23 import CoreUtils
24 import MkCore
25 import CoreFVs
26 import DataCon
27 import TyCon
28 import Type
29 import NameSet
30 import Var
31 import VarEnv
32 import VarSet
33 import Id
34 import BasicTypes( isStrongLoopBreaker )
35 import Literal
36 import TysWiredIn
37 import TysPrim
38 import Outputable
39 import FastString
40 import Control.Monad
41 import Data.List
42
43
44 -- | Vectorise a polymorphic expression.
45 --
46 vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
47 -- binding is a loop breaker.
48 -> [Var]
49 -> CoreExprWithFVs
50 -> VM (Inline, Bool, VExpr)
51 vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
52 = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
53 return (inline, isScalarFn, vNote note expr')
54 vectPolyExpr loop_breaker recFns expr
55 = do
56 arity <- polyArity tvs
57 polyAbstract tvs $ \args ->
58 do
59 (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
60 return (addInlineArity inline arity, isScalarFn,
61 mapVect (mkLams $ tvs ++ args) mono')
62 where
63 (tvs, mono) = collectAnnTypeBinders expr
64
65
66 -- |Vectorise an expression.
67 --
68 vectExpr :: CoreExprWithFVs -> VM VExpr
69 vectExpr (_, AnnType ty)
70 = liftM vType (vectType ty)
71
72 vectExpr (_, AnnVar v)
73 = vectVar v
74
75 vectExpr (_, AnnLit lit)
76 = vectLiteral lit
77
78 vectExpr (_, AnnNote note expr)
79 = liftM (vNote note) (vectExpr expr)
80
81 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
82 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
83 -- happy.
84 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
85 | v == pAT_ERROR_ID
86 = do { (vty, lty) <- vectAndLiftType ty
87 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
88 }
89 where
90 err' = deAnnotate err
91
92 vectExpr e@(_, AnnApp _ arg)
93 | isAnnTypeArg arg
94 = vectTyAppExpr fn tys
95 where
96 (fn, tys) = collectAnnTypeArgs e
97
98 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
99 | Just con <- isDataConId_maybe v
100 , is_special_con con
101 = do
102 let vexpr = App (Var v) (Lit lit)
103 lexpr <- liftPD vexpr
104 return (vexpr, lexpr)
105 where
106 is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
107
108
109 -- TODO: Avoid using closure application for dictionaries.
110 -- vectExpr (_, AnnApp fn arg)
111 -- | if is application of dictionary
112 -- just use regular app instead of closure app.
113
114 -- for lifted version.
115 -- do liftPD (sub a dNumber)
116 -- lift the result of the selection, not sub and dNumber seprately.
117
118 vectExpr (_, AnnApp fn arg)
119 = do
120 arg_ty' <- vectType arg_ty
121 res_ty' <- vectType res_ty
122
123 fn' <- vectExpr fn
124 arg' <- vectExpr arg
125
126 mkClosureApp arg_ty' res_ty' fn' arg'
127 where
128 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
129
130 vectExpr (_, AnnCase scrut bndr ty alts)
131 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
132 , isAlgTyCon tycon
133 = vectAlgCase tycon ty_args scrut bndr ty alts
134 | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
135 where
136 scrut_ty = exprType (deAnnotate scrut)
137
138 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
139 = do
140 vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
141 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
142 return $ vLet (vNonRec vbndr vrhs) vbody
143
144 vectExpr (_, AnnLet (AnnRec bs) body)
145 = do
146 (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
147 $ liftM2 (,)
148 (zipWithM vect_rhs bndrs rhss)
149 (vectExpr body)
150 return $ vLet (vRec vbndrs vrhss) vbody
151 where
152 (bndrs, rhss) = unzip bs
153
154 vect_rhs bndr rhs = localV
155 . inBind bndr
156 . liftM (\(_,_,z)->z)
157 $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs
158
159 vectExpr e@(_, AnnLam bndr _)
160 | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
161 {-
162 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
163 `orElseV` vectLam True fvs bs body
164 where
165 (bs,body) = collectAnnValBinders e
166 -}
167
168 vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
169
170 -- | Vectorise an expression with an outer lambda abstraction.
171 --
172 vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
173 -- be inlined
174 -> Bool -- ^ Whether the binding is a loop breaker
175 -> [Var] -- ^ Names of function in same recursive binding group
176 -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
177 -> VM (Inline, Bool, VExpr)
178 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
179 | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
180 `orElseV`
181 mark inlineMe False (vectLam inline loop_breaker expr)
182 vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e
183
184 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
185 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
186
187 -- |Vectorise an expression of functional type, where all arguments and the result are of scalar
188 -- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
189 -- involve parallel arrays. Such functionals do not requires the full blown vectorisation
190 -- transformation; instead, they can be lifted by application of a member of the zipWith family
191 -- (i.e., 'map', 'zipWith', zipWith3', etc.)
192 --
193 vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
194 -> [Var] -- ^ Functions names in same recursive binding group
195 -> CoreExpr -- ^ Expression to be vectorised
196 -> VM VExpr
197 vectScalarFun forceScalar recFns expr
198 = do { gscalarVars <- globalScalarVars
199 ; scalarTyCons <- globalScalarTyCons
200 ; let scalarVars = gscalarVars `extendVarSetList` recFns
201 (arg_tys, res_ty) = splitFunTys (exprType expr)
202 ; MASSERT( not $ null arg_tys )
203 ; onlyIfV (forceScalar -- user asserts the functions is scalar
204 ||
205 all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
206 && is_scalar_ty scalarTyCons res_ty
207 && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
208 && uses scalarVars expr)
209 $ mkScalarFun arg_tys res_ty expr
210 }
211 where
212 is_scalar_ty scalarTyCons ty
213 | Just (tycon, _) <- splitTyConApp_maybe ty
214 = tyConName tycon `elemNameSet` scalarTyCons
215 | otherwise = False
216
217 -- Checks whether an expression contain a non-scalar subexpression.
218 --
219 -- Precodition: The variables in the first argument are scalar.
220 --
221 -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
222 -- them to the list of scalar variables) and then check them. If one of them turns out not to
223 -- be scalar, the entire group is regarded as not being scalar.
224 --
225 -- The second argument is a predicate that checks whether a type is scalar.
226 --
227 is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
228 is_scalar scalars _isScalarTC (Var v) = v `elemVarSet` scalars
229 is_scalar _scalars _isScalarTC (Lit _) = True
230 is_scalar scalars isScalarTC e@(App e1 e2)
231 | maybe_parr_ty (exprType e) = False
232 | otherwise = is_scalar scalars isScalarTC e1 &&
233 is_scalar scalars isScalarTC e2
234 is_scalar scalars isScalarTC (Lam var body)
235 | maybe_parr_ty (varType var) = False
236 | otherwise = is_scalar (scalars `extendVarSet` var)
237 isScalarTC body
238 is_scalar scalars isScalarTC (Let bind body) = bindsAreScalar &&
239 is_scalar scalars' isScalarTC body
240 where
241 (bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
242 is_scalar scalars isScalarTC (Case e var ty alts)
243 | isScalarTC ty = is_scalar scalars' isScalarTC e &&
244 all (is_scalar_alt scalars' isScalarTC) alts
245 | otherwise = False
246 where
247 scalars' = scalars `extendVarSet` var
248 is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e
249 is_scalar scalars isScalarTC (Note _ e ) = is_scalar scalars isScalarTC e
250 is_scalar _scalars _isScalarTC (Type {}) = True
251 is_scalar _scalars _isScalarTC (Coercion {}) = True
252
253 -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
254 is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e,
255 scalars `extendVarSet` var)
256 is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es,
257 scalars')
258 where
259 (vars, es) = unzip bnds
260 scalars' = scalars `extendVarSetList` vars
261
262 is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars)
263 isScalarTCs e
264
265 -- Checks whether the type might be a parallel array type. In particular, if the outermost
266 -- constructor is a type family, we conservatively assume that it may be a parallel array type.
267 maybe_parr_ty :: Type -> Bool
268 maybe_parr_ty ty
269 | Just ty' <- coreView ty = maybe_parr_ty ty'
270 | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
271 maybe_parr_ty _ = False
272
273 -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions
274 -- is called by some other function that is otherwise scalar, it would be very bad
275 -- that just this call to the identity makes it not be scalar.
276 -- A scalar function has to actually compute something. Without the check,
277 -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
278 -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
279 -- (\n# x -> x) which is what we want.
280 uses funs (Var v) = v `elemVarSet` funs
281 uses funs (App e1 e2) = uses funs e1 || uses funs e2
282 uses funs (Lam b body) = uses (funs `extendVarSet` b) body
283 uses funs (Let (NonRec _b letExpr) body)
284 = uses funs letExpr || uses funs body
285 uses funs (Case e _eId _ty alts)
286 = uses funs e || any (uses_alt funs) alts
287 uses _ _ = False
288
289 uses_alt funs (_, _bs, e) = uses funs e
290
291 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
292 mkScalarFun arg_tys res_ty expr
293 = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline
294 ; zipf <- zipScalars arg_tys res_ty
295 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
296 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
297 ; lclo <- liftPD (Var clo_var)
298 ; return (Var clo_var, lclo)
299 }
300
301 -- | Vectorise a lambda abstraction.
302 --
303 vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
304 -> Bool -- ^ Whether the binding is a loop breaker.
305 -> CoreExprWithFVs -- ^ Body of abstraction.
306 -> VM VExpr
307 vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
308 = do let (bs, body) = collectAnnValBinders expr
309
310 tyvars <- localTyVars
311 (vs, vvs) <- readLEnv $ \env ->
312 unzip [(var, vv) | var <- varSetElems fvs
313 , Just vv <- [lookupVarEnv (local_vars env) var]]
314
315 arg_tys <- mapM (vectType . idType) bs
316 res_ty <- vectType (exprType $ deAnnotate body)
317
318 buildClosures tyvars vvs arg_tys res_ty
319 . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
320 $ do
321 lc <- builtin liftingContext
322 (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body)
323
324 vbody' <- break_loop lc res_ty vbody
325 return $ vLams lc vbndrs vbody'
326 where
327 maybe_inline n | inline = Inline n
328 | otherwise = DontInline
329
330 break_loop lc ty (ve, le)
331 | loop_breaker
332 = do
333 empty <- emptyPD ty
334 lty <- mkPDataType ty
335 return (ve, mkWildCase (Var lc) intPrimTy lty
336 [(DEFAULT, [], le),
337 (LitAlt (mkMachInt 0), [], empty)])
338
339 | otherwise = return (ve, le)
340 vectLam _ _ _ = panic "vectLam"
341
342
343 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
344 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
345 vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
346 (ppr $ deAnnotate e `mkTyApps` tys)
347
348
349 -- | Vectorise an algebraic case expression.
350 -- We convert
351 --
352 -- case e :: t of v { ... }
353 --
354 -- to
355 --
356 -- V: let v' = e in case v' of _ { ... }
357 -- L: let v' = e in case v' `cast` ... of _ { ... }
358 --
359 -- When lifting, we have to do it this way because v must have the type
360 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
361 -- have to handle the case where v is a wild var correctly.
362 --
363
364 -- FIXME: this is too lazy
365 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs -> Var -> Type
366 -> [(AltCon, [Var], CoreExprWithFVs)]
367 -> VM VExpr
368 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
369 = do
370 vscrut <- vectExpr scrut
371 (vty, lty) <- vectAndLiftType ty
372 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
373 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
374
375 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
376 = do
377 vscrut <- vectExpr scrut
378 (vty, lty) <- vectAndLiftType ty
379 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
380 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
381
382 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
383 = do
384 (vty, lty) <- vectAndLiftType ty
385 vexpr <- vectExpr scrut
386 (vbndr, (vbndrs, (vect_body, lift_body)))
387 <- vect_scrut_bndr
388 . vectBndrsIn bndrs
389 $ vectExpr body
390 let (vect_bndrs, lift_bndrs) = unzip vbndrs
391 (vscrut, lscrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
392 vect_dc <- maybeV (lookupDataCon dc)
393 let [pdata_dc] = tyConDataCons pdata_tc
394
395 let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
396 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
397
398 return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
399 where
400 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
401 | otherwise = vectBndrIn bndr
402
403 mk_wild_case expr ty dc bndrs body
404 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
405
406 vectAlgCase tycon _ty_args scrut bndr ty alts
407 = do
408 vect_tc <- maybeV (lookupTyCon tycon)
409 (vty, lty) <- vectAndLiftType ty
410
411 let arity = length (tyConDataCons vect_tc)
412 sel_ty <- builtin (selTy arity)
413 sel_bndr <- newLocalVar (fsLit "sel") sel_ty
414 let sel = Var sel_bndr
415
416 (vbndr, valts) <- vect_scrut_bndr
417 $ mapM (proc_alt arity sel vty lty) alts'
418 let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
419
420 vexpr <- vectExpr scrut
421 (vect_scrut, lift_scrut, pdata_tc, _arg_tys) <- mkVScrut (vVar vbndr)
422 let [pdata_dc] = tyConDataCons pdata_tc
423
424 let (vect_bodies, lift_bodies) = unzip vbodies
425
426 vdummy <- newDummyVar (exprType vect_scrut)
427 ldummy <- newDummyVar (exprType lift_scrut)
428 let vect_case = Case vect_scrut vdummy vty
429 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
430
431 lc <- builtin liftingContext
432 lbody <- combinePD vty (Var lc) sel lift_bodies
433 let lift_case = Case lift_scrut ldummy lty
434 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
435 lbody)]
436
437 return . vLet (vNonRec vbndr vexpr)
438 $ (vect_case, lift_case)
439 where
440 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
441 | otherwise = vectBndrIn bndr
442
443 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
444
445 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
446 cmp DEFAULT DEFAULT = EQ
447 cmp DEFAULT _ = LT
448 cmp _ DEFAULT = GT
449 cmp _ _ = panic "vectAlgCase/cmp"
450
451 proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
452 = do
453 vect_dc <- maybeV (lookupDataCon dc)
454 let ntag = dataConTagZ vect_dc
455 tag = mkDataConTag vect_dc
456 fvs = freeVarsOf body `delVarSetList` bndrs
457
458 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
459 lc <- builtin liftingContext
460 elems <- builtin (selElements arity ntag)
461
462 (vbndrs, vbody)
463 <- vectBndrsIn bndrs
464 . localV
465 $ do
466 binds <- mapM (pack_var (Var lc) sel_tags tag)
467 . filter isLocalId
468 $ varSetElems fvs
469 (ve, le) <- vectExpr body
470 return (ve, Case (elems `App` sel) lc lty
471 [(DEFAULT, [], (mkLets (concat binds) le))])
472 -- empty <- emptyPD vty
473 -- return (ve, Case (elems `App` sel) lc lty
474 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
475 -- $ mkLets (concat binds) le),
476 -- (LitAlt (mkMachInt 0), [], empty)])
477 let (vect_bndrs, lift_bndrs) = unzip vbndrs
478 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
479
480 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
481
482 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
483
484 pack_var len tags t v
485 = do
486 r <- lookupVar v
487 case r of
488 Local (vv, lv) ->
489 do
490 lv' <- cloneVar lv
491 expr <- packByTagPD (idType vv) (Var lv) len tags t
492 updLEnv (\env -> env { local_vars = extendVarEnv
493 (local_vars env) v (vv, lv') })
494 return [(NonRec lv' expr)]
495
496 _ -> return []
497