Vectoriser: fix vectorisation avoidance for case expressions
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise right-hand sides of toplevel bindings
7 vectTopExpr
8 , vectTopExprs
9 , vectScalarFun
10 , vectScalarDFun
11 )
12 where
13
14 #include "HsVersions.h"
15
16 import Vectorise.Type.Type
17 import Vectorise.Var
18 import Vectorise.Convert
19 import Vectorise.Vect
20 import Vectorise.Env
21 import Vectorise.Monad
22 import Vectorise.Builtins
23 import Vectorise.Utils
24
25 import CoreUtils
26 import MkCore
27 import CoreSyn
28 import CoreFVs
29 import Class
30 import DataCon
31 import TyCon
32 import TcType
33 import Type
34 import TypeRep
35 import Var
36 import VarEnv
37 import VarSet
38 import NameSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysPrim
43 import Outputable
44 import FastString
45 import DynFlags
46 import Util
47 import MonadUtils
48
49 import Control.Monad
50 import Data.Maybe
51 import Data.List
52
53
54 -- Main entry point to vectorise expressions -----------------------------------
55
56 -- |Vectorise a polymorphic expression that forms a *non-recursive* binding.
57 --
58 -- Return 'Nothing' if the expression is scalar; otherwise, the first component of the result
59 -- (which is of type 'Bool') indicates whether the expression is parallel (i.e., whether it is
60 -- tagged as 'VIParr').
61 --
62 -- We have got the non-recursive case as a special case as it doesn't require to compute
63 -- vectorisation information twice.
64 --
65 vectTopExpr :: Var -> CoreExpr -> VM (Maybe (Bool, Inline, CoreExpr))
66 vectTopExpr var expr
67 = do
68 { exprVI <- encapsulateScalars <=< vectAvoidInfo emptyVarSet . freeVars $ expr
69 ; if isVIEncaps exprVI
70 then
71 return Nothing
72 else do
73 { vExpr <- closedV $
74 inBind var $
75 vectAnnPolyExpr False exprVI
76 ; inline <- computeInline exprVI
77 ; return $ Just (isVIParr exprVI, inline, vectorised vExpr)
78 }
79 }
80
81 -- Compute the inlining hint for the right-hand side of a top-level binding.
82 --
83 computeInline :: CoreExprWithVectInfo -> VM Inline
84 computeInline ((_, VIDict), _) = return $ DontInline
85 computeInline (_, AnnTick _ expr) = computeInline expr
86 computeInline expr@(_, AnnLam _ _) = Inline <$> polyArity tvs
87 where
88 (tvs, _) = collectAnnTypeBinders expr
89 computeInline _expr = return $ DontInline
90
91 -- |Vectorise a recursive group of top-level polymorphic expressions.
92 --
93 -- Return 'Nothing' if the expression group is scalar; otherwise, the first component of the result
94 -- (which is of type 'Bool') indicates whether the expressions are parallel (i.e., whether they are
95 -- tagged as 'VIParr').
96 --
97 vectTopExprs :: [(Var, CoreExpr)] -> VM (Maybe (Bool, [(Inline, CoreExpr)]))
98 vectTopExprs binds
99 = do
100 { exprVIs <- mapM (vectAvoidAndEncapsulate emptyVarSet) exprs
101 ; if all isVIEncaps exprVIs
102 then
103 return Nothing
104 else do
105 { (areVIParr, vExprs) <- unzip <$> mapM encapsulateAndVect binds
106 ; return $ Just (or areVIParr, vExprs)
107 }
108 }
109 where
110 (vars, exprs) = unzip binds
111
112 vectAvoidAndEncapsulate pvs = encapsulateScalars <=< vectAvoidInfo pvs . freeVars
113
114 encapsulateAndVect (var, expr)
115 = do
116 { exprVI <- vectAvoidAndEncapsulate (mkVarSet vars) expr
117 ; vExpr <- closedV $
118 inBind var $
119 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo var) exprVI
120 ; inline <- computeInline exprVI
121 ; return (isVIParr exprVI, (inline, vectorised vExpr))
122 }
123
124 -- |Vectorise a polymorphic expression annotated with vectorisation information.
125 --
126 -- The special case of dictionary functions is currently handled separately. (Would be neater to
127 -- integrate them, though!)
128 --
129 vectAnnPolyExpr :: Bool -> CoreExprWithVectInfo -> VM VExpr
130 vectAnnPolyExpr loop_breaker (_, AnnTick tickish expr)
131 -- traverse through ticks
132 = vTick tickish <$> vectAnnPolyExpr loop_breaker expr
133 vectAnnPolyExpr loop_breaker expr
134 | isVIDict expr
135 -- special case the right-hand side of dictionary functions
136 = (, undefined) <$> vectDictExpr (deAnnotate expr)
137 | otherwise
138 -- collect and vectorise type abstractions; then, descent into the body
139 = polyAbstract tvs $ \args ->
140 mapVect (mkLams $ tvs ++ args) <$> vectFnExpr False loop_breaker mono
141 where
142 (tvs, mono) = collectAnnTypeBinders expr
143
144 -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
145 -- lambda abstraction over all its free variables followed by the corresponding application to those
146 -- variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
147 --
148 -- Preconditions:
149 --
150 -- * All free variables and the result type must be /simple/ types.
151 -- * The expression is sufficiently complex (to warrant special treatment). For now, that is
152 -- every expression that is not constant and contains at least one operation.
153 --
154 encapsulateScalars :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
155 encapsulateScalars ce@(_, AnnType _ty)
156 = return ce
157 encapsulateScalars ce@((_, VISimple), AnnVar v)
158 | isFunTy . varType $ v -- NB: diverts from the paper: encapsulate scalar function types
159 = liftSimpleAndCase ce
160 encapsulateScalars ce@(_, AnnVar _v)
161 = return ce
162 encapsulateScalars ce@(_, AnnLit _)
163 = return ce
164 encapsulateScalars ((fvs, vi), AnnTick tck expr)
165 = do
166 { encExpr <- encapsulateScalars expr
167 ; return ((fvs, vi), AnnTick tck encExpr)
168 }
169 encapsulateScalars ce@((fvs, vi), AnnLam bndr expr)
170 = do
171 { varsS <- allScalarVarTypeSet fvs
172 ; case (vi, varsS) of
173 (VISimple, True) -> liftSimpleAndCase ce
174 _ -> do
175 { encExpr <- encapsulateScalars expr
176 ; return ((fvs, vi), AnnLam bndr encExpr)
177 }
178 }
179 encapsulateScalars ce@((fvs, vi), AnnApp ce1 ce2)
180 = do
181 { varsS <- allScalarVarTypeSet fvs
182 ; case (vi, varsS) of
183 (VISimple, True) -> liftSimpleAndCase ce
184 _ -> do
185 { encCe1 <- encapsulateScalars ce1
186 ; encCe2 <- encapsulateScalars ce2
187 ; return ((fvs, vi), AnnApp encCe1 encCe2)
188 }
189 }
190 encapsulateScalars ce@((fvs, vi), AnnCase scrut bndr ty alts)
191 = do
192 { varsS <- allScalarVarTypeSet fvs
193 ; case (vi, varsS) of
194 (VISimple, True) -> liftSimpleAndCase ce
195 _ -> do
196 { encScrut <- encapsulateScalars scrut
197 ; encAlts <- mapM encAlt alts
198 ; return ((fvs, vi), AnnCase encScrut bndr ty encAlts)
199 }
200 }
201 where
202 encAlt (con, bndrs, expr) = (con, bndrs,) <$> encapsulateScalars expr
203 encapsulateScalars ce@((fvs, vi), AnnLet (AnnNonRec bndr expr1) expr2)
204 = do
205 { varsS <- allScalarVarTypeSet fvs
206 ; case (vi, varsS) of
207 (VISimple, True) -> liftSimpleAndCase ce
208 _ -> do
209 { encExpr1 <- encapsulateScalars expr1
210 ; encExpr2 <- encapsulateScalars expr2
211 ; return ((fvs, vi), AnnLet (AnnNonRec bndr encExpr1) encExpr2)
212 }
213 }
214 encapsulateScalars ce@((fvs, vi), AnnLet (AnnRec binds) expr)
215 = do
216 { varsS <- allScalarVarTypeSet fvs
217 ; case (vi, varsS) of
218 (VISimple, True) -> liftSimpleAndCase ce
219 _ -> do
220 { encBinds <- mapM encBind binds
221 ; encExpr <- encapsulateScalars expr
222 ; return ((fvs, vi), AnnLet (AnnRec encBinds) encExpr)
223 }
224 }
225 where
226 encBind (bndr, expr) = (bndr,) <$> encapsulateScalars expr
227 encapsulateScalars ((fvs, vi), AnnCast expr coercion)
228 = do
229 { encExpr <- encapsulateScalars expr
230 ; return ((fvs, vi), AnnCast encExpr coercion)
231 }
232 encapsulateScalars _
233 = panic "Vectorise.Exp.encapsulateScalars: unknown constructor"
234
235 -- Lambda-lift the given simple expression and apply it to the abstracted free variables.
236 --
237 -- If the expression is a case expression scrutinising anything, but a scalar type, then lift
238 -- each alternative individually.
239 --
240 liftSimpleAndCase :: CoreExprWithVectInfo -> VM CoreExprWithVectInfo
241 liftSimpleAndCase aexpr@((fvs, _vi), AnnCase expr bndr t alts)
242 = do
243 { vi <- vectAvoidInfoTypeOf expr
244 ; if (vi == VISimple)
245 then
246 return $ liftSimple aexpr -- if the scrutinee is scalar, we need no special treatment
247 else do
248 { alts' <- mapM (\(ac, bndrs, aexpr) -> (ac, bndrs,) <$> liftSimpleAndCase aexpr) alts
249 ; return ((fvs, vi), AnnCase expr bndr t alts')
250 }
251 }
252 liftSimpleAndCase aexpr = return $ liftSimple aexpr
253
254 liftSimple :: CoreExprWithVectInfo -> CoreExprWithVectInfo
255 liftSimple ((fvs, vi), expr)
256 = ASSERT(vi == VISimple)
257 mkAnnApps (mkAnnLams vars fvs expr) vars
258 where
259 vars = varSetElems fvs
260
261 mkAnnLams :: [Var] -> VarSet -> AnnExpr' Var (VarSet, VectAvoidInfo) -> CoreExprWithVectInfo
262 mkAnnLams [] fvs expr = ASSERT(isEmptyVarSet fvs)
263 ((emptyVarSet, VIEncaps), expr)
264 mkAnnLams (v:vs) fvs expr = mkAnnLams vs (fvs `delVarSet` v) (AnnLam v ((fvs, VIEncaps), expr))
265
266 mkAnnApps :: CoreExprWithVectInfo -> [Var] -> CoreExprWithVectInfo
267 mkAnnApps aexpr [] = aexpr
268 mkAnnApps aexpr (v:vs) = mkAnnApps (mkAnnApp aexpr v) vs
269
270 mkAnnApp :: CoreExprWithVectInfo -> Var -> CoreExprWithVectInfo
271 mkAnnApp aexpr@((fvs, _vi), _expr) v
272 = ((fvs `extendVarSet` v, VISimple), AnnApp aexpr ((unitVarSet v, VISimple), AnnVar v))
273
274 -- |Vectorise an expression.
275 --
276 vectExpr :: CoreExprWithVectInfo -> VM VExpr
277
278 vectExpr (_, AnnVar v)
279 = vectVar v
280
281 vectExpr (_, AnnLit lit)
282 = vectConst $ Lit lit
283
284 vectExpr e@(_, AnnLam bndr _)
285 | isId bndr = vectFnExpr True False e
286 | otherwise
287 = do
288 { dflags <- getDynFlags
289 ; cantVectorise dflags "Unexpected type lambda (vectExpr)" $ ppr (deAnnotate e)
290 }
291
292 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
293 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
294 -- happy.
295 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
296 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err)
297 | v == pAT_ERROR_ID
298 = do
299 { (vty, lty) <- vectAndLiftType ty
300 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
301 }
302 where
303 err' = deAnnotate err
304
305 -- type application (handle multiple consecutive type applications simultaneously to ensure the
306 -- PA dictionaries are put at the right places)
307 vectExpr e@(_, AnnApp _ arg)
308 | isAnnTypeArg arg
309 = vectPolyApp e
310
311 -- Lifted literal
312 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit))
313 | Just _con <- isDataConId_maybe v
314 = do
315 { let vexpr = App (Var v) (Lit lit)
316 ; lexpr <- liftPD vexpr
317 ; return (vexpr, lexpr)
318 }
319
320 -- value application (dictionary or user value)
321 vectExpr e@(_, AnnApp fn arg)
322 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
323 = vectPolyApp e
324 | otherwise -- user value
325 = do
326 { -- vectorise the types
327 ; varg_ty <- vectType arg_ty
328 ; vres_ty <- vectType res_ty
329
330 -- vectorise the function and argument expression
331 ; vfn <- vectExpr fn
332 ; varg <- vectExpr arg
333
334 -- the vectorised function is a closure; apply it to the vectorised argument
335 ; mkClosureApp varg_ty vres_ty vfn varg
336 }
337 where
338 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
339
340 vectExpr (_, AnnCase scrut bndr ty alts)
341 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
342 , isAlgTyCon tycon
343 = vectAlgCase tycon ty_args scrut bndr ty alts
344 | otherwise
345 = do
346 { dflags <- getDynFlags
347 ; cantVectorise dflags "Can't vectorise expression (no algebraic type constructor)" $
348 ppr scrut_ty
349 }
350 where
351 scrut_ty = exprType (deAnnotate scrut)
352
353 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
354 = do
355 { vrhs <- localV $
356 inBind bndr $
357 vectAnnPolyExpr False rhs
358 ; (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
359 ; return $ vLet (vNonRec vbndr vrhs) vbody
360 }
361
362 vectExpr (_, AnnLet (AnnRec bs) body)
363 = do
364 { (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
365 $ liftM2 (,)
366 (zipWithM vect_rhs bndrs rhss)
367 (vectExpr body)
368 ; return $ vLet (vRec vbndrs vrhss) vbody
369 }
370 where
371 (bndrs, rhss) = unzip bs
372
373 vect_rhs bndr rhs = localV $
374 inBind bndr $
375 vectAnnPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) rhs
376
377 vectExpr (_, AnnTick tickish expr)
378 = vTick tickish <$> vectExpr expr
379
380 vectExpr (_, AnnType ty)
381 = vType <$> vectType ty
382
383 vectExpr e
384 = do
385 { dflags <- getDynFlags
386 ; cantVectorise dflags "Can't vectorise expression (vectExpr)" $ ppr (deAnnotate e)
387 }
388
389 -- |Vectorise an expression that *may* have an outer lambda abstraction. If the expression is marked
390 -- as encapsulated ('VIEncaps'), vectorise it as a scalar computation (using a generalised scalar
391 -- zip).
392 --
393 -- We do not handle type variables at this point, as they will already have been stripped off by
394 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
395 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
396 --
397 vectFnExpr :: Bool -- ^If we process the RHS of a binding, whether that binding
398 -- should be inlined
399 -> Bool -- ^Whether the binding is a loop breaker
400 -> CoreExprWithVectInfo -- ^Expression to vectorise; must have an outer `AnnLam`
401 -> VM VExpr
402 vectFnExpr inline loop_breaker expr@(_ann, AnnLam bndr body)
403 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
404 | isId bndr
405 && isPredTy (idType bndr)
406 = do
407 { vBndr <- vectBndr bndr
408 ; vbody <- vectFnExpr inline loop_breaker body
409 ; return $ mapVect (mkLams [vectorised vBndr]) vbody
410 }
411 -- non-predicate abstraction: vectorise as a scalar computation
412 | isId bndr && isVIEncaps expr
413 = vectScalarFun . deAnnotate $ expr
414 -- non-predicate abstraction: vectorise as a non-scalar computation
415 | isId bndr
416 = vectLam inline loop_breaker expr
417 vectFnExpr _ _ expr
418 -- not an abstraction: vectorise as a vanilla expression
419 = vectExpr expr
420
421 -- |Vectorise type and dictionary applications.
422 --
423 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
424 -- involve two sets of type variables and dictionaries. Consider,
425 --
426 -- > class C a where
427 -- > m :: D b => b -> a
428 --
429 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
430 --
431 vectPolyApp :: CoreExprWithVectInfo -> VM VExpr
432 vectPolyApp e0
433 = case e4 of
434 (_, AnnVar var)
435 -> do { -- get the vectorised form of the variable
436 ; vVar <- lookupVar var
437 ; traceVt "vectPolyApp of" (ppr var)
438
439 -- vectorise type and dictionary arguments
440 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
441 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
442 ; vTysOuter <- mapM vectType tysOuter
443 ; vTysInner <- mapM vectType tysInner
444
445 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
446
447 ; case vVar of
448 Local (vv, lv)
449 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
450 ; traceVt " LOCAL" (text "")
451 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
452 }
453 Global vv
454 | isDictComp var -- dictionary computation
455 -> do { -- in a dictionary computation, the innermost, non-empty set of
456 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
457 -- are needed for the type variables
458 ; ve <- if null dictsInner
459 then
460 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
461 else
462 reconstructOuter
463 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
464 ; traceVt " GLOBAL (dict):" (ppr ve)
465 ; vectConst ve
466 }
467 | otherwise -- non-dictionary computation
468 -> do { MASSERT( null dictsInner )
469 ; ve <- reconstructOuter (Var vv)
470 ; traceVt " GLOBAL (non-dict):" (ppr ve)
471 ; vectConst ve
472 }
473 }
474 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
475 where
476 -- if there is only one set of variables or dictionaries, it will be the outer set
477 (e1, dictsOuter) = collectAnnDictArgs e0
478 (e2, tysOuter) = collectAnnTypeArgs e1
479 (e3, dictsInner) = collectAnnDictArgs e2
480 (e4, tysInner) = collectAnnTypeArgs e3
481 --
482 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
483
484 -- |Vectorise the body of a dfun.
485 --
486 -- Dictionary computations are special for the following reasons. The application of dictionary
487 -- functions are always saturated, so there is no need to create closures. Dictionary computations
488 -- don't depend on array values, so they are always scalar computations whose result we can
489 -- replicate (instead of executing them in parallel).
490 --
491 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
492 -- computation. Consequently, the variable case needs to deal with cases where binders are
493 -- in the vectoriser environments and where that is not the case.
494 --
495 vectDictExpr :: CoreExpr -> VM CoreExpr
496 vectDictExpr (Var var)
497 = do { mb_scope <- lookupVar_maybe var
498 ; case mb_scope of
499 Nothing -> return $ Var var -- binder from within the dict. computation
500 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
501 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
502 }
503 vectDictExpr (Lit lit)
504 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
505 vectDictExpr (Lam bndr e)
506 = Lam bndr <$> vectDictExpr e
507 vectDictExpr (App fn arg)
508 = App <$> vectDictExpr fn <*> vectDictExpr arg
509 vectDictExpr (Case e bndr ty alts)
510 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
511 where
512 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
513 --
514 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
515 where
516 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
517 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
518 vectDictAltCon DEFAULT = return DEFAULT
519 vectDictExpr (Let bnd body)
520 = Let <$> vectDictBind bnd <*> vectDictExpr body
521 where
522 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
523 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
524 vectDictExpr e@(Cast _e _coe)
525 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
526 vectDictExpr (Tick tickish e)
527 = Tick tickish <$> vectDictExpr e
528 vectDictExpr (Type ty)
529 = Type <$> vectType ty
530 vectDictExpr (Coercion coe)
531 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
532
533 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
534 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
535 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
536 -- require the full blown vectorisation transformation; instead, they can be lifted by application
537 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
538 --
539 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
540 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
541 -- "Note [Scalar dfuns]" in 'Vectorise'.
542 --
543 vectScalarFun :: CoreExpr -> VM VExpr
544 vectScalarFun expr
545 = do
546 { traceVt "vectScalarFun" (ppr expr)
547 ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
548 ; mkScalarFun arg_tys res_ty expr
549 }
550
551 -- Generate code for a scalar function by generating a scalar closure. If the function is a
552 -- dictionary function, vectorise it as dictionary code.
553 --
554 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
555 mkScalarFun arg_tys res_ty expr
556 | isPredTy res_ty
557 = do { vExpr <- vectDictExpr expr
558 ; return (vExpr, unused)
559 }
560 | otherwise
561 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
562
563 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
564 ; zipf <- zipScalars arg_tys res_ty
565 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
566 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
567 ; lclo <- liftPD (Var clo_var)
568 ; return (Var clo_var, lclo)
569 }
570 where
571 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
572
573 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
574 --
575 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
576 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
577 --
578 -- NB: You may think that we could implement this function guided by the struture of the Core
579 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
580 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
581 -- to the Core code of the unvectorised dfun.
582 --
583 -- Here an example — assume,
584 --
585 -- > class Eq a where { (==) :: a -> a -> Bool }
586 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
587 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
588 --
589 -- The unvectorised dfun for the above instance has the following signature:
590 --
591 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
592 --
593 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
594 --
595 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
596 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
597 -- > D:V:Eq $(vectScalarFun True recFns
598 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
599 --
600 -- NB:
601 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
602 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
603 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
604 --
605 vectScalarDFun :: Var -- ^ Original dfun
606 -> VM CoreExpr
607 vectScalarDFun var
608 = do { -- bring the type variables into scope
609 ; mapM_ defLocalTyVar tvs
610
611 -- vectorise dictionary argument types and generate variables for them
612 ; vTheta <- mapM vectType theta
613 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
614 ; let vThetaVars = varsToCoreExprs vThetaBndr
615
616 -- vectorise superclass dictionaries and methods as scalar expressions
617 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
618 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
619 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
620 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
621 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
622 selIds
623 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
624
625 -- vectorised applications of the class-dictionary data constructor
626 ; Just vDataCon <- lookupDataCon dataCon
627 ; vTys <- mapM vectType tys
628 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
629
630 ; return $ mkLams (tvs ++ vThetaBndr) vBody
631 }
632 where
633 ty = varType var
634 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
635 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
636 selIds = classAllSelIds cls
637 dataCon = classDataCon cls
638
639 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
640 -- expression computing the vectorised dictionary.
641 --
642 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
643 -- the unvectorised version, thus:
644 --
645 -- > D:C op1 .. opm
646 -- > where
647 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
648 --
649 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
650 --
651 unVectDict :: Type -> CoreExpr -> VM CoreExpr
652 unVectDict ty e
653 = do { vTys <- mapM vectType tys
654 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
655 ; scOps <- zipWithM fromVect methTys meths
656 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
657 }
658 where
659 (tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
660 cls = case tyConClass_maybe tycon of
661 Just cls -> cls
662 Nothing -> panic "Vectorise.Exp.unVectDict: no class"
663 selIds = classAllSelIds cls
664
665 -- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
666 --
667 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
668 -- variables are passed explicit (as conventional arguments) into the body during closure
669 -- construction.
670 --
671 vectLam :: Bool -- ^ Should the RHS of a binding be inlined?
672 -> Bool -- ^ Whether the binding is a loop breaker.
673 -> CoreExprWithVectInfo -- ^ Body of abstraction.
674 -> VM VExpr
675 vectLam inline loop_breaker expr@((fvs, _vi), AnnLam _ _)
676 = do { let (bndrs, body) = collectAnnValBinders expr
677
678 -- grab the in-scope type variables
679 ; tyvars <- localTyVars
680
681 -- collect and vectorise all /local/ free variables
682 ; vfvs <- readLEnv $ \env ->
683 [ (var, fromJust mb_vv)
684 | var <- varSetElems fvs
685 , let mb_vv = lookupVarEnv (local_vars env) var
686 , isJust mb_vv -- its local == is in local var env
687 ]
688 -- separate dictionary from non-dictionary variables in the free variable set
689 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
690 (_fvs_dict, vfvs_dict) = unzip vvs_dict
691 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
692
693 -- compute the type of the vectorised closure
694 ; arg_tys <- mapM (vectType . idType) bndrs
695 ; res_ty <- vectType (exprType $ deAnnotate body)
696
697 ; let arity = length fvs_nondict + length bndrs
698 vfvs_dict' = map vectorised vfvs_dict
699 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
700 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
701 $ do { -- generate the vectorised body of the lambda abstraction
702 ; lc <- builtin liftingContext
703 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) $ vectExpr body
704
705 ; vbody' <- break_loop lc res_ty vbody
706 ; return $ vLams lc vbndrs vbody'
707 }
708 }
709 where
710 maybe_inline n | inline = Inline n
711 | otherwise = DontInline
712
713 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
714 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
715 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
716 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
717 -- in Figure 6 of HtM.
718 break_loop lc ty (ve, le)
719 | loop_breaker
720 = do { empty <- emptyPD ty
721 ; lty <- mkPDataType ty
722 ; return (ve, mkWildCase (Var lc) intPrimTy lty
723 [(DEFAULT, [], le),
724 (LitAlt (mkMachInt 0), [], empty)])
725 }
726 | otherwise = return (ve, le)
727 vectLam _ _ _ = panic "Vectorise.Exp.vectLam: not a lambda"
728
729 -- Vectorise an algebraic case expression.
730 --
731 -- We convert
732 --
733 -- case e :: t of v { ... }
734 --
735 -- to
736 --
737 -- V: let v' = e in case v' of _ { ... }
738 -- L: let v' = e in case v' `cast` ... of _ { ... }
739 --
740 -- When lifting, we have to do it this way because v must have the type
741 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
742 -- have to handle the case where v is a wild var correctly.
743 --
744
745 -- FIXME: this is too lazy
746 vectAlgCase :: TyCon -> [Type] -> CoreExprWithVectInfo -> Var -> Type
747 -> [(AltCon, [Var], CoreExprWithVectInfo)]
748 -> VM VExpr
749 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)]
750 = do
751 vscrut <- vectExpr scrut
752 (vty, lty) <- vectAndLiftType ty
753 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
754 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
755
756 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)]
757 = do
758 vscrut <- vectExpr scrut
759 (vty, lty) <- vectAndLiftType ty
760 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
761 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
762
763 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
764 = do
765 (vty, lty) <- vectAndLiftType ty
766 vexpr <- vectExpr scrut
767 (vbndr, (vbndrs, (vect_body, lift_body)))
768 <- vect_scrut_bndr
769 . vectBndrsIn bndrs
770 $ vectExpr body
771 let (vect_bndrs, lift_bndrs) = unzip vbndrs
772 (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
773 vect_dc <- maybeV dataConErr (lookupDataCon dc)
774
775 let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
776 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
777
778 return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
779 where
780 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
781 | otherwise = vectBndrIn bndr
782
783 mk_wild_case expr ty dc bndrs body
784 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
785
786 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
787
788 vectAlgCase tycon _ty_args scrut bndr ty alts
789 = do
790 vect_tc <- vectTyCon tycon
791 (vty, lty) <- vectAndLiftType ty
792
793 let arity = length (tyConDataCons vect_tc)
794 sel_ty <- builtin (selTy arity)
795 sel_bndr <- newLocalVar (fsLit "sel") sel_ty
796 let sel = Var sel_bndr
797
798 (vbndr, valts) <- vect_scrut_bndr
799 $ mapM (proc_alt arity sel vty lty) alts'
800 let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
801
802 vexpr <- vectExpr scrut
803 (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
804
805 let (vect_bodies, lift_bodies) = unzip vbodies
806
807 vdummy <- newDummyVar (exprType vect_scrut)
808 ldummy <- newDummyVar (exprType lift_scrut)
809 let vect_case = Case vect_scrut vdummy vty
810 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
811
812 lc <- builtin liftingContext
813 lbody <- combinePD vty (Var lc) sel lift_bodies
814 let lift_case = Case lift_scrut ldummy lty
815 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
816 lbody)]
817
818 return . vLet (vNonRec vbndr vexpr)
819 $ (vect_case, lift_case)
820 where
821 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
822 | otherwise = vectBndrIn bndr
823
824 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
825
826 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
827 cmp DEFAULT DEFAULT = EQ
828 cmp DEFAULT _ = LT
829 cmp _ DEFAULT = GT
830 cmp _ _ = panic "vectAlgCase/cmp"
831
832 proc_alt arity sel _ lty (DataAlt dc, bndrs, body@((fvs_body, _), _))
833 = do
834 vect_dc <- maybeV dataConErr (lookupDataCon dc)
835 let ntag = dataConTagZ vect_dc
836 tag = mkDataConTag vect_dc
837 fvs = fvs_body `delVarSetList` bndrs
838
839 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
840 lc <- builtin liftingContext
841 elems <- builtin (selElements arity ntag)
842
843 (vbndrs, vbody)
844 <- vectBndrsIn bndrs
845 . localV
846 $ do
847 binds <- mapM (pack_var (Var lc) sel_tags tag)
848 . filter isLocalId
849 $ varSetElems fvs
850 (ve, le) <- vectExpr body
851 return (ve, Case (elems `App` sel) lc lty
852 [(DEFAULT, [], (mkLets (concat binds) le))])
853 -- empty <- emptyPD vty
854 -- return (ve, Case (elems `App` sel) lc lty
855 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
856 -- $ mkLets (concat binds) le),
857 -- (LitAlt (mkMachInt 0), [], empty)])
858 let (vect_bndrs, lift_bndrs) = unzip vbndrs
859 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
860 where
861 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
862
863
864 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
865
866 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
867
868 pack_var len tags t v
869 = do
870 r <- lookupVar v
871 case r of
872 Local (vv, lv) ->
873 do
874 lv' <- cloneVar lv
875 expr <- packByTagPD (idType vv) (Var lv) len tags t
876 updLEnv (\env -> env { local_vars = extendVarEnv
877 (local_vars env) v (vv, lv') })
878 return [(NonRec lv' expr)]
879
880 _ -> return []
881
882
883 -- Support to compute information for vectorisation avoidance ------------------
884
885 -- Annotation for Core AST nodes that describes how they should be handled during vectorisation
886 -- and especially if vectorisation of the corresponding computation can be avoided.
887 --
888 data VectAvoidInfo = VIParr -- tree contains parallel computations
889 | VISimple -- result type is scalar & no parallel subcomputation
890 | VIComplex -- any result type, no parallel subcomputation
891 | VIEncaps -- tree encapsulated by 'liftSimple'
892 | VIDict -- dictionary computation (never parallel)
893 deriving (Eq, Show)
894
895 -- Core expression annotated with free variables and vectorisation-specific information.
896 --
897 type CoreExprWithVectInfo = AnnExpr Id (VarSet, VectAvoidInfo)
898
899 -- Yield the type of an annotated core expression.
900 --
901 annExprType :: AnnExpr Var ann -> Type
902 annExprType = exprType . deAnnotate
903
904 -- Project the vectorisation information from an annotated Core expression.
905 --
906 vectAvoidInfoOf :: CoreExprWithVectInfo -> VectAvoidInfo
907 vectAvoidInfoOf ((_, vi), _) = vi
908
909 -- Is this a 'VIParr' node?
910 --
911 isVIParr :: CoreExprWithVectInfo -> Bool
912 isVIParr = (== VIParr) . vectAvoidInfoOf
913
914 -- Is this a 'VIEncaps' node?
915 --
916 isVIEncaps :: CoreExprWithVectInfo -> Bool
917 isVIEncaps = (== VIEncaps) . vectAvoidInfoOf
918
919 -- Is this a 'VIDict' node?
920 --
921 isVIDict :: CoreExprWithVectInfo -> Bool
922 isVIDict = (== VIDict) . vectAvoidInfoOf
923
924 -- 'VIParr' if either argument is 'VIParr'; otherwise, the first argument.
925 --
926 unlessVIParr :: VectAvoidInfo -> VectAvoidInfo -> VectAvoidInfo
927 unlessVIParr _ VIParr = VIParr
928 unlessVIParr vi _ = vi
929
930 -- 'VIParr' if either arguments vectorisation information is 'VIParr'; otherwise, the vectorisation
931 -- information of the first argument is produced.
932 --
933 unlessVIParrExpr :: VectAvoidInfo -> CoreExprWithVectInfo -> VectAvoidInfo
934 infixl `unlessVIParrExpr`
935 unlessVIParrExpr e1 e2 = e1 `unlessVIParr` vectAvoidInfoOf e2
936
937 -- Compute Core annotations to determine for which subexpressions we can avoid vectorisation.
938 --
939 -- * The first argument is the set of free, local variables whose evaluation may entail parallelism.
940 --
941 vectAvoidInfo :: VarSet -> CoreExprWithFVs -> VM CoreExprWithVectInfo
942 vectAvoidInfo pvs ce@(fvs, AnnVar v)
943 = do
944 { gpvs <- globalParallelVars
945 ; vi <- if v `elemVarSet` pvs || v `elemVarSet` gpvs
946 then return VIParr
947 else vectAvoidInfoTypeOf ce
948 ; viTrace ce vi []
949
950 ; vit <- vectAvoidInfoTypeOf ce -- TEMPORARY
951 ; traceVt (" AnnVar: vectAvoidInfoTypeOf: " ++ show vit) empty
952
953 ; return ((fvs, vi), AnnVar v)
954 }
955
956 vectAvoidInfo _pvs ce@(fvs, AnnLit lit)
957 = do
958 { vi <- vectAvoidInfoTypeOf ce
959 ; viTrace ce vi []
960 ; return ((fvs, vi), AnnLit lit)
961 }
962
963 vectAvoidInfo pvs ce@(fvs, AnnApp e1 e2)
964 = do
965 { ceVI <- vectAvoidInfoTypeOf ce
966 ; eVI1 <- vectAvoidInfo pvs e1
967 ; eVI2 <- vectAvoidInfo pvs e2
968 ; let vi = ceVI `unlessVIParrExpr` eVI1 `unlessVIParrExpr` eVI2
969 ; viTrace ce vi [eVI1, eVI2]
970 ; return ((fvs, vi), AnnApp eVI1 eVI2)
971 }
972
973 vectAvoidInfo pvs ce@(fvs, AnnLam var body)
974 = do
975 { bodyVI <- vectAvoidInfo pvs body
976 ; varVI <- vectAvoidInfoType $ varType var
977 ; let vi = vectAvoidInfoOf bodyVI `unlessVIParr` varVI
978 ; viTrace ce vi [bodyVI]
979 ; return ((fvs, vi), AnnLam var bodyVI)
980 }
981
982 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnNonRec var e) body)
983 = do
984 { ceVI <- vectAvoidInfoTypeOf ce
985 ; eVI <- vectAvoidInfo pvs e
986 ; isScalarTy <- isScalar $ varType var
987 ; (bodyVI, vi) <- if isVIParr eVI && not isScalarTy
988 then do -- binding is parallel
989 { bodyVI <- vectAvoidInfo (fvs `extendVarSet` var) body
990 ; return (bodyVI, VIParr)
991 }
992 else do -- binding doesn't affect parallelism
993 { bodyVI <- vectAvoidInfo fvs body
994 ; return (bodyVI, ceVI `unlessVIParrExpr` bodyVI)
995 }
996 ; viTrace ce vi [eVI, bodyVI]
997 ; return ((fvs, vi), AnnLet (AnnNonRec var eVI) bodyVI)
998 }
999
1000 vectAvoidInfo pvs ce@(fvs, AnnLet (AnnRec bnds) body)
1001 = do
1002 { ceVI <- vectAvoidInfoTypeOf ce
1003 ; bndsVI <- mapM (vectAvoidInfoBnd pvs) bnds
1004 ; parrBndrs <- map fst <$> filterM isVIParrBnd bndsVI
1005 ; if not . null $ parrBndrs
1006 then do -- body may trigger parallelism via at least one binding
1007 { new_pvs <- filterM ((not <$>) . isScalar . varType) parrBndrs
1008 ; let extendedPvs = pvs `extendVarSetList` new_pvs
1009 ; bndsVI <- mapM (vectAvoidInfoBnd extendedPvs) bnds
1010 ; bodyVI <- vectAvoidInfo extendedPvs body
1011 ; viTrace ce VIParr (map snd bndsVI ++ [bodyVI])
1012 ; return ((fvs, VIParr), AnnLet (AnnRec bndsVI) bodyVI)
1013 }
1014 else do -- demanded bindings cannot trigger parallelism
1015 { bodyVI <- vectAvoidInfo pvs body
1016 ; let vi = ceVI `unlessVIParrExpr` bodyVI
1017 ; viTrace ce vi (map snd bndsVI ++ [bodyVI])
1018 ; return ((fvs, vi), AnnLet (AnnRec bndsVI) bodyVI)
1019 }
1020 }
1021 where
1022 vectAvoidInfoBnd pvs (var, e) = (var,) <$> vectAvoidInfo pvs e
1023
1024 isVIParrBnd (var, eVI)
1025 = do
1026 { isScalarTy <- isScalar (varType var)
1027 ; return $ isVIParr eVI && not isScalarTy
1028 }
1029
1030 vectAvoidInfo pvs ce@(fvs, AnnCase e var ty alts)
1031 = do
1032 { ceVI <- vectAvoidInfoTypeOf ce
1033 ; eVI <- vectAvoidInfo pvs e
1034 ; altsVI <- mapM (vectAvoidInfoAlt (isVIParr eVI)) alts
1035 ; let alteVIs = [eVI | (_, _, eVI) <- altsVI]
1036 vi = foldl unlessVIParrExpr ceVI (eVI:alteVIs) -- NB: same effect as in the paper
1037 ; viTrace ce vi (eVI : alteVIs)
1038 ; return ((fvs, vi), AnnCase eVI var ty altsVI)
1039 }
1040 where
1041 vectAvoidInfoAlt scrutIsPar (con, bndrs, e)
1042 = do
1043 { allScalar <- allScalarVarType bndrs
1044 ; let altPvs | scrutIsPar && not allScalar = pvs `extendVarSetList` bndrs
1045 | otherwise = pvs
1046 ; (con, bndrs,) <$> vectAvoidInfo altPvs e
1047 }
1048
1049 vectAvoidInfo pvs (fvs, AnnCast e (fvs_ann, ann))
1050 = do
1051 { eVI <- vectAvoidInfo pvs e
1052 ; return ((fvs, vectAvoidInfoOf eVI), AnnCast eVI ((fvs_ann, VISimple), ann))
1053 }
1054
1055 vectAvoidInfo pvs (fvs, AnnTick tick e)
1056 = do
1057 { eVI <- vectAvoidInfo pvs e
1058 ; return ((fvs, vectAvoidInfoOf eVI), AnnTick tick eVI)
1059 }
1060
1061 vectAvoidInfo _pvs (fvs, AnnType ty)
1062 = return ((fvs, VISimple), AnnType ty)
1063
1064 vectAvoidInfo _pvs (fvs, AnnCoercion coe)
1065 = return ((fvs, VISimple), AnnCoercion coe)
1066
1067 -- Compute vectorisation avoidance information for a type.
1068 --
1069 vectAvoidInfoType :: Type -> VM VectAvoidInfo
1070 vectAvoidInfoType ty
1071 | isPredTy ty
1072 = return VIDict
1073 | Just (arg, res) <- splitFunTy_maybe ty
1074 = do
1075 { argVI <- vectAvoidInfoType arg
1076 ; resVI <- vectAvoidInfoType res
1077 ; case (argVI, resVI) of
1078 (VISimple, VISimple) -> return VISimple -- NB: diverts from the paper: scalar functions
1079 (_ , VIDict) -> return VIDict
1080 _ -> return $ VIComplex `unlessVIParr` argVI `unlessVIParr` resVI
1081 }
1082 | otherwise
1083 = do
1084 { parr <- maybeParrTy ty
1085 ; if parr
1086 then return VIParr
1087 else do
1088 { scalar <- isScalar ty
1089 ; if scalar
1090 then return VISimple
1091 else return VIComplex
1092 } }
1093
1094 -- Compute vectorisation avoidance information for the type of a Core expression (with FVs).
1095 --
1096 vectAvoidInfoTypeOf :: AnnExpr Var ann -> VM VectAvoidInfo
1097 vectAvoidInfoTypeOf = vectAvoidInfoType . annExprType
1098
1099 -- Checks whether the type might be a parallel array type.
1100 --
1101 maybeParrTy :: Type -> VM Bool
1102 maybeParrTy ty
1103 -- looking through newtypes
1104 | Just ty' <- coreView ty
1105 = (== VIParr) <$> vectAvoidInfoType ty'
1106 -- decompose constructor applications
1107 | Just (tc, ts) <- splitTyConApp_maybe ty
1108 = do
1109 { isParallel <- (tyConName tc `elemNameSet`) <$> globalParallelTyCons
1110 ; if isParallel
1111 then return True
1112 else or <$> mapM maybeParrTy ts
1113 }
1114 maybeParrTy (ForAllTy _ ty) = maybeParrTy ty
1115 maybeParrTy _ = return False
1116
1117 -- Are the types of all variables in the 'Scalar' class?
1118 --
1119 allScalarVarType :: [Var] -> VM Bool
1120 allScalarVarType vs = and <$> mapM (isScalar . varType) vs
1121
1122 -- Are the types of all variables in the set in the 'Scalar' class?
1123 --
1124 allScalarVarTypeSet :: VarSet -> VM Bool
1125 allScalarVarTypeSet = allScalarVarType . varSetElems
1126
1127 -- Debugging support
1128 --
1129 viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [CoreExprWithVectInfo] -> VM ()
1130 viTrace ce vi vTs
1131 = traceVt ("vect info: " ++ show vi ++ "[" ++
1132 (concat $ map ((++ " ") . show . vectAvoidInfoOf) vTs) ++ "]")
1133 (ppr $ deAnnotate ce)