ae7483a612d0cb32226cc0b2fda2fa897a847428
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE CPP, TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise right-hand sides of toplevel bindings
7 vectTopExpr
8 , vectTopExprs
9 , vectScalarFun
10 , vectScalarDFun
11 )
12 where
13
14 #include "HsVersions.h"
15
16 import Vectorise.Type.Type
17 import Vectorise.Var
18 import Vectorise.Convert
19 import Vectorise.Vect
20 import Vectorise.Env
21 import Vectorise.Monad
22 import Vectorise.Builtins
23 import Vectorise.Utils
24
25 import CoreUtils
26 import MkCore
27 import CoreSyn
28 import CoreFVs
29 import Class
30 import DataCon
31 import TyCon
32 import TcType
33 import Type
34 import TypeRep
35 import Var
36 import VarEnv
37 import VarSet
38 import NameSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysPrim
43 import Outputable
44 import FastString
45 import DynFlags
46 import Util
47 #if __GLASGOW_HASKELL__ < 709
48 import MonadUtils
49 #endif
50
51 import Control.Monad
52 import Data.Maybe
53 import Data.List
54
55
56 -- Main entry point to vectorise expressions -----------------------------------
57
58 -- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
59 --
60 -- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
61 -- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
62 -- tagged as 'VIParr').
63 --
64 -- We have got the non-recursive case as a special case as it doesn't require to compute
65 -- vectorisation information twice.
66 --
67 vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
68 vectTopExpr var expr
69 = do
70 { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
71 ; if isVIEncaps exprVI
72 then
73 return Nothing
74 else do
75 { vExpr <- closedV $
76 inBind var $
77 vectAnnPolyExpr False exprVI
78 ; inline <- computeInline exprVI
79 ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
80 }
81 }
82
83 -- Compute the inlining hint for the right-hand side of a top-level binding.
84 --
85 computeInline :: CoreExprWithVectInfo -> VM Inline
86 computeInline ((_, VIDict), _) = return $ DontInline
87 computeInline (_, AnnTick _ expr) = computeInline expr
88 computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
89 where
90 (tvs, _) = collectAnnTypeBinders expr
91 computeInline _expr = return $ DontInline
92
93 -- |Vectorise a recursive group of top-level polymorphic expressions.
94 --
95 -- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
96 -- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
97 -- tagged as 'VIParr').
98 --
99 vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
100 vectTopExprs binds
101 = do
102 { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
103 ; if all isVIEncaps exprVIs
104 -- if all bindings are scalar => don't vectorise this group of bindings
105 then return Nothing
106 else do
107 { -- non-scalar bindings need to be vectorised
108 ; let areVIParr = any isVIParr exprVIs
109 ; revised_exprVIs <- if not areVIParr
110 -- if no binding is parallel => 'exprVIs' is ready for vectorisation
111 then return exprVIs
112 -- if any binding is parallel => recompute the vectorisation info
113 else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
114
115 ; vExprs <- zipWithM vect vars revised_exprVIs
116 ; return $ Just (areVIParr, vExprs)
117 }
118 }
119 where
120 (vars, exprs) = unzip binds
121
122 vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
123
124 vect var exprVI
125 = do
126 { vExpr <- closedV $
127 inBind var $
128 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
129 ; inline <- computeInline exprVI
130 ; return (inline, vectorised vExpr)
131 }
132
133 -- |Vectorise a polymorphic expression annotated with vectorisation information.
134 --
135 -- The special case of dictionary functions is currently handled separately. (Would be neater to
136 -- integrate them, though!)
137 --
138 vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
139 vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
140 -- traverse through ticks
141 = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
142 vectAnnPolyExpr loop_breaker expr
143 | isVIDict expr
144 -- special case the right-hand side of dictionary functions
145 = (, undefined) <$> vectDictExpr (deAnnotate expr)
146 | otherwise
147 -- collect and vectorise type abstractions; then, descent into the body
148 = polyAbstract tvs $ \args ->
149 mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
150 where
151 (tvs, mono) = collectAnnTypeBinders expr
152
153 -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
154 -- lambda abstraction over all its free variables followed by the corresponding application to those
155 -- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
156 --
157 -- Preconditions:
158 --
159 -- * All free variables and the result type must be /simple/ types.
160 -- * The expression is sufficiently complex (to warrant special treatment). For now, that is
161 -- every expression that is not constant and contains at least one operation.
162 --
163 --
164 -- The user has an option to choose between aggressive and minimal vectorisation avoidance. With
165 -- minimal vectorisation avoidance, we only encapsulate individual scalar operations. With
166 -- aggressive vectorisation avoidance, we encapsulate subexpression that are as big as possible.
167 --
168 encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
169 encapsulateScalars ce@(_, AnnType _ty)
170 = return ce
171 encapsulateScalars ce@((_, VISimple), AnnVar _v)
172 -- NB: diverts from the paper: encapsulate scalar variables (including functions)
173 = liftSimpleAndCase ce
174 encapsulateScalars ce@(_, AnnVar _v)
175 = return ce
176 encapsulateScalars ce@(_, AnnLit _)
177 = return ce
178 encapsulateScalars ((fvs, vi), AnnTick tck expr)
179 = do
180 { encExpr <- encapsulateScalars expr
181 ; return ((fvs, vi), AnnTick tck encExpr)
182 }
183 encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
184 = do
185 { vectAvoid <- isVectAvoidanceAggressive
186 ; varsS <- allScalarVarTypeSet fvs
187 -- NB: diverts from the paper: we need to check the scalarness of bound variables as well,
188 -- as 'vectScalarFun' will handle them just the same as those introduced for the 'fvs'
189 -- by encapsulation.
190 ; bndrsS <- allScalarVarType bndrs
191 ; case (vi, vectAvoid && varsS && bndrsS) of
192 (VISimple, True) -> liftSimpleAndCase ce
193 _ -> do
194 { encExpr <- encapsulateScalars expr
195 ; return ((fvs, vi), AnnLam bndr encExpr)
196 }
197 }
198 where
199 (bndrs, _) = collectAnnBndrs ce
200 encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
201 = do
202 { vectAvoid <- isVectAvoidanceAggressive
203 ; varsS <- allScalarVarTypeSet fvs
204 ; case (vi, (vectAvoid || isSimpleApplication ce) && varsS) of
205 (VISimple, True) -> liftSimpleAndCase ce
206 _ -> do
207 { encCe1 <- encapsulateScalars ce1
208 ; encCe2 <- encapsulateScalars ce2
209 ; return ((fvs, vi), AnnApp encCe1 encCe2)
210 }
211 }
212 where
213 isSimpleApplication :: CoreExprWithVectInfo -> Bool
214 isSimpleApplication (_, AnnTick _ ce) = isSimpleApplication ce
215 isSimpleApplication (_, AnnCast ce _) = isSimpleApplication ce
216 isSimpleApplication ce | isSimple ce = True
217 isSimpleApplication (_, AnnApp ce1 ce2) = isSimple ce1 && isSimpleApplication ce2
218 isSimpleApplication _ = False
219 --
220 isSimple :: CoreExprWithVectInfo -> Bool
221 isSimple (_, AnnType {}) = True
222 isSimple (_, AnnVar {}) = True
223 isSimple (_, AnnLit {}) = True
224 isSimple (_, AnnTick _ ce) = isSimple ce
225 isSimple (_, AnnCast ce _) = isSimple ce
226 isSimple _ = False
227 encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
228 = do
229 { vectAvoid <- isVectAvoidanceAggressive
230 ; varsS <- allScalarVarTypeSet fvs
231 ; case (vi, vectAvoid && varsS) of
232 (VISimple, True) -> liftSimpleAndCase ce
233 _ -> do
234 { encScrut <- encapsulateScalars scrut
235 ; encAlts <- mapM encAlt alts
236 ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
237 }
238 }
239 where
240 encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
241 encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
242 = do
243 { vectAvoid <- isVectAvoidanceAggressive
244 ; varsS <- allScalarVarTypeSet fvs
245 ; case (vi, vectAvoid && varsS) of
246 (VISimple, True) -> liftSimpleAndCase ce
247 _ -> do
248 { encExpr1 <- encapsulateScalars expr1
249 ; encExpr2 <- encapsulateScalars expr2
250 ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
251 }
252 }
253 encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
254 = do
255 { vectAvoid <- isVectAvoidanceAggressive
256 ; varsS <- allScalarVarTypeSet fvs
257 ; case (vi, vectAvoid && varsS) of
258 (VISimple, True) -> liftSimpleAndCase ce
259 _ -> do
260 { encBinds <- mapM encBind binds
261 ; encExpr <- encapsulateScalars expr
262 ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
263 }
264 }
265 where
266 encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
267 encapsulateScalars ((fvs, vi), AnnCast expr coercion)
268 = do
269 { encExpr <- encapsulateScalars expr
270 ; return ((fvs, vi), AnnCast encExpr coercion)
271 }
272 encapsulateScalars _
273 = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
274
275 -- Lambda-lift the given simple expression and apply it to the abstracted free variables.
276 --
277 -- If the expression is a case expression scrutinising anything, but a scalar type, then lift
278 -- each alternative individually.
279 --
280 liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
281 liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
282 = do
283 { vi <- vectAvoidInfoTypeOf expr
284 ; if (vi == VISimple)
285 then
286 liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
287 else do
288 { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
289 ; return ((fvs, vi), AnnCase expr bndr t alts')
290 }
291 }
292 liftSimpleAndCase aexpr = liftSimple aexpr
293
294 liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
295 liftSimple ((fvs, vi), AnnVar v)
296 | v `elemVarSet` fvs -- special case to avoid producing: (\v -> v) v
297 && not (isToplevel v) -- NB: if 'v' not free or is toplevel, we must get the 'VIEncaps'
298 = return $ ((fvs, vi), AnnVar v)
299 liftSimple aexpr@((fvs_orig, VISimple), expr)
300 = do
301 { let liftedExpr = mkAnnApps (mkAnnLams (reverse vars) fvs expr) vars
302
303 ; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
304
305 ; return $ liftedExpr
306 }
307 where
308 vars = varSetElems fvs
309 fvs = filterVarSet (not . isToplevel) fvs_orig -- only include 'Id's that are not toplevel
310
311 mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
312 mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
313 ((emptyVarSet, VIEncaps), expr)
314 mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
315
316 mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
317 mkAnnApps aexpr [] = aexpr
318 mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
319
320 mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
321 mkAnnApp aexpr@((fvs, _vi), _expr) v
322 = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
323 liftSimple aexpr
324 = pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
325
326 isToplevel :: Var -> Bool
327 isToplevel v | isId v = case realIdUnfolding v of
328 NoUnfolding -> False
329 OtherCon {} -> True
330 DFunUnfolding {} -> True
331 CoreUnfolding {uf_is_top = top} -> top
332 | otherwise = False
333
334 -- |Vectorise an expression.
335 --
336 vectExpr :: CoreExprWithVectInfo -> VM VExpr
337
338 vectExpr aexpr
339 -- encapsulated expression of functional type => try to vectorise as a scalar subcomputation
340 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
341 = vectFnExpr True False aexpr
342 -- encapsulated constant => vectorise as a scalar constant
343 | isVIEncaps aexpr
344 = traceVt "vectExpr (encapsulated constant):" (ppr . deAnnotate $ aexpr) >>
345 vectConst (deAnnotate aexpr)
346
347 vectExpr (_, AnnVar v)
348 = vectVar v
349
350 vectExpr (_, AnnLit lit)
351 = vectConst $ Lit lit
352
353 vectExpr aexpr@(_, AnnLam _ _)
354 = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >>
355 vectFnExpr True False aexpr
356
357 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
358 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
359 -- happy.
360 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
361 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
362 | v == pAT_ERROR_ID
363 = do
364 { (vty, lty) <- vectAndLiftType ty
365 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
366 }
367 where
368 err' = deAnnotate err
369
370 -- type application (handle multiple consecutive type applications simultaneously to ensure the
371 -- PA dictionaries are put at the right places)
372 vectExpr e@(_, AnnApp _ arg)
373 | isAnnTypeArg arg
374 = vectPolyApp e
375
376 -- Lifted literal
377 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
378 | Just _con <- isDataConId_maybe v
379 = do
380 { let vexpr = App (Var v) (Lit lit)
381 ; lexpr <- liftPD vexpr
382 ; return (vexpr, lexpr)
383 }
384
385 -- value application (dictionary or user value)
386 vectExpr e@(_, AnnApp fn arg)
387 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
388 = vectPolyApp e
389 | otherwise -- user value
390 = do
391 { -- vectorise the types
392 ; varg_ty <- vectType arg_ty
393 ; vres_ty <- vectType res_ty
394
395 -- vectorise the function and argument expression
396 ; vfn <- vectExpr fn
397 ; varg <- vectExpr arg
398
399 -- the vectorised function is a closure; apply it to the vectorised argument
400 ; mkClosureApp varg_ty vres_ty vfn varg
401 }
402 where
403 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
404
405 vectExpr (_, AnnCase scrut bndr ty alts)
406 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
407 , isAlgTyCon tycon
408 = vectAlgCase tycon ty_args scrut bndr ty alts
409 | otherwise
410 = do
411 { dflags <- getDynFlags
412 ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
413 ppr scrut_ty
414 }
415 where
416 scrut_ty = exprType (deAnnotate scrut)
417
418 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
419 = do
420 { traceVt "let binding (non-recursive)" Outputable.empty
421 ; vrhs <- localV $
422 inBind bndr $
423 vectAnnPolyExpr False rhs
424 ; traceVt "let body (non-recursive)" Outputable.empty
425 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
426 ; return $ vLet (vNonRec vbndr vrhs) vbody
427 }
428
429 vectExpr (_, AnnLet (AnnRec bs) body)
430 = do
431 { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
432 { traceVt "let bindings (recursive)" Outputable.empty
433 ; vrhss <- zipWithM vect_rhs bndrs rhss
434 ; traceVt "let body (recursive)" Outputable.empty
435 ; vbody <- vectExpr body
436 ; return (vrhss, vbody)
437 }
438 ; return $ vLet (vRec vbndrs vrhss) vbody
439 }
440 where
441 (bndrs, rhss) = unzip bs
442
443 vect_rhs bndr rhs = localV $
444 inBind bndr $
445 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
446
447 vectExpr (_, AnnTick tickish expr)
448 = vTick tickish <$> vectExpr expr
449
450 vectExpr (_, AnnType ty)
451 = vType <$> vectType ty
452
453 vectExpr e
454 = do
455 { dflags <- getDynFlags
456 ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
457 }
458
459 -- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
460 -- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
461 -- zip).
462 --
463 -- We do not handle type variables at this point, as they will already have been stripped off by
464 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
465 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
466 --
467 vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
468 -- should be inlined
469 -> Bool -- ^Whether the binding is a loop breaker
470 -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
471 -> VM VExpr
472 vectFnExpr inline loop_breaker aexpr@(_ann, AnnLam bndr body)
473 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
474 | isId bndr
475 && isPredTy (idType bndr)
476 = do
477 { vBndr <- vectBndr bndr
478 ; vbody <- vectFnExpr inline loop_breaker body
479 ; return $ mapVect (mkLams [vectorised vBndr]) vbody
480 }
481 -- encapsulated non-predicate abstraction: vectorise as a scalar computation
482 | isId bndr && isVIEncaps aexpr
483 = vectScalarFun . deAnnotate $ aexpr
484 -- non-predicate abstraction: vectorise as a non-scalar computation
485 | isId bndr
486 = vectLam inline loop_breaker aexpr
487 | otherwise
488 = do
489 { dflags <- getDynFlags
490 ; cantVectorise dflags "Vectorise.Exp.vectFnExpr: Unexpected type lambda" $
491 ppr (deAnnotate aexpr)
492 }
493 vectFnExpr _ _ aexpr
494 -- encapsulated function: vectorise as a scalar computation
495 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
496 = vectScalarFun . deAnnotate $ aexpr
497 | otherwise
498 -- not an abstraction: vectorise as a non-scalar vanilla expression
499 -- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr'
500 = vectExpr aexpr
501
502 -- |Vectorise type and dictionary applications.
503 --
504 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
505 -- involve two sets of type variables and dictionaries. Consider,
506 --
507 -- > class C a where
508 -- > m :: D b => b -> a
509 --
510 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
511 --
512 vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
513 vectPolyApp e0
514 = case e4 of
515 (_, AnnVar var)
516 -> do { -- get the vectorised form of the variable
517 ; vVar <- lookupVar var
518 ; traceVt "vectPolyApp of" (ppr var)
519
520 -- vectorise type and dictionary arguments
521 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
522 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
523 ; vTysOuter <- mapM vectType tysOuter
524 ; vTysInner <- mapM vectType tysInner
525
526 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
527
528 ; case vVar of
529 Local (vv, lv)
530 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
531 ; traceVt " LOCAL" (text "")
532 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
533 }
534 Global vv
535 | isDictComp var -- dictionary computation
536 -> do { -- in a dictionary computation, the innermost, non-empty set of
537 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
538 -- are needed for the type variables
539 ; ve <- if null dictsInner
540 then
541 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
542 else
543 reconstructOuter
544 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
545 ; traceVt " GLOBAL (dict):" (ppr ve)
546 ; vectConst ve
547 }
548 | otherwise -- non-dictionary computation
549 -> do { MASSERT( null dictsInner )
550 ; ve <- reconstructOuter (Var vv)
551 ; traceVt " GLOBAL (non-dict):" (ppr ve)
552 ; vectConst ve
553 }
554 }
555 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
556 where
557 -- if there is only one set of variables or dictionaries, it will be the outer set
558 (e1, dictsOuter) = collectAnnDictArgs e0
559 (e2, tysOuter) = collectAnnTypeArgs e1
560 (e3, dictsInner) = collectAnnDictArgs e2
561 (e4, tysInner) = collectAnnTypeArgs e3
562 --
563 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
564
565 -- |Vectorise the body of a dfun.
566 --
567 -- Dictionary computations are special for the following reasons. The application of dictionary
568 -- functions are always saturated, so there is no need to create closures. Dictionary computations
569 -- don't depend on array values, so they are always scalar computations whose result we can
570 -- replicate (instead of executing them in parallel).
571 --
572 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
573 -- computation. Consequently, the variable case needs to deal with cases where binders are
574 -- in the vectoriser environments and where that is not the case.
575 --
576 vectDictExpr :: CoreExpr -> VM CoreExpr
577 vectDictExpr (Var var)
578 = do { mb_scope <- lookupVar_maybe var
579 ; case mb_scope of
580 Nothing -> return $ Var var -- binder from within the dict. computation
581 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
582 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
583 }
584 vectDictExpr (Lit lit)
585 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
586 vectDictExpr (Lam bndr e)
587 = Lam bndr <$> vectDictExpr e
588 vectDictExpr (App fn arg)
589 = App <$> vectDictExpr fn <*> vectDictExpr arg
590 vectDictExpr (Case e bndr ty alts)
591 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
592 where
593 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
594 --
595 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
596 where
597 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
598 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
599 vectDictAltCon DEFAULT = return DEFAULT
600 vectDictExpr (Let bnd body)
601 = Let <$> vectDictBind bnd <*> vectDictExpr body
602 where
603 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
604 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
605 vectDictExpr e@(Cast _e _coe)
606 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
607 vectDictExpr (Tick tickish e)
608 = Tick tickish <$> vectDictExpr e
609 vectDictExpr (Type ty)
610 = Type <$> vectType ty
611 vectDictExpr (Coercion coe)
612 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
613
614 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
615 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
616 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
617 -- require the full blown vectorisation transformation; instead, they can be lifted by application
618 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
619 --
620 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
621 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
622 -- "Note [Scalar dfuns]" in 'Vectorise'.
623 --
624 vectScalarFun :: CoreExpr -> VM VExpr
625 vectScalarFun expr
626 = do
627 { traceVt "vectScalarFun:" (ppr expr)
628 ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
629 ; mkScalarFun arg_tys res_ty expr
630 }
631
632 -- Generate code for a scalar function by generating a scalar closure. If the function is a
633 -- dictionary function, vectorise it as dictionary code.
634 --
635 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
636 mkScalarFun arg_tys res_ty expr
637 | isPredTy res_ty
638 = do { vExpr <- vectDictExpr expr
639 ; return (vExpr, unused)
640 }
641 | otherwise
642 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
643
644 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
645 ; zipf <- zipScalars arg_tys res_ty
646 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
647 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
648 ; lclo <- liftPD (Var clo_var)
649 ; return (Var clo_var, lclo)
650 }
651 where
652 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
653
654 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
655 --
656 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
657 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
658 --
659 -- NB: You may think that we could implement this function guided by the struture of the Core
660 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
661 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
662 -- to the Core code of the unvectorised dfun.
663 --
664 -- Here an example — assume,
665 --
666 -- > class Eq a where { (==) :: a -> a -> Bool }
667 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
668 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
669 --
670 -- The unvectorised dfun for the above instance has the following signature:
671 --
672 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
673 --
674 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
675 --
676 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
677 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
678 -- > D:V:Eq $(vectScalarFun True recFns
679 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
680 --
681 -- NB:
682 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
683 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
684 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
685 --
686 vectScalarDFun :: Var -- ^ Original dfun
687 -> VM CoreExpr
688 vectScalarDFun var
689 = do { -- bring the type variables into scope
690 ; mapM_ defLocalTyVar tvs
691
692 -- vectorise dictionary argument types and generate variables for them
693 ; vTheta <- mapM vectType theta
694 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
695 ; let vThetaVars = varsToCoreExprs vThetaBndr
696
697 -- vectorise superclass dictionaries and methods as scalar expressions
698 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
699 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
700 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
701 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
702 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
703 selIds
704 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
705
706 -- vectorised applications of the class-dictionary data constructor
707 ; Just vDataCon <- lookupDataCon dataCon
708 ; vTys <- mapM vectType tys
709 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
710
711 ; return $ mkLams (tvs ++ vThetaBndr) vBody
712 }
713 where
714 ty = varType var
715 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
716 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
717 selIds = classAllSelIds cls
718 dataCon = classDataCon cls
719
720 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
721 -- expression computing the vectorised dictionary.
722 --
723 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
724 -- the unvectorised version, thus:
725 --
726 -- > D:C op1 .. opm
727 -- > where
728 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
729 --
730 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
731 --
732 unVectDict :: Type -> CoreExpr -> VM CoreExpr
733 unVectDict ty e
734 = do { vTys <- mapM vectType tys
735 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
736 ; scOps <- zipWithM fromVect methTys meths
737 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
738 }
739 where
740 (tycon, tys) = splitTyConApp ty
741 Just dataCon = isDataProductTyCon_maybe tycon
742 Just cls = tyConClass_maybe tycon
743 methTys = dataConInstArgTys dataCon tys
744 selIds = classAllSelIds cls
745
746 -- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
747 --
748 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
749 -- variables are passed explicit (as conventional arguments) into the body during closure
750 -- construction.
751 --
752 vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
753 -> Bool -- ^ Whether the binding is a loop breaker.
754 -> CoreExprWithVectInfo -- ^ Body of abstraction.
755 -> VM VExpr
756 vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
757 = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
758
759 ; let (bndrs, body) = collectAnnValBinders expr
760
761 -- grab the in-scope type variables
762 ; tyvars <- localTyVars
763
764 -- collect and vectorise all /local/ free variables
765 ; vfvs <- readLEnv $ \env ->
766 [ (var, fromJust mb_vv)
767 | var <- varSetElems fvs
768 , let mb_vv = lookupVarEnv (local_vars env) var
769 , isJust mb_vv -- its local == is in local var env
770 ]
771 -- separate dictionary from non-dictionary variables in the free variable set
772 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
773 (_fvs_dict, vfvs_dict) = unzip vvs_dict
774 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
775
776 -- compute the type of the vectorised closure
777 ; arg_tys <- mapM (vectType . idType) bndrs
778 ; res_ty <- vectType (exprType $ deAnnotate body)
779
780 ; let arity = length fvs_nondict + length bndrs
781 vfvs_dict' = map vectorised vfvs_dict
782 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
783 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
784 $ do { -- generate the vectorised body of the lambda abstraction
785 ; lc <- builtin liftingContext
786 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
787
788 ; vbody' <- break_loop lc res_ty vbody
789 ; return $ vLams lc vbndrs vbody'
790 }
791 }
792 where
793 maybe_inline n | inline = Inline n
794 | otherwise = DontInline
795
796 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
797 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
798 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
799 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
800 -- in Figure 6 of HtM.
801 break_loop lc ty (ve, le)
802 | loop_breaker
803 = do { dflags <- getDynFlags
804 ; empty <- emptyPD ty
805 ; lty <- mkPDataType ty
806 ; return (ve, mkWildCase (Var lc) intPrimTy lty
807 [(DEFAULT, [], le),
808 (LitAlt (mkMachInt dflags 0), [], empty)])
809 }
810 | otherwise = return (ve, le)
811 vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
812
813 -- Vectorise an algebraic case expression.
814 --
815 -- We convert
816 --
817 -- case e :: t of v { ... }
818 --
819 -- to
820 --
821 -- V: let v' = e in case v' of _ { ... }
822 -- L: let v' = e in case v' `cast` ... of _ { ... }
823 --
824 -- When lifting, we have to do it this way because v must have the type
825 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
826 -- have to handle the case where v is a wild var correctly.
827 --
828
829 -- FIXME: this is too lazy...is it?
830 vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
831 -> [(AltCon, [Var], CoreExprWithVectInfo)]
832 -> VM VExpr
833 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
834 = do
835 { traceVt "scrutinee (DEFAULT only)" Outputable.empty
836 ; vscrut <- vectExpr scrut
837 ; (vty, lty) <- vectAndLiftType ty
838 ; traceVt "alternative body (DEFAULT only)" Outputable.empty
839 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
840 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
841 }
842 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
843 = do
844 { traceVt "scrutinee (one shot w/o binders)" Outputable.empty
845 ; vscrut <- vectExpr scrut
846 ; (vty, lty) <- vectAndLiftType ty
847 ; traceVt "alternative body (one shot w/o binders)" Outputable.empty
848 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
849 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
850 }
851 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
852 = do
853 { traceVt "scrutinee (one shot w/ binders)" Outputable.empty
854 ; vexpr <- vectExpr scrut
855 ; (vty, lty) <- vectAndLiftType ty
856 ; traceVt "alternative body (one shot w/ binders)" Outputable.empty
857 ; (vbndr, (vbndrs, (vect_body, lift_body)))
858 <- vect_scrut_bndr
859 . vectBndrsIn bndrs
860 $ vectExpr body
861 ; let (vect_bndrs, lift_bndrs) = unzip vbndrs
862 ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
863 ; vect_dc <- maybeV dataConErr (lookupDataCon dc)
864
865 ; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
866 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
867
868 ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
869 }
870 where
871 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
872 | otherwise = vectBndrIn bndr
873
874 mk_wild_case expr ty dc bndrs body
875 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
876
877 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
878
879 vectAlgCase tycon _ty_args scrut bndr ty alts
880 = do
881 { traceVt "scrutinee (general case)" Outputable.empty
882 ; vexpr <- vectExpr scrut
883
884 ; vect_tc <- vectTyCon tycon
885 ; (vty, lty) <- vectAndLiftType ty
886
887 ; let arity = length (tyConDataCons vect_tc)
888 ; sel_ty <- builtin (selTy arity)
889 ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
890 ; let sel = Var sel_bndr
891
892 ; traceVt "alternatives' body (general case)" Outputable.empty
893 ; (vbndr, valts) <- vect_scrut_bndr
894 $ mapM (proc_alt arity sel vty lty) alts'
895 ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
896
897 ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
898
899 ; let (vect_bodies, lift_bodies) = unzip vbodies
900
901 ; vdummy <- newDummyVar (exprType vect_scrut)
902 ; ldummy <- newDummyVar (exprType lift_scrut)
903 ; let vect_case = Case vect_scrut vdummy vty
904 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
905
906 ; lc <- builtin liftingContext
907 ; lbody <- combinePD vty (Var lc) sel lift_bodies
908 ; let lift_case = Case lift_scrut ldummy lty
909 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
910 lbody)]
911
912 ; return . vLet (vNonRec vbndr vexpr)
913 $ (vect_case, lift_case)
914 }
915 where
916 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
917 | otherwise = vectBndrIn bndr
918
919 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
920
921 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
922 cmp DEFAULT DEFAULT = EQ
923 cmp DEFAULT _ = LT
924 cmp _ DEFAULT = GT
925 cmp _ _ = panic "vectAlgCase/cmp"
926
927 proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
928 = do
929 dflags <- getDynFlags
930 vect_dc <- maybeV dataConErr (lookupDataCon dc)
931 let ntag = dataConTagZ vect_dc
932 tag = mkDataConTag dflags vect_dc
933 fvs = fvs_body `delVarSetList` bndrs
934
935 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
936 lc <- builtin liftingContext
937 elems <- builtin (selElements arity ntag)
938
939 (vbndrs, vbody)
940 <- vectBndrsIn bndrs
941 . localV
942 $ do
943 { binds <- mapM (pack_var (Var lc) sel_tags tag)
944 . filter isLocalId
945 $ varSetElems fvs
946 ; traceVt "case alternative:" (ppr . deAnnotate $ body)
947 ; (ve, le) <- vectExpr body
948 ; return (ve, Case (elems `App` sel) lc lty
949 [(DEFAULT, [], (mkLets (concat binds) le))])
950 }
951 -- empty <- emptyPD vty
952 -- return (ve, Case (elems `App` sel) lc lty
953 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
954 -- $ mkLets (concat binds) le),
955 -- (LitAlt (mkMachInt 0), [], empty)])
956 let (vect_bndrs, lift_bndrs) = unzip vbndrs
957 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
958 where
959 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
960
961 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
962
963 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
964
965 -- Pack a variable for a case alternative context *if* the variable is vectorised. If it
966 -- isn't, ignore it as scalar variables don't need to be packed.
967 pack_var len tags t v
968 = do
969 { r <- lookupVar_maybe v
970 ; case r of
971 Just (Local (vv, lv)) ->
972 do
973 { lv' <- cloneVar lv
974 ; expr <- packByTagPD (idType vv) (Var lv) len tags t
975 ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
976 ; return [(NonRec lv' expr)]
977 }
978 _ -> return []
979 }
980
981
982 -- Support to compute information for vectorisation avoidance ------------------
983
984 -- Annotation for Core AST nodes that describes how they should be handled during vectorisation
985 -- and especially if vectorisation of the corresponding computation can be avoided.
986 --
987 data VectAvoidInfo = VIParr -- tree contains parallel computations
988 | VISimple -- result type is scalar & no parallel subcomputation
989 | VIComplex -- any result type, no parallel subcomputation
990 | VIEncaps -- tree encapsulated by 'liftSimple'
991 | VIDict -- dictionary computation (never parallel)
992 deriving (Eq, Show)
993
994 -- Core expression annotated with free variables and vectorisation-specific information.
995 --
996 type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
997
998 -- Yield the type of an annotated core expression.
999 --
1000 annExprType :: AnnExpr Var ann -> Type
1001 annExprType = exprType . deAnnotate
1002
1003 -- Project the vectorisation information from an annotated Core expression.
1004 --
1005 vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
1006 vectAvoidInfoOf ((_, vi), _) = vi
1007
1008 -- Is this a 'VIParr' node?
1009 --
1010 isVIParr :: CoreExprWithVectInfo -> Bool
1011 isVIParr = (== VIParr) . vectAvoidInfoOf
1012
1013 -- Is this a 'VIEncaps' node?
1014 --
1015 isVIEncaps :: CoreExprWithVectInfo -> Bool
1016 isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
1017
1018 -- Is this a 'VIDict' node?
1019 --
1020 isVIDict :: CoreExprWithVectInfo -> Bool
1021 isVIDict = (== VIDict) . vectAvoidInfoOf
1022
1023 -- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
1024 --
1025 unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
1026 unlessVIParr _ VIParr = VIParr
1027 unlessVIParr vi _ = vi
1028
1029 -- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
1030 -- information of the first argument is produced.
1031 --
1032 unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
1033 infixl `unlessVIParrExpr`
1034 unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
1035
1036 -- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
1037 --
1038 -- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
1039 --
1040 vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
1041 vectAvoidInfo pvs ce@(fvs, AnnVar v)
1042 = do
1043 { gpvs <- globalParallelVars
1044 ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
1045 then return VIParr
1046 else vectAvoidInfoTypeOf ce
1047 ; viTrace ce vi []
1048 ; when (vi == VIParr) $
1049 traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
1050 if v `elemVarSet` gpvs then text "global" else text "parallel type"
1051
1052 ; return ((fvs, vi), AnnVar v)
1053 }
1054
1055 vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
1056 = do
1057 { vi <- vectAvoidInfoTypeOf ce
1058 ; viTrace ce vi []
1059 ; return ((fvs, vi), AnnLit lit)
1060 }
1061
1062 vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
1063 = do
1064 { ceVI <- vectAvoidInfoTypeOf ce
1065 ; eVI1 <- vectAvoidInfo pvs e1
1066 ; eVI2 <- vectAvoidInfo pvs e2
1067 ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
1068 -- ; viTrace ce vi [eVI1, eVI2]
1069 ; return ((fvs, vi), AnnApp eVI1 eVI2)
1070 }
1071
1072 vectAvoidInfo pvs (fvs, AnnLam var body)
1073 = do
1074 { bodyVI <- vectAvoidInfo pvs body
1075 ; varVI <- vectAvoidInfoType $ varType var
1076 ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
1077 -- ; viTrace ce vi [bodyVI]
1078 ; return ((fvs, vi), AnnLam var bodyVI)
1079 }
1080
1081 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
1082 = do
1083 { ceVI <- vectAvoidInfoTypeOf ce
1084 ; eVI <- vectAvoidInfo pvs e
1085 ; isScalarTy <- isScalar $ varType var
1086 ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
1087 then do -- binding is parallel
1088 { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
1089 ; return (bodyVI, VIParr)
1090 }
1091 else do -- binding doesn't affect parallelism
1092 { bodyVI <- vectAvoidInfo pvs body
1093 ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
1094 }
1095 -- ; viTrace ce vi [eVI, bodyVI]
1096 ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
1097 }
1098
1099 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
1100 = do
1101 { ceVI <- vectAvoidInfoTypeOf ce
1102 ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
1103 ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
1104 ; if not . null $ parrBndrs
1105 then do -- body may trigger parallelism via at least one binding
1106 { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
1107 ; let extendedPvs = pvs `extendVarSetList` new_pvs
1108 ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
1109 ; bodyVI <- vectAvoidInfo extendedPvs body
1110 -- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
1111 ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
1112 }
1113 else do -- demanded bindings cannot trigger parallelism
1114 { bodyVI <- vectAvoidInfo pvs body
1115 ; let vi = ceVI `unlessVIParrExpr` bodyVI
1116 -- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
1117 ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
1118 }
1119 }
1120 where
1121 vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
1122
1123 isVIParrBnd (var, eVI)
1124 = do
1125 { isScalarTy <- isScalar (varType var)
1126 ; return $ isVIParr eVI && not isScalarTy
1127 }
1128
1129 vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
1130 = do
1131 { ceVI <- vectAvoidInfoTypeOf ce
1132 ; eVI <- vectAvoidInfo pvs e
1133 ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
1134 ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
1135 vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
1136 -- ; viTrace ce vi (eVI : alteVIs)
1137 ; return ((fvs, vi), AnnCase eVI var ty altsVI)
1138 }
1139 where
1140 vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
1141 = do
1142 { allScalar <- allScalarVarType bndrs
1143 ; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
1144 | otherwise = pvs
1145 ; (con, bndrs,) <$> vectAvoidInfo altPvs e
1146 }
1147
1148 vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
1149 = do
1150 { eVI <- vectAvoidInfo pvs e
1151 ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
1152 }
1153
1154 vectAvoidInfo pvs (fvs, AnnTick tick e)
1155 = do
1156 { eVI <- vectAvoidInfo pvs e
1157 ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
1158 }
1159
1160 vectAvoidInfo _pvs (fvs, AnnType ty)
1161 = return ((fvs, VISimple), AnnType ty)
1162
1163 vectAvoidInfo _pvs (fvs, AnnCoercion coe)
1164 = return ((fvs, VISimple), AnnCoercion coe)
1165
1166 -- Compute vectorisation avoidance information for a type.
1167 --
1168 vectAvoidInfoType :: Type -> VM VectAvoidInfo
1169 vectAvoidInfoType ty
1170 | isPredTy ty
1171 = return VIDict
1172 | Just (arg, res) <- splitFunTy_maybe ty
1173 = do
1174 { argVI <- vectAvoidInfoType arg
1175 ; resVI <- vectAvoidInfoType res
1176 ; case (argVI, resVI) of
1177 (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
1178 (_ , VIDict) -> return VIDict
1179 _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
1180 }
1181 | otherwise
1182 = do
1183 { parr <- maybeParrTy ty
1184 ; if parr
1185 then return VIParr
1186 else do
1187 { scalar <- isScalar ty
1188 ; if scalar
1189 then return VISimple
1190 else return VIComplex
1191 } }
1192
1193 -- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
1194 --
1195 vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
1196 vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
1197
1198 -- Checks whether the type might be a parallel array type.
1199 --
1200 maybeParrTy :: Type -> VM Bool
1201 maybeParrTy ty
1202 -- looking through newtypes
1203 | Just ty' <- coreView ty
1204 = (== VIParr) <$> vectAvoidInfoType ty'
1205 -- decompose constructor applications
1206 | Just (tc, ts) <- splitTyConApp_maybe ty
1207 = do
1208 { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
1209 ; if isParallel
1210 then return True
1211 else or <$> mapM maybeParrTy ts
1212 }
1213 maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
1214 maybeParrTy _ = return False
1215
1216 -- Are the types of all variables in the 'Scalar' class or toplevel variables?
1217 --
1218 -- NB: 'liftSimple' does not abstract over toplevel variables.
1219 --
1220 allScalarVarType :: [Var] -> VM Bool
1221 allScalarVarType vs = and <$> mapM isScalarOrToplevel vs
1222 where
1223 isScalarOrToplevel v | isToplevel v = return True
1224 | otherwise = isScalar (varType v)
1225
1226 -- Are the types of all variables in the set in the 'Scalar' class or toplevel variables?
1227 --
1228 allScalarVarTypeSet :: VarSet -> VM Bool
1229 allScalarVarTypeSet = allScalarVarType . varSetElems
1230
1231 -- Debugging support
1232 --
1233 viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
1234 viTrace ce vi vTs
1235 = traceVt ("vect info: " ++ show vi ++ "[" ++
1236 (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
1237 (ppr $ deAnnotate ce)