Remove '-favoid-vect' and add '-fvectorisation-avoidance'
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise right-hand sides of toplevel bindings
7 vectTopExpr
8 , vectTopExprs
9 , vectScalarFun
10 , vectScalarDFun
11 )
12 where
13
14 #include "HsVersions.h"
15
16 import Vectorise.Type.Type
17 import Vectorise.Var
18 import Vectorise.Convert
19 import Vectorise.Vect
20 import Vectorise.Env
21 import Vectorise.Monad
22 import Vectorise.Builtins
23 import Vectorise.Utils
24
25 import CoreUtils
26 import MkCore
27 import CoreSyn
28 import CoreFVs
29 import Class
30 import DataCon
31 import TyCon
32 import TcType
33 import Type
34 import TypeRep
35 import Var
36 import VarEnv
37 import VarSet
38 import NameSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysPrim
43 import Outputable
44 import FastString
45 import DynFlags
46 import Util
47 import MonadUtils
48
49 import Control.Monad
50 import Data.Maybe
51 import Data.List
52
53
54 -- Main entry point to vectorise expressions -----------------------------------
55
56 -- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
57 --
58 -- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
59 -- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
60 -- tagged as 'VIParr').
61 --
62 -- We have got the non-recursive case as a special case as it doesn't require to compute
63 -- vectorisation information twice.
64 --
65 vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
66 vectTopExpr var expr
67 = do
68 { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
69 ; if isVIEncaps exprVI
70 then
71 return Nothing
72 else do
73 { vExpr <- closedV $
74 inBind var $
75 vectAnnPolyExpr False exprVI
76 ; inline <- computeInline exprVI
77 ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
78 }
79 }
80
81 -- Compute the inlining hint for the right-hand side of a top-level binding.
82 --
83 computeInline :: CoreExprWithVectInfo -> VM Inline
84 computeInline ((_, VIDict), _) = return $ DontInline
85 computeInline (_, AnnTick _ expr) = computeInline expr
86 computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
87 where
88 (tvs, _) = collectAnnTypeBinders expr
89 computeInline _expr = return $ DontInline
90
91 -- |Vectorise a recursive group of top-level polymorphic expressions.
92 --
93 -- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
94 -- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
95 -- tagged as 'VIParr').
96 --
97 vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
98 vectTopExprs binds
99 = do
100 { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
101 ; if all isVIEncaps exprVIs
102 -- if all bindings are scalar => don't vectorise this group of bindings
103 then return Nothing
104 else do
105 { -- non-scalar bindings need to be vectorised
106 ; let areVIParr = any isVIParr exprVIs
107 ; revised_exprVIs <- if not areVIParr
108 -- if no binding is parallel => 'exprVIs' is ready for vectorisation
109 then return exprVIs
110 -- if any binding is parallel => recompute the vectorisation info
111 else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
112
113 ; vExprs <- zipWithM vect vars revised_exprVIs
114 ; return $ Just (areVIParr, vExprs)
115 }
116 }
117 where
118 (vars, exprs) = unzip binds
119
120 vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
121
122 vect var exprVI
123 = do
124 { vExpr <- closedV $
125 inBind var $
126 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
127 ; inline <- computeInline exprVI
128 ; return (inline, vectorised vExpr)
129 }
130
131 -- |Vectorise a polymorphic expression annotated with vectorisation information.
132 --
133 -- The special case of dictionary functions is currently handled separately. (Would be neater to
134 -- integrate them, though!)
135 --
136 vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
137 vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
138 -- traverse through ticks
139 = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
140 vectAnnPolyExpr loop_breaker expr
141 | isVIDict expr
142 -- special case the right-hand side of dictionary functions
143 = (, undefined) <$> vectDictExpr (deAnnotate expr)
144 | otherwise
145 -- collect and vectorise type abstractions; then, descent into the body
146 = polyAbstract tvs $ \args ->
147 mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
148 where
149 (tvs, mono) = collectAnnTypeBinders expr
150
151 -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
152 -- lambda abstraction over all its free variables followed by the corresponding application to those
153 -- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
154 --
155 -- Preconditions:
156 --
157 -- * All free variables and the result type must be /simple/ types.
158 -- * The expression is sufficiently complex (to warrant special treatment). For now, that is
159 -- every expression that is not constant and contains at least one operation.
160 --
161 --
162 -- The user has an option to choose between aggressive and minimal vectorisation avoidance. With
163 -- minimal vectorisation avoidance, we only encapsulate individual scalar operations. With
164 -- aggressive vectorisation avoidance, we encapsulate subexpression that are as big as possible.
165 --
166 encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
167 encapsulateScalars ce@(_, AnnType _ty)
168 = return ce
169 encapsulateScalars ce@((_, VISimple), AnnVar _v)
170 -- NB: diverts from the paper: encapsulate scalar variables (including functions)
171 = liftSimpleAndCase ce
172 encapsulateScalars ce@(_, AnnVar _v)
173 = return ce
174 encapsulateScalars ce@(_, AnnLit _)
175 = return ce
176 encapsulateScalars ((fvs, vi), AnnTick tck expr)
177 = do
178 { encExpr <- encapsulateScalars expr
179 ; return ((fvs, vi), AnnTick tck encExpr)
180 }
181 encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
182 = do
183 { vectAvoid <- isVectAvoidanceAggressive
184 ; varsS <- allScalarVarTypeSet fvs
185 -- NB: diverts from the paper: we need to check the scalarness of bound variables as well,
186 -- as 'vectScalarFun' will handle them just the same as those introduced for the 'fvs'
187 -- by encapsulation.
188 ; bndrsS <- allScalarVarType bndrs
189 ; case (vi, vectAvoid && varsS && bndrsS) of
190 (VISimple, True) -> liftSimpleAndCase ce
191 _ -> do
192 { encExpr <- encapsulateScalars expr
193 ; return ((fvs, vi), AnnLam bndr encExpr)
194 }
195 }
196 where
197 (bndrs, _) = collectAnnBndrs ce
198 encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
199 = do
200 { vectAvoid <- isVectAvoidanceAggressive
201 ; varsS <- allScalarVarTypeSet fvs
202 ; case (vi, (vectAvoid || isSimpleApplication ce) && varsS) of
203 (VISimple, True) -> liftSimpleAndCase ce
204 _ -> do
205 { encCe1 <- encapsulateScalars ce1
206 ; encCe2 <- encapsulateScalars ce2
207 ; return ((fvs, vi), AnnApp encCe1 encCe2)
208 }
209 }
210 where
211 isSimpleApplication :: CoreExprWithVectInfo -> Bool
212 isSimpleApplication (_, AnnTick _ ce) = isSimpleApplication ce
213 isSimpleApplication (_, AnnCast ce _) = isSimpleApplication ce
214 isSimpleApplication ce | isSimple ce = True
215 isSimpleApplication (_, AnnApp ce1 ce2) = isSimple ce1 && isSimpleApplication ce2
216 isSimpleApplication _ = False
217 --
218 isSimple :: CoreExprWithVectInfo -> Bool
219 isSimple (_, AnnType {}) = True
220 isSimple (_, AnnVar {}) = True
221 isSimple (_, AnnLit {}) = True
222 isSimple (_, AnnTick _ ce) = isSimple ce
223 isSimple (_, AnnCast ce _) = isSimple ce
224 isSimple _ = False
225 encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
226 = do
227 { vectAvoid <- isVectAvoidanceAggressive
228 ; varsS <- allScalarVarTypeSet fvs
229 ; case (vi, vectAvoid && varsS) of
230 (VISimple, True) -> liftSimpleAndCase ce
231 _ -> do
232 { encScrut <- encapsulateScalars scrut
233 ; encAlts <- mapM encAlt alts
234 ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
235 }
236 }
237 where
238 encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
239 encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
240 = do
241 { vectAvoid <- isVectAvoidanceAggressive
242 ; varsS <- allScalarVarTypeSet fvs
243 ; case (vi, vectAvoid && varsS) of
244 (VISimple, True) -> liftSimpleAndCase ce
245 _ -> do
246 { encExpr1 <- encapsulateScalars expr1
247 ; encExpr2 <- encapsulateScalars expr2
248 ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
249 }
250 }
251 encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
252 = do
253 { vectAvoid <- isVectAvoidanceAggressive
254 ; varsS <- allScalarVarTypeSet fvs
255 ; case (vi, vectAvoid && varsS) of
256 (VISimple, True) -> liftSimpleAndCase ce
257 _ -> do
258 { encBinds <- mapM encBind binds
259 ; encExpr <- encapsulateScalars expr
260 ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
261 }
262 }
263 where
264 encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
265 encapsulateScalars ((fvs, vi), AnnCast expr coercion)
266 = do
267 { encExpr <- encapsulateScalars expr
268 ; return ((fvs, vi), AnnCast encExpr coercion)
269 }
270 encapsulateScalars _
271 = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
272
273 -- Lambda-lift the given simple expression and apply it to the abstracted free variables.
274 --
275 -- If the expression is a case expression scrutinising anything, but a scalar type, then lift
276 -- each alternative individually.
277 --
278 liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
279 liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
280 = do
281 { vi <- vectAvoidInfoTypeOf expr
282 ; if (vi == VISimple)
283 then
284 liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
285 else do
286 { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
287 ; return ((fvs, vi), AnnCase expr bndr t alts')
288 }
289 }
290 liftSimpleAndCase aexpr = liftSimple aexpr
291
292 liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
293 liftSimple ((fvs, vi), AnnVar v)
294 | v `elemVarSet` fvs -- special case to avoid producing: (\v -> v) v
295 && not (isToplevel v) -- NB: if 'v' not free or is toplevel, we must get the 'VIEncaps'
296 = return $ ((fvs, vi), AnnVar v)
297 liftSimple aexpr@((fvs_orig, VISimple), expr)
298 = do
299 { let liftedExpr = mkAnnApps (mkAnnLams (reverse vars) fvs expr) vars
300
301 ; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
302
303 ; return $ liftedExpr
304 }
305 where
306 vars = varSetElems fvs
307 fvs = filterVarSet (not . isToplevel) fvs_orig -- only include 'Id's that are not toplevel
308
309 mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
310 mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
311 ((emptyVarSet, VIEncaps), expr)
312 mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
313
314 mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
315 mkAnnApps aexpr [] = aexpr
316 mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
317
318 mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
319 mkAnnApp aexpr@((fvs, _vi), _expr) v
320 = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
321 liftSimple aexpr
322 = pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
323
324 isToplevel :: Var -> Bool
325 isToplevel v | isId v = case realIdUnfolding v of
326 NoUnfolding -> False
327 OtherCon {} -> True
328 DFunUnfolding {} -> True
329 CoreUnfolding {uf_is_top = top} -> top
330 | otherwise = False
331
332 -- |Vectorise an expression.
333 --
334 vectExpr :: CoreExprWithVectInfo -> VM VExpr
335
336 vectExpr aexpr
337 -- encapsulated expression of functional type => try to vectorise as a scalar subcomputation
338 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
339 = vectFnExpr True False aexpr
340 -- encapsulated constant => vectorise as a scalar constant
341 | isVIEncaps aexpr
342 = traceVt "vectExpr (encapsulated constant):" (ppr . deAnnotate $ aexpr) >>
343 vectConst (deAnnotate aexpr)
344
345 vectExpr (_, AnnVar v)
346 = vectVar v
347
348 vectExpr (_, AnnLit lit)
349 = vectConst $ Lit lit
350
351 vectExpr aexpr@(_, AnnLam _ _)
352 = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >>
353 vectFnExpr True False aexpr
354
355 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
356 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
357 -- happy.
358 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
359 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
360 | v == pAT_ERROR_ID
361 = do
362 { (vty, lty) <- vectAndLiftType ty
363 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
364 }
365 where
366 err' = deAnnotate err
367
368 -- type application (handle multiple consecutive type applications simultaneously to ensure the
369 -- PA dictionaries are put at the right places)
370 vectExpr e@(_, AnnApp _ arg)
371 | isAnnTypeArg arg
372 = vectPolyApp e
373
374 -- Lifted literal
375 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
376 | Just _con <- isDataConId_maybe v
377 = do
378 { let vexpr = App (Var v) (Lit lit)
379 ; lexpr <- liftPD vexpr
380 ; return (vexpr, lexpr)
381 }
382
383 -- value application (dictionary or user value)
384 vectExpr e@(_, AnnApp fn arg)
385 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
386 = vectPolyApp e
387 | otherwise -- user value
388 = do
389 { -- vectorise the types
390 ; varg_ty <- vectType arg_ty
391 ; vres_ty <- vectType res_ty
392
393 -- vectorise the function and argument expression
394 ; vfn <- vectExpr fn
395 ; varg <- vectExpr arg
396
397 -- the vectorised function is a closure; apply it to the vectorised argument
398 ; mkClosureApp varg_ty vres_ty vfn varg
399 }
400 where
401 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
402
403 vectExpr (_, AnnCase scrut bndr ty alts)
404 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
405 , isAlgTyCon tycon
406 = vectAlgCase tycon ty_args scrut bndr ty alts
407 | otherwise
408 = do
409 { dflags <- getDynFlags
410 ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
411 ppr scrut_ty
412 }
413 where
414 scrut_ty = exprType (deAnnotate scrut)
415
416 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
417 = do
418 { traceVt "let binding (non-recursive)" empty
419 ; vrhs <- localV $
420 inBind bndr $
421 vectAnnPolyExpr False rhs
422 ; traceVt "let body (non-recursive)" empty
423 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
424 ; return $ vLet (vNonRec vbndr vrhs) vbody
425 }
426
427 vectExpr (_, AnnLet (AnnRec bs) body)
428 = do
429 { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
430 { traceVt "let bindings (recursive)" empty
431 ; vrhss <- zipWithM vect_rhs bndrs rhss
432 ; traceVt "let body (recursive)" empty
433 ; vbody <- vectExpr body
434 ; return (vrhss, vbody)
435 }
436 ; return $ vLet (vRec vbndrs vrhss) vbody
437 }
438 where
439 (bndrs, rhss) = unzip bs
440
441 vect_rhs bndr rhs = localV $
442 inBind bndr $
443 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
444
445 vectExpr (_, AnnTick tickish expr)
446 = vTick tickish <$> vectExpr expr
447
448 vectExpr (_, AnnType ty)
449 = vType <$> vectType ty
450
451 vectExpr e
452 = do
453 { dflags <- getDynFlags
454 ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
455 }
456
457 -- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
458 -- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
459 -- zip).
460 --
461 -- We do not handle type variables at this point, as they will already have been stripped off by
462 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
463 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
464 --
465 vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
466 -- should be inlined
467 -> Bool -- ^Whether the binding is a loop breaker
468 -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
469 -> VM VExpr
470 vectFnExpr inline loop_breaker aexpr@(_ann, AnnLam bndr body)
471 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
472 | isId bndr
473 && isPredTy (idType bndr)
474 = do
475 { vBndr <- vectBndr bndr
476 ; vbody <- vectFnExpr inline loop_breaker body
477 ; return $ mapVect (mkLams [vectorised vBndr]) vbody
478 }
479 -- encapsulated non-predicate abstraction: vectorise as a scalar computation
480 | isId bndr && isVIEncaps aexpr
481 = vectScalarFun . deAnnotate $ aexpr
482 -- non-predicate abstraction: vectorise as a non-scalar computation
483 | isId bndr
484 = vectLam inline loop_breaker aexpr
485 | otherwise
486 = do
487 { dflags <- getDynFlags
488 ; cantVectorise dflags "Vectorise.Exp.vectFnExpr: Unexpected type lambda" $
489 ppr (deAnnotate aexpr)
490 }
491 vectFnExpr _ _ aexpr
492 -- encapsulated function: vectorise as a scalar computation
493 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
494 = vectScalarFun . deAnnotate $ aexpr
495 | otherwise
496 -- not an abstraction: vectorise as a non-scalar vanilla expression
497 -- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr'
498 = vectExpr aexpr
499
500 -- |Vectorise type and dictionary applications.
501 --
502 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
503 -- involve two sets of type variables and dictionaries. Consider,
504 --
505 -- > class C a where
506 -- > m :: D b => b -> a
507 --
508 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
509 --
510 vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
511 vectPolyApp e0
512 = case e4 of
513 (_, AnnVar var)
514 -> do { -- get the vectorised form of the variable
515 ; vVar <- lookupVar var
516 ; traceVt "vectPolyApp of" (ppr var)
517
518 -- vectorise type and dictionary arguments
519 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
520 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
521 ; vTysOuter <- mapM vectType tysOuter
522 ; vTysInner <- mapM vectType tysInner
523
524 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
525
526 ; case vVar of
527 Local (vv, lv)
528 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
529 ; traceVt " LOCAL" (text "")
530 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
531 }
532 Global vv
533 | isDictComp var -- dictionary computation
534 -> do { -- in a dictionary computation, the innermost, non-empty set of
535 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
536 -- are needed for the type variables
537 ; ve <- if null dictsInner
538 then
539 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
540 else
541 reconstructOuter
542 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
543 ; traceVt " GLOBAL (dict):" (ppr ve)
544 ; vectConst ve
545 }
546 | otherwise -- non-dictionary computation
547 -> do { MASSERT( null dictsInner )
548 ; ve <- reconstructOuter (Var vv)
549 ; traceVt " GLOBAL (non-dict):" (ppr ve)
550 ; vectConst ve
551 }
552 }
553 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
554 where
555 -- if there is only one set of variables or dictionaries, it will be the outer set
556 (e1, dictsOuter) = collectAnnDictArgs e0
557 (e2, tysOuter) = collectAnnTypeArgs e1
558 (e3, dictsInner) = collectAnnDictArgs e2
559 (e4, tysInner) = collectAnnTypeArgs e3
560 --
561 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
562
563 -- |Vectorise the body of a dfun.
564 --
565 -- Dictionary computations are special for the following reasons. The application of dictionary
566 -- functions are always saturated, so there is no need to create closures. Dictionary computations
567 -- don't depend on array values, so they are always scalar computations whose result we can
568 -- replicate (instead of executing them in parallel).
569 --
570 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
571 -- computation. Consequently, the variable case needs to deal with cases where binders are
572 -- in the vectoriser environments and where that is not the case.
573 --
574 vectDictExpr :: CoreExpr -> VM CoreExpr
575 vectDictExpr (Var var)
576 = do { mb_scope <- lookupVar_maybe var
577 ; case mb_scope of
578 Nothing -> return $ Var var -- binder from within the dict. computation
579 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
580 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
581 }
582 vectDictExpr (Lit lit)
583 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
584 vectDictExpr (Lam bndr e)
585 = Lam bndr <$> vectDictExpr e
586 vectDictExpr (App fn arg)
587 = App <$> vectDictExpr fn <*> vectDictExpr arg
588 vectDictExpr (Case e bndr ty alts)
589 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
590 where
591 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
592 --
593 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
594 where
595 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
596 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
597 vectDictAltCon DEFAULT = return DEFAULT
598 vectDictExpr (Let bnd body)
599 = Let <$> vectDictBind bnd <*> vectDictExpr body
600 where
601 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
602 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
603 vectDictExpr e@(Cast _e _coe)
604 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
605 vectDictExpr (Tick tickish e)
606 = Tick tickish <$> vectDictExpr e
607 vectDictExpr (Type ty)
608 = Type <$> vectType ty
609 vectDictExpr (Coercion coe)
610 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
611
612 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
613 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
614 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
615 -- require the full blown vectorisation transformation; instead, they can be lifted by application
616 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
617 --
618 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
619 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
620 -- "Note [Scalar dfuns]" in 'Vectorise'.
621 --
622 vectScalarFun :: CoreExpr -> VM VExpr
623 vectScalarFun expr
624 = do
625 { traceVt "vectScalarFun:" (ppr expr)
626 ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
627 ; mkScalarFun arg_tys res_ty expr
628 }
629
630 -- Generate code for a scalar function by generating a scalar closure. If the function is a
631 -- dictionary function, vectorise it as dictionary code.
632 --
633 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
634 mkScalarFun arg_tys res_ty expr
635 | isPredTy res_ty
636 = do { vExpr <- vectDictExpr expr
637 ; return (vExpr, unused)
638 }
639 | otherwise
640 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
641
642 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
643 ; zipf <- zipScalars arg_tys res_ty
644 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
645 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
646 ; lclo <- liftPD (Var clo_var)
647 ; return (Var clo_var, lclo)
648 }
649 where
650 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
651
652 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
653 --
654 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
655 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
656 --
657 -- NB: You may think that we could implement this function guided by the struture of the Core
658 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
659 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
660 -- to the Core code of the unvectorised dfun.
661 --
662 -- Here an example — assume,
663 --
664 -- > class Eq a where { (==) :: a -> a -> Bool }
665 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
666 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
667 --
668 -- The unvectorised dfun for the above instance has the following signature:
669 --
670 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
671 --
672 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
673 --
674 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
675 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
676 -- > D:V:Eq $(vectScalarFun True recFns
677 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
678 --
679 -- NB:
680 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
681 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
682 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
683 --
684 vectScalarDFun :: Var -- ^ Original dfun
685 -> VM CoreExpr
686 vectScalarDFun var
687 = do { -- bring the type variables into scope
688 ; mapM_ defLocalTyVar tvs
689
690 -- vectorise dictionary argument types and generate variables for them
691 ; vTheta <- mapM vectType theta
692 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
693 ; let vThetaVars = varsToCoreExprs vThetaBndr
694
695 -- vectorise superclass dictionaries and methods as scalar expressions
696 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
697 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
698 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
699 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
700 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
701 selIds
702 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
703
704 -- vectorised applications of the class-dictionary data constructor
705 ; Just vDataCon <- lookupDataCon dataCon
706 ; vTys <- mapM vectType tys
707 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
708
709 ; return $ mkLams (tvs ++ vThetaBndr) vBody
710 }
711 where
712 ty = varType var
713 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
714 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
715 selIds = classAllSelIds cls
716 dataCon = classDataCon cls
717
718 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
719 -- expression computing the vectorised dictionary.
720 --
721 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
722 -- the unvectorised version, thus:
723 --
724 -- > D:C op1 .. opm
725 -- > where
726 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
727 --
728 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
729 --
730 unVectDict :: Type -> CoreExpr -> VM CoreExpr
731 unVectDict ty e
732 = do { vTys <- mapM vectType tys
733 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
734 ; scOps <- zipWithM fromVect methTys meths
735 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
736 }
737 where
738 (tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
739 cls = case tyConClass_maybe tycon of
740 Just cls -> cls
741 Nothing -> panic "Vectorise.Exp.unVectDict: no class"
742 selIds = classAllSelIds cls
743
744 -- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
745 --
746 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
747 -- variables are passed explicit (as conventional arguments) into the body during closure
748 -- construction.
749 --
750 vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
751 -> Bool -- ^ Whether the binding is a loop breaker.
752 -> CoreExprWithVectInfo -- ^ Body of abstraction.
753 -> VM VExpr
754 vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
755 = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
756
757 ; let (bndrs, body) = collectAnnValBinders expr
758
759 -- grab the in-scope type variables
760 ; tyvars <- localTyVars
761
762 -- collect and vectorise all /local/ free variables
763 ; vfvs <- readLEnv $ \env ->
764 [ (var, fromJust mb_vv)
765 | var <- varSetElems fvs
766 , let mb_vv = lookupVarEnv (local_vars env) var
767 , isJust mb_vv -- its local == is in local var env
768 ]
769 -- separate dictionary from non-dictionary variables in the free variable set
770 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
771 (_fvs_dict, vfvs_dict) = unzip vvs_dict
772 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
773
774 -- compute the type of the vectorised closure
775 ; arg_tys <- mapM (vectType . idType) bndrs
776 ; res_ty <- vectType (exprType $ deAnnotate body)
777
778 ; let arity = length fvs_nondict + length bndrs
779 vfvs_dict' = map vectorised vfvs_dict
780 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
781 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
782 $ do { -- generate the vectorised body of the lambda abstraction
783 ; lc <- builtin liftingContext
784 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
785
786 ; vbody' <- break_loop lc res_ty vbody
787 ; return $ vLams lc vbndrs vbody'
788 }
789 }
790 where
791 maybe_inline n | inline = Inline n
792 | otherwise = DontInline
793
794 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
795 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
796 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
797 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
798 -- in Figure 6 of HtM.
799 break_loop lc ty (ve, le)
800 | loop_breaker
801 = do { empty <- emptyPD ty
802 ; lty <- mkPDataType ty
803 ; return (ve, mkWildCase (Var lc) intPrimTy lty
804 [(DEFAULT, [], le),
805 (LitAlt (mkMachInt 0), [], empty)])
806 }
807 | otherwise = return (ve, le)
808 vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
809
810 -- Vectorise an algebraic case expression.
811 --
812 -- We convert
813 --
814 -- case e :: t of v { ... }
815 --
816 -- to
817 --
818 -- V: let v' = e in case v' of _ { ... }
819 -- L: let v' = e in case v' `cast` ... of _ { ... }
820 --
821 -- When lifting, we have to do it this way because v must have the type
822 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
823 -- have to handle the case where v is a wild var correctly.
824 --
825
826 -- FIXME: this is too lazy...is it?
827 vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
828 -> [(AltCon, [Var], CoreExprWithVectInfo)]
829 -> VM VExpr
830 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
831 = do
832 { traceVt "scrutinee (DEFAULT only)" empty
833 ; vscrut <- vectExpr scrut
834 ; (vty, lty) <- vectAndLiftType ty
835 ; traceVt "alternative body (DEFAULT only)" empty
836 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
837 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
838 }
839 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
840 = do
841 { traceVt "scrutinee (one shot w/o binders)" empty
842 ; vscrut <- vectExpr scrut
843 ; (vty, lty) <- vectAndLiftType ty
844 ; traceVt "alternative body (one shot w/o binders)" empty
845 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
846 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
847 }
848 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
849 = do
850 { traceVt "scrutinee (one shot w/ binders)" empty
851 ; vexpr <- vectExpr scrut
852 ; (vty, lty) <- vectAndLiftType ty
853 ; traceVt "alternative body (one shot w/ binders)" empty
854 ; (vbndr, (vbndrs, (vect_body, lift_body)))
855 <- vect_scrut_bndr
856 . vectBndrsIn bndrs
857 $ vectExpr body
858 ; let (vect_bndrs, lift_bndrs) = unzip vbndrs
859 ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
860 ; vect_dc <- maybeV dataConErr (lookupDataCon dc)
861
862 ; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
863 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
864
865 ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
866 }
867 where
868 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
869 | otherwise = vectBndrIn bndr
870
871 mk_wild_case expr ty dc bndrs body
872 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
873
874 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
875
876 vectAlgCase tycon _ty_args scrut bndr ty alts
877 = do
878 { traceVt "scrutinee (general case)" empty
879 ; vexpr <- vectExpr scrut
880
881 ; vect_tc <- vectTyCon tycon
882 ; (vty, lty) <- vectAndLiftType ty
883
884 ; let arity = length (tyConDataCons vect_tc)
885 ; sel_ty <- builtin (selTy arity)
886 ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
887 ; let sel = Var sel_bndr
888
889 ; traceVt "alternatives' body (general case)" empty
890 ; (vbndr, valts) <- vect_scrut_bndr
891 $ mapM (proc_alt arity sel vty lty) alts'
892 ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
893
894 ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
895
896 ; let (vect_bodies, lift_bodies) = unzip vbodies
897
898 ; vdummy <- newDummyVar (exprType vect_scrut)
899 ; ldummy <- newDummyVar (exprType lift_scrut)
900 ; let vect_case = Case vect_scrut vdummy vty
901 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
902
903 ; lc <- builtin liftingContext
904 ; lbody <- combinePD vty (Var lc) sel lift_bodies
905 ; let lift_case = Case lift_scrut ldummy lty
906 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
907 lbody)]
908
909 ; return . vLet (vNonRec vbndr vexpr)
910 $ (vect_case, lift_case)
911 }
912 where
913 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
914 | otherwise = vectBndrIn bndr
915
916 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
917
918 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
919 cmp DEFAULT DEFAULT = EQ
920 cmp DEFAULT _ = LT
921 cmp _ DEFAULT = GT
922 cmp _ _ = panic "vectAlgCase/cmp"
923
924 proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
925 = do
926 vect_dc <- maybeV dataConErr (lookupDataCon dc)
927 let ntag = dataConTagZ vect_dc
928 tag = mkDataConTag vect_dc
929 fvs = fvs_body `delVarSetList` bndrs
930
931 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
932 lc <- builtin liftingContext
933 elems <- builtin (selElements arity ntag)
934
935 (vbndrs, vbody)
936 <- vectBndrsIn bndrs
937 . localV
938 $ do
939 { binds <- mapM (pack_var (Var lc) sel_tags tag)
940 . filter isLocalId
941 $ varSetElems fvs
942 ; traceVt "case alternative:" (ppr . deAnnotate $ body)
943 ; (ve, le) <- vectExpr body
944 ; return (ve, Case (elems `App` sel) lc lty
945 [(DEFAULT, [], (mkLets (concat binds) le))])
946 }
947 -- empty <- emptyPD vty
948 -- return (ve, Case (elems `App` sel) lc lty
949 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
950 -- $ mkLets (concat binds) le),
951 -- (LitAlt (mkMachInt 0), [], empty)])
952 let (vect_bndrs, lift_bndrs) = unzip vbndrs
953 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
954 where
955 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
956
957 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
958
959 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
960
961 -- Pack a variable for a case alternative context *if* the variable is vectorised. If it
962 -- isn't, ignore it as scalar variables don't need to be packed.
963 pack_var len tags t v
964 = do
965 { r <- lookupVar_maybe v
966 ; case r of
967 Just (Local (vv, lv)) ->
968 do
969 { lv' <- cloneVar lv
970 ; expr <- packByTagPD (idType vv) (Var lv) len tags t
971 ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
972 ; return [(NonRec lv' expr)]
973 }
974 _ -> return []
975 }
976
977
978 -- Support to compute information for vectorisation avoidance ------------------
979
980 -- Annotation for Core AST nodes that describes how they should be handled during vectorisation
981 -- and especially if vectorisation of the corresponding computation can be avoided.
982 --
983 data VectAvoidInfo = VIParr -- tree contains parallel computations
984 | VISimple -- result type is scalar & no parallel subcomputation
985 | VIComplex -- any result type, no parallel subcomputation
986 | VIEncaps -- tree encapsulated by 'liftSimple'
987 | VIDict -- dictionary computation (never parallel)
988 deriving (Eq, Show)
989
990 -- Core expression annotated with free variables and vectorisation-specific information.
991 --
992 type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
993
994 -- Yield the type of an annotated core expression.
995 --
996 annExprType :: AnnExpr Var ann -> Type
997 annExprType = exprType . deAnnotate
998
999 -- Project the vectorisation information from an annotated Core expression.
1000 --
1001 vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
1002 vectAvoidInfoOf ((_, vi), _) = vi
1003
1004 -- Is this a 'VIParr' node?
1005 --
1006 isVIParr :: CoreExprWithVectInfo -> Bool
1007 isVIParr = (== VIParr) . vectAvoidInfoOf
1008
1009 -- Is this a 'VIEncaps' node?
1010 --
1011 isVIEncaps :: CoreExprWithVectInfo -> Bool
1012 isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
1013
1014 -- Is this a 'VIDict' node?
1015 --
1016 isVIDict :: CoreExprWithVectInfo -> Bool
1017 isVIDict = (== VIDict) . vectAvoidInfoOf
1018
1019 -- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
1020 --
1021 unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
1022 unlessVIParr _ VIParr = VIParr
1023 unlessVIParr vi _ = vi
1024
1025 -- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
1026 -- information of the first argument is produced.
1027 --
1028 unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
1029 infixl `unlessVIParrExpr`
1030 unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
1031
1032 -- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
1033 --
1034 -- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
1035 --
1036 vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
1037 vectAvoidInfo pvs ce@(fvs, AnnVar v)
1038 = do
1039 { gpvs <- globalParallelVars
1040 ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
1041 then return VIParr
1042 else vectAvoidInfoTypeOf ce
1043 ; viTrace ce vi []
1044 ; when (vi == VIParr) $
1045 traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
1046 if v `elemVarSet` gpvs then text "global" else text "parallel type"
1047
1048 ; return ((fvs, vi), AnnVar v)
1049 }
1050
1051 vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
1052 = do
1053 { vi <- vectAvoidInfoTypeOf ce
1054 ; viTrace ce vi []
1055 ; return ((fvs, vi), AnnLit lit)
1056 }
1057
1058 vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
1059 = do
1060 { ceVI <- vectAvoidInfoTypeOf ce
1061 ; eVI1 <- vectAvoidInfo pvs e1
1062 ; eVI2 <- vectAvoidInfo pvs e2
1063 ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
1064 -- ; viTrace ce vi [eVI1, eVI2]
1065 ; return ((fvs, vi), AnnApp eVI1 eVI2)
1066 }
1067
1068 vectAvoidInfo pvs (fvs, AnnLam var body)
1069 = do
1070 { bodyVI <- vectAvoidInfo pvs body
1071 ; varVI <- vectAvoidInfoType $ varType var
1072 ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
1073 -- ; viTrace ce vi [bodyVI]
1074 ; return ((fvs, vi), AnnLam var bodyVI)
1075 }
1076
1077 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
1078 = do
1079 { ceVI <- vectAvoidInfoTypeOf ce
1080 ; eVI <- vectAvoidInfo pvs e
1081 ; isScalarTy <- isScalar $ varType var
1082 ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
1083 then do -- binding is parallel
1084 { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
1085 ; return (bodyVI, VIParr)
1086 }
1087 else do -- binding doesn't affect parallelism
1088 { bodyVI <- vectAvoidInfo pvs body
1089 ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
1090 }
1091 -- ; viTrace ce vi [eVI, bodyVI]
1092 ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
1093 }
1094
1095 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
1096 = do
1097 { ceVI <- vectAvoidInfoTypeOf ce
1098 ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
1099 ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
1100 ; if not . null $ parrBndrs
1101 then do -- body may trigger parallelism via at least one binding
1102 { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
1103 ; let extendedPvs = pvs `extendVarSetList` new_pvs
1104 ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
1105 ; bodyVI <- vectAvoidInfo extendedPvs body
1106 -- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
1107 ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
1108 }
1109 else do -- demanded bindings cannot trigger parallelism
1110 { bodyVI <- vectAvoidInfo pvs body
1111 ; let vi = ceVI `unlessVIParrExpr` bodyVI
1112 -- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
1113 ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
1114 }
1115 }
1116 where
1117 vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
1118
1119 isVIParrBnd (var, eVI)
1120 = do
1121 { isScalarTy <- isScalar (varType var)
1122 ; return $ isVIParr eVI && not isScalarTy
1123 }
1124
1125 vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
1126 = do
1127 { ceVI <- vectAvoidInfoTypeOf ce
1128 ; eVI <- vectAvoidInfo pvs e
1129 ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
1130 ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
1131 vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
1132 -- ; viTrace ce vi (eVI : alteVIs)
1133 ; return ((fvs, vi), AnnCase eVI var ty altsVI)
1134 }
1135 where
1136 vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
1137 = do
1138 { allScalar <- allScalarVarType bndrs
1139 ; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
1140 | otherwise = pvs
1141 ; (con, bndrs,) <$> vectAvoidInfo altPvs e
1142 }
1143
1144 vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
1145 = do
1146 { eVI <- vectAvoidInfo pvs e
1147 ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
1148 }
1149
1150 vectAvoidInfo pvs (fvs, AnnTick tick e)
1151 = do
1152 { eVI <- vectAvoidInfo pvs e
1153 ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
1154 }
1155
1156 vectAvoidInfo _pvs (fvs, AnnType ty)
1157 = return ((fvs, VISimple), AnnType ty)
1158
1159 vectAvoidInfo _pvs (fvs, AnnCoercion coe)
1160 = return ((fvs, VISimple), AnnCoercion coe)
1161
1162 -- Compute vectorisation avoidance information for a type.
1163 --
1164 vectAvoidInfoType :: Type -> VM VectAvoidInfo
1165 vectAvoidInfoType ty
1166 | isPredTy ty
1167 = return VIDict
1168 | Just (arg, res) <- splitFunTy_maybe ty
1169 = do
1170 { argVI <- vectAvoidInfoType arg
1171 ; resVI <- vectAvoidInfoType res
1172 ; case (argVI, resVI) of
1173 (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
1174 (_ , VIDict) -> return VIDict
1175 _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
1176 }
1177 | otherwise
1178 = do
1179 { parr <- maybeParrTy ty
1180 ; if parr
1181 then return VIParr
1182 else do
1183 { scalar <- isScalar ty
1184 ; if scalar
1185 then return VISimple
1186 else return VIComplex
1187 } }
1188
1189 -- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
1190 --
1191 vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
1192 vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
1193
1194 -- Checks whether the type might be a parallel array type.
1195 --
1196 maybeParrTy :: Type -> VM Bool
1197 maybeParrTy ty
1198 -- looking through newtypes
1199 | Just ty' <- coreView ty
1200 = (== VIParr) <$> vectAvoidInfoType ty'
1201 -- decompose constructor applications
1202 | Just (tc, ts) <- splitTyConApp_maybe ty
1203 = do
1204 { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
1205 ; if isParallel
1206 then return True
1207 else or <$> mapM maybeParrTy ts
1208 }
1209 maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
1210 maybeParrTy _ = return False
1211
1212 -- Are the types of all variables in the 'Scalar' class?
1213 --
1214 allScalarVarType :: [Var] -> VM Bool
1215 allScalarVarType vs = and <$> mapM (isScalar . varType) vs
1216
1217 -- Are the types of all variables in the set in the 'Scalar' class?
1218 --
1219 allScalarVarTypeSet :: VarSet -> VM Bool
1220 allScalarVarTypeSet = allScalarVarType . varSetElems
1221
1222 -- Debugging support
1223 --
1224 viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
1225 viTrace ce vi vTs
1226 = traceVt ("vect info: " ++ show vi ++ "[" ++
1227 (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
1228 (ppr $ deAnnotate ce)