Merge branch 'refs/heads/vect-avoid' into vect-avoid-merge
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise right-hand sides of toplevel bindings
7 vectTopExpr
8 , vectTopExprs
9 , vectScalarFun
10 , vectScalarDFun
11 )
12 where
13
14 #include "HsVersions.h"
15
16 import Vectorise.Type.Type
17 import Vectorise.Var
18 import Vectorise.Convert
19 import Vectorise.Vect
20 import Vectorise.Env
21 import Vectorise.Monad
22 import Vectorise.Builtins
23 import Vectorise.Utils
24
25 import CoreUtils
26 import MkCore
27 import CoreSyn
28 import CoreFVs
29 import Class
30 import DataCon
31 import TyCon
32 import TcType
33 import Type
34 import TypeRep
35 import Var
36 import VarEnv
37 import VarSet
38 import NameSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysPrim
43 import Outputable
44 import FastString
45 import DynFlags
46 import Util
47 import MonadUtils
48
49 import Control.Monad
50 import Data.Maybe
51 import Data.List
52 import TcRnMonad (goptM)
53 import DynFlags
54 import Util
55
56
57 -- Main entry point to vectorise expressions -----------------------------------
58
59 -- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
60 --
61 -- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
62 -- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
63 -- tagged as 'VIParr').
64 --
65 -- We have got the non-recursive case as a special case as it doesn't require to compute
66 -- vectorisation information twice.
67 --
68 vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
69 vectTopExpr var expr
70 = do
71 { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
72 ; if isVIEncaps exprVI
73 then
74 return Nothing
75 else do
76 { vExpr <- closedV $
77 inBind var $
78 vectAnnPolyExpr False exprVI
79 ; inline <- computeInline exprVI
80 ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
81 }
82 }
83
84 -- Compute the inlining hint for the right-hand side of a top-level binding.
85 --
86 computeInline :: CoreExprWithVectInfo -> VM Inline
87 computeInline ((_, VIDict), _) = return $ DontInline
88 computeInline (_, AnnTick _ expr) = computeInline expr
89 computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
90 where
91 (tvs, _) = collectAnnTypeBinders expr
92 computeInline _expr = return $ DontInline
93
94 -- |Vectorise a recursive group of top-level polymorphic expressions.
95 --
96 -- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
97 -- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
98 -- tagged as 'VIParr').
99 --
100 vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
101 vectTopExprs binds
102 = do
103 { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
104 ; if all isVIEncaps exprVIs
105 -- if all bindings are scalar => don't vectorise this group of bindings
106 then return Nothing
107 else do
108 { -- non-scalar bindings need to be vectorised
109 ; let areVIParr = any isVIParr exprVIs
110 ; revised_exprVIs <- if not areVIParr
111 -- if no binding is parallel => 'exprVIs' is ready for vectorisation
112 then return exprVIs
113 -- if any binding is parallel => recompute the vectorisation info
114 else mapM (vectAvoidAndEncapsulate (mkVarSet vars)) exprs
115
116 ; vExprs <- zipWithM vect vars revised_exprVIs
117 ; return $ Just (areVIParr, vExprs)
118 }
119 }
120 where
121 (vars, exprs) = unzip binds
122
123 vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
124
125 vect var exprVI
126 = do
127 { vExpr <- closedV $
128 inBind var $
129 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
130 ; inline <- computeInline exprVI
131 ; return (inline, vectorised vExpr)
132 }
133
134 -- |Vectorise a polymorphic expression annotated with vectorisation information.
135 --
136 -- The special case of dictionary functions is currently handled separately. (Would be neater to
137 -- integrate them, though!)
138 --
139 vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
140 vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
141 -- traverse through ticks
142 = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
143 vectAnnPolyExpr loop_breaker expr
144 | isVIDict expr
145 -- special case the right-hand side of dictionary functions
146 = (, undefined) <$> vectDictExpr (deAnnotate expr)
147 | otherwise
148 -- collect and vectorise type abstractions; then, descent into the body
149 = polyAbstract tvs $ \args ->
150 mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
151 where
152 (tvs, mono) = collectAnnTypeBinders expr
153
154 -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
155 -- lambda abstraction over all its free variables followed by the corresponding application to those
156 -- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
157 --
158 -- Preconditions:
159 --
160 -- * All free variables and the result type must be /simple/ types.
161 -- * The expression is sufficiently complex (to warrant special treatment). For now, that is
162 -- every expression that is not constant and contains at least one operation.
163 --
164 --
165 -- The user has an option to choose between aggressive and minimal vectorisation avoidance. With
166 -- minimal vectorisation avoidance, we only encapsulate individual scalar operations. With
167 -- aggressive vectorisation avoidance, we encapsulate subexpression that are as big as possible.
168 --
169 encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
170 encapsulateScalars ce@(_, AnnType _ty)
171 = return ce
172 encapsulateScalars ce@((_, VISimple), AnnVar _v)
173 -- NB: diverts from the paper: encapsulate scalar variables (including functions)
174 = liftSimpleAndCase ce
175 encapsulateScalars ce@(_, AnnVar _v)
176 = return ce
177 encapsulateScalars ce@(_, AnnLit _)
178 = return ce
179 encapsulateScalars ((fvs, vi), AnnTick tck expr)
180 = do
181 { encExpr <- encapsulateScalars expr
182 ; return ((fvs, vi), AnnTick tck encExpr)
183 }
184 encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
185 = do
186 { vectAvoid <- isVectAvoidanceAggressive
187 ; varsS <- allScalarVarTypeSet fvs
188 -- NB: diverts from the paper: we need to check the scalarness of bound variables as well,
189 -- as 'vectScalarFun' will handle them just the same as those introduced for the 'fvs'
190 -- by encapsulation.
191 ; bndrsS <- allScalarVarType bndrs
192 ; case (vi, vectAvoid && varsS && bndrsS) of
193 (VISimple, True) -> liftSimpleAndCase ce
194 _ -> do
195 { encExpr <- encapsulateScalars expr
196 ; return ((fvs, vi), AnnLam bndr encExpr)
197 }
198 }
199 where
200 (bndrs, _) = collectAnnBndrs ce
201 encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
202 = do
203 { vectAvoid <- isVectAvoidanceAggressive
204 ; varsS <- allScalarVarTypeSet fvs
205 ; case (vi, (vectAvoid || isSimpleApplication ce) && varsS) of
206 (VISimple, True) -> liftSimpleAndCase ce
207 _ -> do
208 { encCe1 <- encapsulateScalars ce1
209 ; encCe2 <- encapsulateScalars ce2
210 ; return ((fvs, vi), AnnApp encCe1 encCe2)
211 }
212 }
213 where
214 isSimpleApplication :: CoreExprWithVectInfo -> Bool
215 isSimpleApplication (_, AnnTick _ ce) = isSimpleApplication ce
216 isSimpleApplication (_, AnnCast ce _) = isSimpleApplication ce
217 isSimpleApplication ce | isSimple ce = True
218 isSimpleApplication (_, AnnApp ce1 ce2) = isSimple ce1 && isSimpleApplication ce2
219 isSimpleApplication _ = False
220 --
221 isSimple :: CoreExprWithVectInfo -> Bool
222 isSimple (_, AnnType {}) = True
223 isSimple (_, AnnVar {}) = True
224 isSimple (_, AnnLit {}) = True
225 isSimple (_, AnnTick _ ce) = isSimple ce
226 isSimple (_, AnnCast ce _) = isSimple ce
227 isSimple _ = False
228 encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
229 = do
230 { vectAvoid <- isVectAvoidanceAggressive
231 ; varsS <- allScalarVarTypeSet fvs
232 ; case (vi, vectAvoid && varsS) of
233 (VISimple, True) -> liftSimpleAndCase ce
234 _ -> do
235 { encScrut <- encapsulateScalars scrut
236 ; encAlts <- mapM encAlt alts
237 ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
238 }
239 }
240 where
241 encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
242 encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
243 = do
244 { vectAvoid <- isVectAvoidanceAggressive
245 ; varsS <- allScalarVarTypeSet fvs
246 ; case (vi, vectAvoid && varsS) of
247 (VISimple, True) -> liftSimpleAndCase ce
248 _ -> do
249 { encExpr1 <- encapsulateScalars expr1
250 ; encExpr2 <- encapsulateScalars expr2
251 ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
252 }
253 }
254 encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
255 = do
256 { vectAvoid <- isVectAvoidanceAggressive
257 ; varsS <- allScalarVarTypeSet fvs
258 ; case (vi, vectAvoid && varsS) of
259 (VISimple, True) -> liftSimpleAndCase ce
260 _ -> do
261 { encBinds <- mapM encBind binds
262 ; encExpr <- encapsulateScalars expr
263 ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
264 }
265 }
266 where
267 encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
268 encapsulateScalars ((fvs, vi), AnnCast expr coercion)
269 = do
270 { encExpr <- encapsulateScalars expr
271 ; return ((fvs, vi), AnnCast encExpr coercion)
272 }
273 encapsulateScalars _
274 = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
275
276 -- Lambda-lift the given simple expression and apply it to the abstracted free variables.
277 --
278 -- If the expression is a case expression scrutinising anything, but a scalar type, then lift
279 -- each alternative individually.
280 --
281 liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
282 liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
283 = do
284 { vi <- vectAvoidInfoTypeOf expr
285 ; if (vi == VISimple)
286 then
287 liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
288 else do
289 { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
290 ; return ((fvs, vi), AnnCase expr bndr t alts')
291 }
292 }
293 liftSimpleAndCase aexpr = liftSimple aexpr
294
295 liftSimple :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
296 liftSimple ((fvs, vi), AnnVar v)
297 | v `elemVarSet` fvs -- special case to avoid producing: (\v -> v) v
298 && not (isToplevel v) -- NB: if 'v' not free or is toplevel, we must get the 'VIEncaps'
299 = return $ ((fvs, vi), AnnVar v)
300 liftSimple aexpr@((fvs_orig, VISimple), expr)
301 = do
302 { let liftedExpr = mkAnnApps (mkAnnLams (reverse vars) fvs expr) vars
303
304 ; traceVt "encapsulate:" $ ppr (deAnnotate aexpr) $$ text "==>" $$ ppr (deAnnotate liftedExpr)
305
306 ; return $ liftedExpr
307 }
308 where
309 vars = varSetElems fvs
310 fvs = filterVarSet (not . isToplevel) fvs_orig -- only include 'Id's that are not toplevel
311
312 mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
313 mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
314 ((emptyVarSet, VIEncaps), expr)
315 mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
316
317 mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
318 mkAnnApps aexpr [] = aexpr
319 mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
320
321 mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
322 mkAnnApp aexpr@((fvs, _vi), _expr) v
323 = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
324 liftSimple aexpr
325 = pprPanic "Vectorise.Exp.liftSimple: not simple" $ ppr (deAnnotate aexpr)
326
327 isToplevel :: Var -> Bool
328 isToplevel v | isId v = case realIdUnfolding v of
329 NoUnfolding -> False
330 OtherCon {} -> True
331 DFunUnfolding {} -> True
332 CoreUnfolding {uf_is_top = top} -> top
333 | otherwise = False
334
335 -- |Vectorise an expression.
336 --
337 vectExpr :: CoreExprWithVectInfo -> VM VExpr
338
339 vectExpr aexpr
340 -- encapsulated expression of functional type => try to vectorise as a scalar subcomputation
341 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
342 = vectFnExpr True False aexpr
343 -- encapsulated constant => vectorise as a scalar constant
344 | isVIEncaps aexpr
345 = traceVt "vectExpr (encapsulated constant):" (ppr . deAnnotate $ aexpr) >>
346 vectConst (deAnnotate aexpr)
347
348 vectExpr (_, AnnVar v)
349 = vectVar v
350
351 vectExpr (_, AnnLit lit)
352 = vectConst $ Lit lit
353
354 vectExpr aexpr@(_, AnnLam _ _)
355 = traceVt "vectExpr [AnnLam]:" (ppr . deAnnotate $ aexpr) >>
356 vectFnExpr True False aexpr
357
358 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
359 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
360 -- happy.
361 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
362 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
363 | v == pAT_ERROR_ID
364 = do
365 { (vty, lty) <- vectAndLiftType ty
366 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
367 }
368 where
369 err' = deAnnotate err
370
371 -- type application (handle multiple consecutive type applications simultaneously to ensure the
372 -- PA dictionaries are put at the right places)
373 vectExpr e@(_, AnnApp _ arg)
374 | isAnnTypeArg arg
375 = vectPolyApp e
376
377 -- Lifted literal
378 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
379 | Just _con <- isDataConId_maybe v
380 = do
381 { let vexpr = App (Var v) (Lit lit)
382 ; lexpr <- liftPD vexpr
383 ; return (vexpr, lexpr)
384 }
385
386 -- value application (dictionary or user value)
387 vectExpr e@(_, AnnApp fn arg)
388 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
389 = vectPolyApp e
390 | otherwise -- user value
391 = do
392 { -- vectorise the types
393 ; varg_ty <- vectType arg_ty
394 ; vres_ty <- vectType res_ty
395
396 -- vectorise the function and argument expression
397 ; vfn <- vectExpr fn
398 ; varg <- vectExpr arg
399
400 -- the vectorised function is a closure; apply it to the vectorised argument
401 ; mkClosureApp varg_ty vres_ty vfn varg
402 }
403 where
404 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
405
406 vectExpr (_, AnnCase scrut bndr ty alts)
407 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
408 , isAlgTyCon tycon
409 = vectAlgCase tycon ty_args scrut bndr ty alts
410 | otherwise
411 = do
412 { dflags <- getDynFlags
413 ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
414 ppr scrut_ty
415 }
416 where
417 scrut_ty = exprType (deAnnotate scrut)
418
419 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
420 = do
421 { traceVt "let binding (non-recursive)" empty
422 ; vrhs <- localV $
423 inBind bndr $
424 vectAnnPolyExpr False rhs
425 ; traceVt "let body (non-recursive)" empty
426 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
427 ; return $ vLet (vNonRec vbndr vrhs) vbody
428 }
429
430 vectExpr (_, AnnLet (AnnRec bs) body)
431 = do
432 { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ do
433 { traceVt "let bindings (recursive)" empty
434 ; vrhss <- zipWithM vect_rhs bndrs rhss
435 ; traceVt "let body (recursive)" empty
436 ; vbody <- vectExpr body
437 ; return (vrhss, vbody)
438 }
439 ; return $ vLet (vRec vbndrs vrhss) vbody
440 }
441 where
442 (bndrs, rhss) = unzip bs
443
444 vect_rhs bndr rhs = localV $
445 inBind bndr $
446 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
447
448 vectExpr (_, AnnTick tickish expr)
449 = vTick tickish <$> vectExpr expr
450
451 vectExpr (_, AnnType ty)
452 = vType <$> vectType ty
453
454 vectExpr e
455 = do
456 { dflags <- getDynFlags
457 ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
458 }
459
460 -- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
461 -- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
462 -- zip).
463 --
464 -- We do not handle type variables at this point, as they will already have been stripped off by
465 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
466 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
467 --
468 vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
469 -- should be inlined
470 -> Bool -- ^Whether the binding is a loop breaker
471 -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
472 -> VM VExpr
473 vectFnExpr inline loop_breaker aexpr@(_ann, AnnLam bndr body)
474 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
475 | isId bndr
476 && isPredTy (idType bndr)
477 = do
478 { vBndr <- vectBndr bndr
479 ; vbody <- vectFnExpr inline loop_breaker body
480 ; return $ mapVect (mkLams [vectorised vBndr]) vbody
481 }
482 -- encapsulated non-predicate abstraction: vectorise as a scalar computation
483 | isId bndr && isVIEncaps aexpr
484 = vectScalarFun . deAnnotate $ aexpr
485 -- non-predicate abstraction: vectorise as a non-scalar computation
486 | isId bndr
487 = vectLam inline loop_breaker aexpr
488 | otherwise
489 = do
490 { dflags <- getDynFlags
491 ; cantVectorise dflags "Vectorise.Exp.vectFnExpr: Unexpected type lambda" $
492 ppr (deAnnotate aexpr)
493 }
494 vectFnExpr _ _ aexpr
495 -- encapsulated function: vectorise as a scalar computation
496 | (isFunTy . annExprType $ aexpr) && isVIEncaps aexpr
497 = vectScalarFun . deAnnotate $ aexpr
498 | otherwise
499 -- not an abstraction: vectorise as a non-scalar vanilla expression
500 -- NB: we can get here due to the recursion in the first case above and from 'vectAnnPolyExpr'
501 = vectExpr aexpr
502
503 -- |Vectorise type and dictionary applications.
504 --
505 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
506 -- involve two sets of type variables and dictionaries. Consider,
507 --
508 -- > class C a where
509 -- > m :: D b => b -> a
510 --
511 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
512 --
513 vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
514 vectPolyApp e0
515 = case e4 of
516 (_, AnnVar var)
517 -> do { -- get the vectorised form of the variable
518 ; vVar <- lookupVar var
519 ; traceVt "vectPolyApp of" (ppr var)
520
521 -- vectorise type and dictionary arguments
522 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
523 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
524 ; vTysOuter <- mapM vectType tysOuter
525 ; vTysInner <- mapM vectType tysInner
526
527 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
528
529 ; case vVar of
530 Local (vv, lv)
531 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
532 ; traceVt " LOCAL" (text "")
533 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
534 }
535 Global vv
536 | isDictComp var -- dictionary computation
537 -> do { -- in a dictionary computation, the innermost, non-empty set of
538 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
539 -- are needed for the type variables
540 ; ve <- if null dictsInner
541 then
542 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
543 else
544 reconstructOuter
545 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
546 ; traceVt " GLOBAL (dict):" (ppr ve)
547 ; vectConst ve
548 }
549 | otherwise -- non-dictionary computation
550 -> do { MASSERT( null dictsInner )
551 ; ve <- reconstructOuter (Var vv)
552 ; traceVt " GLOBAL (non-dict):" (ppr ve)
553 ; vectConst ve
554 }
555 }
556 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
557 where
558 -- if there is only one set of variables or dictionaries, it will be the outer set
559 (e1, dictsOuter) = collectAnnDictArgs e0
560 (e2, tysOuter) = collectAnnTypeArgs e1
561 (e3, dictsInner) = collectAnnDictArgs e2
562 (e4, tysInner) = collectAnnTypeArgs e3
563 --
564 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
565
566 -- |Vectorise the body of a dfun.
567 --
568 -- Dictionary computations are special for the following reasons. The application of dictionary
569 -- functions are always saturated, so there is no need to create closures. Dictionary computations
570 -- don't depend on array values, so they are always scalar computations whose result we can
571 -- replicate (instead of executing them in parallel).
572 --
573 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
574 -- computation. Consequently, the variable case needs to deal with cases where binders are
575 -- in the vectoriser environments and where that is not the case.
576 --
577 vectDictExpr :: CoreExpr -> VM CoreExpr
578 vectDictExpr (Var var)
579 = do { mb_scope <- lookupVar_maybe var
580 ; case mb_scope of
581 Nothing -> return $ Var var -- binder from within the dict. computation
582 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
583 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
584 }
585 vectDictExpr (Lit lit)
586 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
587 vectDictExpr (Lam bndr e)
588 = Lam bndr <$> vectDictExpr e
589 vectDictExpr (App fn arg)
590 = App <$> vectDictExpr fn <*> vectDictExpr arg
591 vectDictExpr (Case e bndr ty alts)
592 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
593 where
594 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
595 --
596 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
597 where
598 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
599 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
600 vectDictAltCon DEFAULT = return DEFAULT
601 vectDictExpr (Let bnd body)
602 = Let <$> vectDictBind bnd <*> vectDictExpr body
603 where
604 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
605 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
606 vectDictExpr e@(Cast _e _coe)
607 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
608 vectDictExpr (Tick tickish e)
609 = Tick tickish <$> vectDictExpr e
610 vectDictExpr (Type ty)
611 = Type <$> vectType ty
612 vectDictExpr (Coercion coe)
613 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
614
615 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
616 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
617 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
618 -- require the full blown vectorisation transformation; instead, they can be lifted by application
619 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
620 --
621 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
622 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
623 -- "Note [Scalar dfuns]" in 'Vectorise'.
624 --
625 vectScalarFun :: CoreExpr -> VM VExpr
626 vectScalarFun expr
627 = do
628 { traceVt "vectScalarFun:" (ppr expr)
629 ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
630 ; mkScalarFun arg_tys res_ty expr
631 }
632
633 -- Generate code for a scalar function by generating a scalar closure. If the function is a
634 -- dictionary function, vectorise it as dictionary code.
635 --
636 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
637 mkScalarFun arg_tys res_ty expr
638 | isPredTy res_ty
639 = do { vExpr <- vectDictExpr expr
640 ; return (vExpr, unused)
641 }
642 | otherwise
643 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
644
645 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
646 ; zipf <- zipScalars arg_tys res_ty
647 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
648 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
649 ; lclo <- liftPD (Var clo_var)
650 ; return (Var clo_var, lclo)
651 }
652 where
653 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
654
655 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
656 --
657 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
658 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
659 --
660 -- NB: You may think that we could implement this function guided by the struture of the Core
661 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
662 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
663 -- to the Core code of the unvectorised dfun.
664 --
665 -- Here an example — assume,
666 --
667 -- > class Eq a where { (==) :: a -> a -> Bool }
668 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
669 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
670 --
671 -- The unvectorised dfun for the above instance has the following signature:
672 --
673 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
674 --
675 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
676 --
677 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
678 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
679 -- > D:V:Eq $(vectScalarFun True recFns
680 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
681 --
682 -- NB:
683 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
684 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
685 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
686 --
687 vectScalarDFun :: Var -- ^ Original dfun
688 -> VM CoreExpr
689 vectScalarDFun var
690 = do { -- bring the type variables into scope
691 ; mapM_ defLocalTyVar tvs
692
693 -- vectorise dictionary argument types and generate variables for them
694 ; vTheta <- mapM vectType theta
695 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
696 ; let vThetaVars = varsToCoreExprs vThetaBndr
697
698 -- vectorise superclass dictionaries and methods as scalar expressions
699 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
700 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
701 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
702 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
703 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
704 selIds
705 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
706
707 -- vectorised applications of the class-dictionary data constructor
708 ; Just vDataCon <- lookupDataCon dataCon
709 ; vTys <- mapM vectType tys
710 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
711
712 ; return $ mkLams (tvs ++ vThetaBndr) vBody
713 }
714 where
715 ty = varType var
716 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
717 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
718 selIds = classAllSelIds cls
719 dataCon = classDataCon cls
720
721 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
722 -- expression computing the vectorised dictionary.
723 --
724 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
725 -- the unvectorised version, thus:
726 --
727 -- > D:C op1 .. opm
728 -- > where
729 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
730 --
731 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
732 --
733 unVectDict :: Type -> CoreExpr -> VM CoreExpr
734 unVectDict ty e
735 = do { vTys <- mapM vectType tys
736 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
737 ; scOps <- zipWithM fromVect methTys meths
738 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
739 }
740 where
741 (tycon, tys) = splitTyConApp ty
742 Just dataCon = isDataProductTyCon_maybe tycon
743 Just cls = tyConClass_maybe tycon
744 methTys = dataConInstArgTys dataCon tys
745 selIds = classAllSelIds cls
746
747 -- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
748 --
749 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
750 -- variables are passed explicit (as conventional arguments) into the body during closure
751 -- construction.
752 --
753 vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
754 -> Bool -- ^ Whether the binding is a loop breaker.
755 -> CoreExprWithVectInfo -- ^ Body of abstraction.
756 -> VM VExpr
757 vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
758 = do { traceVt "fully vectorise a lambda expression" (ppr . deAnnotate $ expr)
759
760 ; let (bndrs, body) = collectAnnValBinders expr
761
762 -- grab the in-scope type variables
763 ; tyvars <- localTyVars
764
765 -- collect and vectorise all /local/ free variables
766 ; vfvs <- readLEnv $ \env ->
767 [ (var, fromJust mb_vv)
768 | var <- varSetElems fvs
769 , let mb_vv = lookupVarEnv (local_vars env) var
770 , isJust mb_vv -- its local == is in local var env
771 ]
772 -- separate dictionary from non-dictionary variables in the free variable set
773 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
774 (_fvs_dict, vfvs_dict) = unzip vvs_dict
775 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
776
777 -- compute the type of the vectorised closure
778 ; arg_tys <- mapM (vectType . idType) bndrs
779 ; res_ty <- vectType (exprType $ deAnnotate body)
780
781 ; let arity = length fvs_nondict + length bndrs
782 vfvs_dict' = map vectorised vfvs_dict
783 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
784 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
785 $ do { -- generate the vectorised body of the lambda abstraction
786 ; lc <- builtin liftingContext
787 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
788
789 ; vbody' <- break_loop lc res_ty vbody
790 ; return $ vLams lc vbndrs vbody'
791 }
792 }
793 where
794 maybe_inline n | inline = Inline n
795 | otherwise = DontInline
796
797 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
798 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
799 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
800 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
801 -- in Figure 6 of HtM.
802 break_loop lc ty (ve, le)
803 | loop_breaker
804 = do { dflags <- getDynFlags
805 ; empty <- emptyPD ty
806 ; lty <- mkPDataType ty
807 ; return (ve, mkWildCase (Var lc) intPrimTy lty
808 [(DEFAULT, [], le),
809 (LitAlt (mkMachInt dflags 0), [], empty)])
810 }
811 | otherwise = return (ve, le)
812 vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
813
814 -- Vectorise an algebraic case expression.
815 --
816 -- We convert
817 --
818 -- case e :: t of v { ... }
819 --
820 -- to
821 --
822 -- V: let v' = e in case v' of _ { ... }
823 -- L: let v' = e in case v' `cast` ... of _ { ... }
824 --
825 -- When lifting, we have to do it this way because v must have the type
826 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
827 -- have to handle the case where v is a wild var correctly.
828 --
829
830 -- FIXME: this is too lazy...is it?
831 vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
832 -> [(AltCon, [Var], CoreExprWithVectInfo)]
833 -> VM VExpr
834 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
835 = do
836 { traceVt "scrutinee (DEFAULT only)" empty
837 ; vscrut <- vectExpr scrut
838 ; (vty, lty) <- vectAndLiftType ty
839 ; traceVt "alternative body (DEFAULT only)" empty
840 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
841 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
842 }
843 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
844 = do
845 { traceVt "scrutinee (one shot w/o binders)" empty
846 ; vscrut <- vectExpr scrut
847 ; (vty, lty) <- vectAndLiftType ty
848 ; traceVt "alternative body (one shot w/o binders)" empty
849 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
850 ; return $ vCaseDEFAULT vscrut vbndr vty lty vbody
851 }
852 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
853 = do
854 { traceVt "scrutinee (one shot w/ binders)" empty
855 ; vexpr <- vectExpr scrut
856 ; (vty, lty) <- vectAndLiftType ty
857 ; traceVt "alternative body (one shot w/ binders)" empty
858 ; (vbndr, (vbndrs, (vect_body, lift_body)))
859 <- vect_scrut_bndr
860 . vectBndrsIn bndrs
861 $ vectExpr body
862 ; let (vect_bndrs, lift_bndrs) = unzip vbndrs
863 ; (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
864 ; vect_dc <- maybeV dataConErr (lookupDataCon dc)
865
866 ; let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
867 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
868
869 ; return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
870 }
871 where
872 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
873 | otherwise = vectBndrIn bndr
874
875 mk_wild_case expr ty dc bndrs body
876 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
877
878 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
879
880 vectAlgCase tycon _ty_args scrut bndr ty alts
881 = do
882 { traceVt "scrutinee (general case)" empty
883 ; vexpr <- vectExpr scrut
884
885 ; vect_tc <- vectTyCon tycon
886 ; (vty, lty) <- vectAndLiftType ty
887
888 ; let arity = length (tyConDataCons vect_tc)
889 ; sel_ty <- builtin (selTy arity)
890 ; sel_bndr <- newLocalVar (fsLit "sel") sel_ty
891 ; let sel = Var sel_bndr
892
893 ; traceVt "alternatives' body (general case)" empty
894 ; (vbndr, valts) <- vect_scrut_bndr
895 $ mapM (proc_alt arity sel vty lty) alts'
896 ; let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
897
898 ; (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
899
900 ; let (vect_bodies, lift_bodies) = unzip vbodies
901
902 ; vdummy <- newDummyVar (exprType vect_scrut)
903 ; ldummy <- newDummyVar (exprType lift_scrut)
904 ; let vect_case = Case vect_scrut vdummy vty
905 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
906
907 ; lc <- builtin liftingContext
908 ; lbody <- combinePD vty (Var lc) sel lift_bodies
909 ; let lift_case = Case lift_scrut ldummy lty
910 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
911 lbody)]
912
913 ; return . vLet (vNonRec vbndr vexpr)
914 $ (vect_case, lift_case)
915 }
916 where
917 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
918 | otherwise = vectBndrIn bndr
919
920 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
921
922 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
923 cmp DEFAULT DEFAULT = EQ
924 cmp DEFAULT _ = LT
925 cmp _ DEFAULT = GT
926 cmp _ _ = panic "vectAlgCase/cmp"
927
928 proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
929 = do
930 dflags <- getDynFlags
931 vect_dc <- maybeV dataConErr (lookupDataCon dc)
932 let ntag = dataConTagZ vect_dc
933 tag = mkDataConTag dflags vect_dc
934 fvs = fvs_body `delVarSetList` bndrs
935
936 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
937 lc <- builtin liftingContext
938 elems <- builtin (selElements arity ntag)
939
940 (vbndrs, vbody)
941 <- vectBndrsIn bndrs
942 . localV
943 $ do
944 { binds <- mapM (pack_var (Var lc) sel_tags tag)
945 . filter isLocalId
946 $ varSetElems fvs
947 ; traceVt "case alternative:" (ppr . deAnnotate $ body)
948 ; (ve, le) <- vectExpr body
949 ; return (ve, Case (elems `App` sel) lc lty
950 [(DEFAULT, [], (mkLets (concat binds) le))])
951 }
952 -- empty <- emptyPD vty
953 -- return (ve, Case (elems `App` sel) lc lty
954 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
955 -- $ mkLets (concat binds) le),
956 -- (LitAlt (mkMachInt 0), [], empty)])
957 let (vect_bndrs, lift_bndrs) = unzip vbndrs
958 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
959 where
960 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
961
962 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
963
964 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
965
966 -- Pack a variable for a case alternative context *if* the variable is vectorised. If it
967 -- isn't, ignore it as scalar variables don't need to be packed.
968 pack_var len tags t v
969 = do
970 { r <- lookupVar_maybe v
971 ; case r of
972 Just (Local (vv, lv)) ->
973 do
974 { lv' <- cloneVar lv
975 ; expr <- packByTagPD (idType vv) (Var lv) len tags t
976 ; updLEnv (\env -> env { local_vars = extendVarEnv (local_vars env) v (vv, lv') })
977 ; return [(NonRec lv' expr)]
978 }
979 _ -> return []
980 }
981
982
983 -- Support to compute information for vectorisation avoidance ------------------
984
985 -- Annotation for Core AST nodes that describes how they should be handled during vectorisation
986 -- and especially if vectorisation of the corresponding computation can be avoided.
987 --
988 data VectAvoidInfo = VIParr -- tree contains parallel computations
989 | VISimple -- result type is scalar & no parallel subcomputation
990 | VIComplex -- any result type, no parallel subcomputation
991 | VIEncaps -- tree encapsulated by 'liftSimple'
992 | VIDict -- dictionary computation (never parallel)
993 deriving (Eq, Show)
994
995 -- Core expression annotated with free variables and vectorisation-specific information.
996 --
997 type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
998
999 -- Yield the type of an annotated core expression.
1000 --
1001 annExprType :: AnnExpr Var ann -> Type
1002 annExprType = exprType . deAnnotate
1003
1004 -- Project the vectorisation information from an annotated Core expression.
1005 --
1006 vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
1007 vectAvoidInfoOf ((_, vi), _) = vi
1008
1009 -- Is this a 'VIParr' node?
1010 --
1011 isVIParr :: CoreExprWithVectInfo -> Bool
1012 isVIParr = (== VIParr) . vectAvoidInfoOf
1013
1014 -- Is this a 'VIEncaps' node?
1015 --
1016 isVIEncaps :: CoreExprWithVectInfo -> Bool
1017 isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
1018
1019 -- Is this a 'VIDict' node?
1020 --
1021 isVIDict :: CoreExprWithVectInfo -> Bool
1022 isVIDict = (== VIDict) . vectAvoidInfoOf
1023
1024 -- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
1025 --
1026 unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
1027 unlessVIParr _ VIParr = VIParr
1028 unlessVIParr vi _ = vi
1029
1030 -- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
1031 -- information of the first argument is produced.
1032 --
1033 unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
1034 infixl `unlessVIParrExpr`
1035 unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
1036
1037 -- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
1038 --
1039 -- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
1040 --
1041 vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
1042 vectAvoidInfo pvs ce@(fvs, AnnVar v)
1043 = do
1044 { gpvs <- globalParallelVars
1045 ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
1046 then return VIParr
1047 else vectAvoidInfoTypeOf ce
1048 ; viTrace ce vi []
1049 ; when (vi == VIParr) $
1050 traceVt " reason:" $ if v `elemVarSet` pvs then text "local" else
1051 if v `elemVarSet` gpvs then text "global" else text "parallel type"
1052
1053 ; return ((fvs, vi), AnnVar v)
1054 }
1055
1056 vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
1057 = do
1058 { vi <- vectAvoidInfoTypeOf ce
1059 ; viTrace ce vi []
1060 ; return ((fvs, vi), AnnLit lit)
1061 }
1062
1063 vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
1064 = do
1065 { ceVI <- vectAvoidInfoTypeOf ce
1066 ; eVI1 <- vectAvoidInfo pvs e1
1067 ; eVI2 <- vectAvoidInfo pvs e2
1068 ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
1069 -- ; viTrace ce vi [eVI1, eVI2]
1070 ; return ((fvs, vi), AnnApp eVI1 eVI2)
1071 }
1072
1073 vectAvoidInfo pvs (fvs, AnnLam var body)
1074 = do
1075 { bodyVI <- vectAvoidInfo pvs body
1076 ; varVI <- vectAvoidInfoType $ varType var
1077 ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
1078 -- ; viTrace ce vi [bodyVI]
1079 ; return ((fvs, vi), AnnLam var bodyVI)
1080 }
1081
1082 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
1083 = do
1084 { ceVI <- vectAvoidInfoTypeOf ce
1085 ; eVI <- vectAvoidInfo pvs e
1086 ; isScalarTy <- isScalar $ varType var
1087 ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
1088 then do -- binding is parallel
1089 { bodyVI <- vectAvoidInfo (pvs `extendVarSet` var) body
1090 ; return (bodyVI, VIParr)
1091 }
1092 else do -- binding doesn't affect parallelism
1093 { bodyVI <- vectAvoidInfo pvs body
1094 ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
1095 }
1096 -- ; viTrace ce vi [eVI, bodyVI]
1097 ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
1098 }
1099
1100 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
1101 = do
1102 { ceVI <- vectAvoidInfoTypeOf ce
1103 ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
1104 ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
1105 ; if not . null $ parrBndrs
1106 then do -- body may trigger parallelism via at least one binding
1107 { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
1108 ; let extendedPvs = pvs `extendVarSetList` new_pvs
1109 ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
1110 ; bodyVI <- vectAvoidInfo extendedPvs body
1111 -- ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
1112 ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
1113 }
1114 else do -- demanded bindings cannot trigger parallelism
1115 { bodyVI <- vectAvoidInfo pvs body
1116 ; let vi = ceVI `unlessVIParrExpr` bodyVI
1117 -- ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
1118 ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
1119 }
1120 }
1121 where
1122 vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
1123
1124 isVIParrBnd (var, eVI)
1125 = do
1126 { isScalarTy <- isScalar (varType var)
1127 ; return $ isVIParr eVI && not isScalarTy
1128 }
1129
1130 vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
1131 = do
1132 { ceVI <- vectAvoidInfoTypeOf ce
1133 ; eVI <- vectAvoidInfo pvs e
1134 ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
1135 ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
1136 vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
1137 -- ; viTrace ce vi (eVI : alteVIs)
1138 ; return ((fvs, vi), AnnCase eVI var ty altsVI)
1139 }
1140 where
1141 vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
1142 = do
1143 { allScalar <- allScalarVarType bndrs
1144 ; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
1145 | otherwise = pvs
1146 ; (con, bndrs,) <$> vectAvoidInfo altPvs e
1147 }
1148
1149 vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
1150 = do
1151 { eVI <- vectAvoidInfo pvs e
1152 ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
1153 }
1154
1155 vectAvoidInfo pvs (fvs, AnnTick tick e)
1156 = do
1157 { eVI <- vectAvoidInfo pvs e
1158 ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
1159 }
1160
1161 vectAvoidInfo _pvs (fvs, AnnType ty)
1162 = return ((fvs, VISimple), AnnType ty)
1163
1164 vectAvoidInfo _pvs (fvs, AnnCoercion coe)
1165 = return ((fvs, VISimple), AnnCoercion coe)
1166
1167 -- Compute vectorisation avoidance information for a type.
1168 --
1169 vectAvoidInfoType :: Type -> VM VectAvoidInfo
1170 vectAvoidInfoType ty
1171 | isPredTy ty
1172 = return VIDict
1173 | Just (arg, res) <- splitFunTy_maybe ty
1174 = do
1175 { argVI <- vectAvoidInfoType arg
1176 ; resVI <- vectAvoidInfoType res
1177 ; case (argVI, resVI) of
1178 (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
1179 (_ , VIDict) -> return VIDict
1180 _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
1181 }
1182 | otherwise
1183 = do
1184 { parr <- maybeParrTy ty
1185 ; if parr
1186 then return VIParr
1187 else do
1188 { scalar <- isScalar ty
1189 ; if scalar
1190 then return VISimple
1191 else return VIComplex
1192 } }
1193
1194 -- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
1195 --
1196 vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
1197 vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
1198
1199 -- Checks whether the type might be a parallel array type.
1200 --
1201 maybeParrTy :: Type -> VM Bool
1202 maybeParrTy ty
1203 -- looking through newtypes
1204 | Just ty' <- coreView ty
1205 = (== VIParr) <$> vectAvoidInfoType ty'
1206 -- decompose constructor applications
1207 | Just (tc, ts) <- splitTyConApp_maybe ty
1208 = do
1209 { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
1210 ; if isParallel
1211 then return True
1212 else or <$> mapM maybeParrTy ts
1213 }
1214 maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
1215 maybeParrTy _ = return False
1216
1217 -- Are the types of all variables in the 'Scalar' class?
1218 --
1219 allScalarVarType :: [Var] -> VM Bool
1220 allScalarVarType vs = and <$> mapM (isScalar . varType) vs
1221
1222 -- Are the types of all variables in the set in the 'Scalar' class?
1223 --
1224 allScalarVarTypeSet :: VarSet -> VM Bool
1225 allScalarVarTypeSet = allScalarVarType . varSetElems
1226
1227 -- Debugging support
1228 --
1229 viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
1230 viTrace ce vi vTs
1231 = traceVt ("vect info: " ++ show vi ++ "[" ++
1232 (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
1233 (ppr $ deAnnotate ce)