Fix #6080 & house keeping in Vectorise.Exp
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular
7 -- variable bindings
8 vectPolyExpr
9 , vectDictExpr
10 , vectScalarFun
11 , vectScalarDFun
12 )
13 where
14
15 #include "HsVersions.h"
16
17 import Vectorise.Type.Type
18 import Vectorise.Var
19 import Vectorise.Convert
20 import Vectorise.Vect
21 import Vectorise.Env
22 import Vectorise.Monad
23 import Vectorise.Builtins
24 import Vectorise.Utils
25
26 import CoreUtils
27 import MkCore
28 import CoreSyn
29 import CoreFVs
30 import Class
31 import DataCon
32 import TyCon
33 import TcType
34 import Type
35 import PrelNames
36 import Var
37 import VarEnv
38 import VarSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysWiredIn
43 import TysPrim
44 import Outputable
45 import FastString
46 import Control.Monad
47 import Control.Applicative
48 import Data.Maybe
49 import Data.List
50 import TcRnMonad (doptM)
51 import DynFlags (DynFlag(Opt_AvoidVect))
52
53
54 -- Main entry point to vectorise expressions -----------------------------------
55
56 -- |Vectorise a polymorphic expression.
57 --
58 -- If not yet available, precompute vectorisation avoidance information before vectorising. If
59 -- the vectorisation avoidance optimisation is enabled, also use the vectorisation avoidance
60 -- information to encapsulated subexpression that do not need to be vectorised.
61 --
62 vectPolyExpr :: Bool -> [Var] -> CoreExprWithFVs -> Maybe VITree
63 -> VM (Inline, Bool, VExpr)
64 -- precompute vectorisation avoidance information (and possibly encapsulated subexpressions)
65 vectPolyExpr loop_breaker recFns expr Nothing
66 = do
67 { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
68 ; vi <- vectAvoidInfo expr
69 ; (expr', vi') <-
70 if vectAvoidance
71 then do
72 { (expr', vi') <- encapsulateScalars vi expr
73 ; traceVt "vectPolyExpr encapsulated:" (ppr $ deAnnotate expr')
74 ; return (expr', vi')
75 }
76 else return (expr, vi)
77 ; vectPolyExpr loop_breaker recFns expr' (Just vi')
78 }
79
80 -- traverse through ticks
81 vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr) (Just (VITNode _ [vit]))
82 = do
83 { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr (Just vit)
84 ; return (inline, isScalarFn, vTick tickish expr')
85 }
86
87 -- collect and vectorise type abstractions; then, descent into the body
88 vectPolyExpr loop_breaker recFns expr (Just vit)
89 = do
90 { let (tvs, mono) = collectAnnTypeBinders expr
91 vit' = stripLevels (length tvs) vit
92 ; arity <- polyArity tvs
93 ; polyAbstract tvs $ \args ->
94 do
95 { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vit'
96 ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
97 }
98 }
99 where
100 stripLevels 0 vit = vit
101 stripLevels n (VITNode _ [vit]) = stripLevels (n - 1) vit
102 stripLevels _ vit = pprPanic "vectPolyExpr: stripLevels:" (text (show vit))
103
104 -- Encapsulate every purely sequential subexpression of a (potentially) parallel expression into a
105 -- into a lambda abstraction over all its free variables followed by the corresponding application
106 -- to those variables. We can, then, avoid the vectorisation of the ensapsulated subexpressions.
107 --
108 -- Preconditions:
109 --
110 -- * All free variables and the result type must be /simple/ types.
111 -- * The expression is sufficientlt complex (top warrant special treatment). For now, that is
112 -- every expression that is not constant and contains at least one operation.
113 --
114 encapsulateScalars :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
115 encapsulateScalars vit ce@(_, AnnType _ty)
116 = return (ce, vit)
117
118 encapsulateScalars vit ce@(_, AnnVar _v)
119 = return (ce, vit)
120
121 encapsulateScalars vit ce@(_, AnnLit _)
122 = return (ce, vit)
123
124 encapsulateScalars (VITNode vi [vit]) (fvs, AnnTick tck expr)
125 = do { (extExpr, vit') <- encapsulateScalars vit expr
126 ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
127 }
128
129 encapsulateScalars _ (_fvs, AnnTick _tck _expr)
130 = panic "encapsulateScalar AnnTick doesn't match up"
131
132 encapsulateScalars (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
133 = do { varsS <- varsSimple fvs
134 ; case (vi, varsS) of
135 (VISimple, True) -> do { let (e', vit') = liftSimple vit ce
136 ; return (e', vit')
137 }
138 _ -> do { (extExpr, vit') <- encapsulateScalars vit expr
139 ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
140 }
141 }
142
143 encapsulateScalars _ (_fvs, AnnLam _bndr _expr)
144 = panic "encapsulateScalars AnnLam doesn't match up"
145
146 encapsulateScalars vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
147 = do { varsS <- varsSimple fvs
148 ; case (vi, varsS) of
149 (VISimple, True) -> do { let (e', vt') = liftSimple vt ce
150 -- ; checkTreeAnnM vt' e'
151 -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
152 ; return (e', vt')
153 }
154 _ -> do { (etaCe1, vit1') <- encapsulateScalars vit1 ce1
155 ; (etaCe2, vit2') <- encapsulateScalars vit2 ce2
156 ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
157 }
158 }
159
160 encapsulateScalars _ (_fvs, AnnApp _ce1 _ce2)
161 = panic "encapsulateScalars AnnApp doesn't match up"
162
163 encapsulateScalars vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
164 = do { varsS <- varsSimple fvs
165 ; case (vi, varsS) of
166 (VISimple, True) -> return $ liftSimple vt ce
167 _ -> do { (extScrut, scrutVit') <- encapsulateScalars scrutVit scrut
168 ; extAltsVits <- zipWithM expAlt altVits alts
169 ; let (extAlts, altVits') = unzip extAltsVits
170 ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
171 }
172 }
173 where
174 expAlt vt (con, bndrs, expr)
175 = do { (extExpr, vt') <- encapsulateScalars vt expr
176 ; return ((con, bndrs, extExpr), vt')
177 }
178
179 encapsulateScalars _ (_fvs, AnnCase _scrut _bndr _ty _alts)
180 = panic "encapsulateScalars AnnCase doesn't match up"
181
182 encapsulateScalars vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
183 = do { varsS <- varsSimple fvs
184 ; case (vi, varsS) of
185 (VISimple, True) -> return $ liftSimple vt ce
186 _ -> do { (extExpr1, vt1') <- encapsulateScalars vt1 expr1
187 ; (extExpr2, vt2') <- encapsulateScalars vt2 expr2
188 ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
189 }
190 }
191
192 encapsulateScalars _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
193 = panic "encapsulateScalars AnnLet nonrec doesn't match up"
194
195 encapsulateScalars vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
196 = do { varsS <- varsSimple fvs
197 ; case (vi, varsS) of
198 (VISimple, True) -> return $ liftSimple vt ce
199 _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
200 ; let (extBnds, vtBnds') = unzip extBndsVts
201 ; (extExpr, vtB') <- encapsulateScalars vtB expr
202 ; let vt' = VITNode vi (vtB':vtBnds')
203 ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
204 }
205 }
206 where
207 expBndg vit (bndr, expr)
208 = do { (extExpr, vit') <- encapsulateScalars vit expr
209 ; return ((bndr, extExpr), vit')
210 }
211
212 encapsulateScalars _ (_fvs, AnnLet (AnnRec _) _expr2)
213 = panic "encapsulateScalars AnnLet rec doesn't match up"
214
215 encapsulateScalars (VITNode vi [vit]) (fvs, AnnCast expr coercion)
216 = do { (extExpr, vit') <- encapsulateScalars vit expr
217 ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
218 }
219
220 encapsulateScalars _ (_fvs, AnnCast _expr _coercion)
221 = panic "encapsulateScalars AnnCast rec doesn't match up"
222
223 encapsulateScalars _ _
224 = panic "encapsulateScalars case not handled"
225
226 -- Lambda-lift the given expression and apply it to the abstracted free variables.
227 --
228 -- If the expression is a case expression scrutinising anything but a primitive type, then lift
229 -- each alternative individually.
230 --
231 liftSimple :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
232 liftSimple (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
233 | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
234 (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- FIXME: shouldn't be hardcoded
235 = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
236 where
237 (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
238 zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, liftSimple altVi aex)) alts altVits
239
240 liftSimple viTree ae@(fvs, _annEx)
241 = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
242 where
243 mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
244 mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
245
246 mkViTreeApps vi [] = vi
247 mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
248
249 vars = varSetElems fvs
250 viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
251
252 mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
253 mkAnnLam bndr ce = AnnLam bndr ce
254
255 mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
256 mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
257 mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
258
259 mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
260 mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
261
262 mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
263 mkAnnApps (fv, aex') [] = (fv, aex')
264 mkAnnApps ae (v:vs) =
265 let
266 (fv, aex') = mkAnnApps ae vs
267 in (extendVarSet fv v, mkAnnApp (fv, aex') v)
268
269 -- |Vectorise an expression.
270 --
271 vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
272 -- vectExpr e vi | not (checkTree vi (deAnnotate e))
273 -- = pprPanic "vectExpr" (ppr $ deAnnotate e)
274
275 vectExpr (_, AnnVar v) _
276 = vectVar v
277
278 vectExpr (_, AnnLit lit) _
279 = vectConst $ Lit lit
280
281 vectExpr e@(_, AnnLam bndr _) vt
282 | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
283 | otherwise = cantVectorise "Unexpected type lambda (vectExpr)" (ppr (deAnnotate e))
284
285 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
286 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
287 -- happy.
288 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
289 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
290 | v == pAT_ERROR_ID
291 = do { (vty, lty) <- vectAndLiftType ty
292 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
293 }
294 where
295 err' = deAnnotate err
296
297 -- type application (handle multiple consecutive type applications simultaneously to ensure the
298 -- PA dictionaries are put at the right places)
299 vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
300 | isAnnTypeArg arg
301 = vectPolyApp e
302
303 -- 'Int', 'Float', or 'Double' literal
304 -- FIXME: this needs to be generalised
305 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
306 | Just con <- isDataConId_maybe v
307 , is_special_con con
308 = do
309 let vexpr = App (Var v) (Lit lit)
310 lexpr <- liftPD vexpr
311 return (vexpr, lexpr)
312 where
313 is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
314
315 -- value application (dictionary or user value)
316 vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
317 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
318 = vectPolyApp e
319 | otherwise -- user value
320 = do { -- vectorise the types
321 ; varg_ty <- vectType arg_ty
322 ; vres_ty <- vectType res_ty
323
324 -- vectorise the function and argument expression
325 ; vfn <- vectExpr fn vit1
326 ; varg <- vectExpr arg vit2
327
328 -- the vectorised function is a closure; apply it to the vectorised argument
329 ; mkClosureApp varg_ty vres_ty vfn varg
330 }
331 where
332 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
333
334 vectExpr (_, AnnCase scrut bndr ty alts) vt
335 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
336 , isAlgTyCon tycon
337 = vectAlgCase tycon ty_args scrut bndr ty alts vt
338 | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
339 where
340 scrut_ty = exprType (deAnnotate scrut)
341
342 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
343 = do
344 vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (Just vt1)
345 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
346 return $ vLet (vNonRec vbndr vrhs) vbody
347
348 vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
349 = do
350 (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
351 $ liftM2 (,)
352 (zipWith3M vect_rhs bndrs rhss vtBnds)
353 (vectExpr body vtB)
354 return $ vLet (vRec vbndrs vrhss) vbody
355 where
356 (bndrs, rhss) = unzip bs
357
358 vect_rhs bndr rhs vt = localV
359 . inBind bndr
360 . liftM (\(_,_,z)->z)
361 $ vectPolyExpr (isStrongLoopBreaker $ idOccInfo bndr) [] rhs (Just vt)
362 zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
363
364 vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
365 = liftM (vTick tickish) (vectExpr expr vit)
366
367 vectExpr (_, AnnType ty) _
368 = liftM vType (vectType ty)
369
370 vectExpr e vit = cantVectorise "Can't vectorise expression (vectExpr)" (ppr (deAnnotate e) $$ text (" " ++ show vit))
371
372 -- |Vectorise an expression that *may* have an outer lambda abstraction.
373 --
374 -- We do not handle type variables at this point, as they will already have been stripped off by
375 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
376 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
377 --
378 vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
379 -- be inlined
380 -> Bool -- ^ Whether the binding is a loop breaker
381 -> [Var] -- ^ Names of function in same recursive binding group
382 -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
383 -> VITree
384 -> VM (Inline, Bool, VExpr)
385 -- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
386 -- = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
387 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
388 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
389 | isId bndr
390 && isPredTy (idType bndr)
391 = do { vBndr <- vectBndr bndr
392 ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
393 ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
394 }
395 -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
396 | isId bndr
397 = mark DontInline True (vectScalarFunMaybe (deAnnotate expr) vt)
398 `orElseV`
399 mark inlineMe False (vectLam inline loop_breaker expr vt)
400 vectFnExpr _ _ _ e vt
401 -- not an abstraction: vectorise as a vanilla expression
402 = mark DontInline False $ vectExpr e vt
403
404 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
405 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
406
407 -- |Vectorise type and dictionary applications.
408 --
409 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
410 -- involve two sets of type variables and dictionaries. Consider,
411 --
412 -- > class C a where
413 -- > m :: D b => b -> a
414 --
415 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
416 --
417 vectPolyApp :: CoreExprWithFVs -> VM VExpr
418 vectPolyApp e0
419 = case e4 of
420 (_, AnnVar var)
421 -> do { -- get the vectorised form of the variable
422 ; vVar <- lookupVar var
423 ; traceVt "vectPolyApp of" (ppr var)
424
425 -- vectorise type and dictionary arguments
426 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
427 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
428 ; vTysOuter <- mapM vectType tysOuter
429 ; vTysInner <- mapM vectType tysInner
430
431 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
432
433 ; case vVar of
434 Local (vv, lv)
435 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
436 ; traceVt " LOCAL" (text "")
437 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
438 }
439 Global vv
440 | isDictComp var -- dictionary computation
441 -> do { -- in a dictionary computation, the innermost, non-empty set of
442 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
443 -- are needed for the type variables
444 ; ve <- if null dictsInner
445 then
446 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
447 else
448 reconstructOuter
449 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
450 ; traceVt " GLOBAL (dict):" (ppr ve)
451 ; vectConst ve
452 }
453 | otherwise -- non-dictionary computation
454 -> do { MASSERT( null dictsInner )
455 ; ve <- reconstructOuter (Var vv)
456 ; traceVt " GLOBAL (non-dict):" (ppr ve)
457 ; vectConst ve
458 }
459 }
460 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
461 where
462 -- if there is only one set of variables or dictionaries, it will be the outer set
463 (e1, dictsOuter) = collectAnnDictArgs e0
464 (e2, tysOuter) = collectAnnTypeArgs e1
465 (e3, dictsInner) = collectAnnDictArgs e2
466 (e4, tysInner) = collectAnnTypeArgs e3
467 --
468 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
469
470 -- |Vectorise the body of a dfun.
471 --
472 -- Dictionary computations are special for the following reasons. The application of dictionary
473 -- functions are always saturated, so there is no need to create closures. Dictionary computations
474 -- don't depend on array values, so they are always scalar computations whose result we can
475 -- replicate (instead of executing them in parallel).
476 --
477 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
478 -- computation. Consequently, the variable case needs to deal with cases where binders are
479 -- in the vectoriser environments and where that is not the case.
480 --
481 vectDictExpr :: CoreExpr -> VM CoreExpr
482 vectDictExpr (Var var)
483 = do { mb_scope <- lookupVar_maybe var
484 ; case mb_scope of
485 Nothing -> return $ Var var -- binder from within the dict. computation
486 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
487 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
488 }
489 vectDictExpr (Lit lit)
490 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
491 vectDictExpr (Lam bndr e)
492 = Lam bndr <$> vectDictExpr e
493 vectDictExpr (App fn arg)
494 = App <$> vectDictExpr fn <*> vectDictExpr arg
495 vectDictExpr (Case e bndr ty alts)
496 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
497 where
498 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
499 --
500 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
501 where
502 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
503 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
504 vectDictAltCon DEFAULT = return DEFAULT
505 vectDictExpr (Let bnd body)
506 = Let <$> vectDictBind bnd <*> vectDictExpr body
507 where
508 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
509 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
510 vectDictExpr e@(Cast _e _coe)
511 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
512 vectDictExpr (Tick tickish e)
513 = Tick tickish <$> vectDictExpr e
514 vectDictExpr (Type ty)
515 = Type <$> vectType ty
516 vectDictExpr (Coercion coe)
517 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
518
519 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
520 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
521 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
522 -- require the full blown vectorisation transformation; instead, they can be lifted by application
523 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
524 --
525 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
526 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
527 -- "Note [Scalar dfuns]" in 'Vectorise'.
528 --
529 vectScalarFunMaybe :: CoreExpr -- ^ Expression to be vectorised
530 -> VITree -- ^ Vectorisation information
531 -> VM VExpr
532 vectScalarFunMaybe expr (VITNode VIEncaps _) = vectScalarFun expr
533 vectScalarFunMaybe _expr _ = noV $ ptext (sLit "not a scalar function")
534
535 -- |Vectorise an expression of functional type by lifting it by an application of a member of the
536 -- zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.) This is only a valid strategy if the
537 -- function does not contain parallel subcomputations and has only 'Scalar' types in its result and
538 -- arguments — this is a predcondition for calling this function.
539 --
540 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
541 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
542 -- "Note [Scalar dfuns]" in 'Vectorise'.
543 --
544 vectScalarFun :: CoreExpr -> VM VExpr
545 vectScalarFun expr
546 = do
547 { traceVt "vectScalarFun" (ppr expr)
548 ; let (arg_tys, res_ty) = splitFunTys (exprType expr)
549 ; mkScalarFun arg_tys res_ty expr
550 }
551
552 -- Generate code for a scalar function by generating a scalar closure. If the function is a
553 -- dictionary function, vectorise it as dictionary code.
554 --
555 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
556 mkScalarFun arg_tys res_ty expr
557 | isPredTy res_ty
558 = do { vExpr <- vectDictExpr expr
559 ; return (vExpr, unused)
560 }
561 | otherwise
562 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
563
564 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
565 ; zipf <- zipScalars arg_tys res_ty
566 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
567 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
568 ; lclo <- liftPD (Var clo_var)
569 ; return (Var clo_var, lclo)
570 }
571 where
572 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
573
574 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
575 --
576 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
577 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
578 --
579 -- NB: You may think that we could implement this function guided by the struture of the Core
580 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
581 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
582 -- to the Core code of the unvectorised dfun.
583 --
584 -- Here an example — assume,
585 --
586 -- > class Eq a where { (==) :: a -> a -> Bool }
587 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
588 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
589 --
590 -- The unvectorised dfun for the above instance has the following signature:
591 --
592 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
593 --
594 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
595 --
596 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
597 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
598 -- > D:V:Eq $(vectScalarFun True recFns
599 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
600 --
601 -- NB:
602 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
603 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
604 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
605 --
606 vectScalarDFun :: Var -- ^ Original dfun
607 -> VM CoreExpr
608 vectScalarDFun var
609 = do { -- bring the type variables into scope
610 ; mapM_ defLocalTyVar tvs
611
612 -- vectorise dictionary argument types and generate variables for them
613 ; vTheta <- mapM vectType theta
614 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
615 ; let vThetaVars = varsToCoreExprs vThetaBndr
616
617 -- vectorise superclass dictionaries and methods as scalar expressions
618 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
619 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
620 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
621 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
622 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
623 selIds
624 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun e) scsOps
625
626 -- vectorised applications of the class-dictionary data constructor
627 ; Just vDataCon <- lookupDataCon dataCon
628 ; vTys <- mapM vectType tys
629 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
630
631 ; return $ mkLams (tvs ++ vThetaBndr) vBody
632 }
633 where
634 ty = varType var
635 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
636 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
637 selIds = classAllSelIds cls
638 dataCon = classDataCon cls
639
640 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
641 -- expression computing the vectorised dictionary.
642 --
643 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
644 -- the unvectorised version, thus:
645 --
646 -- > D:C op1 .. opm
647 -- > where
648 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
649 --
650 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
651 --
652 unVectDict :: Type -> CoreExpr -> VM CoreExpr
653 unVectDict ty e
654 = do { vTys <- mapM vectType tys
655 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
656 ; scOps <- zipWithM fromVect methTys meths
657 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
658 }
659 where
660 (tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
661 cls = case tyConClass_maybe tycon of
662 Just cls -> cls
663 Nothing -> panic "Vectorise.Exp.unVectDict: no class"
664 selIds = classAllSelIds cls
665
666 -- Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
667 --
668 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
669 -- variables are passed explicit (as conventional arguments) into the body during closure
670 -- construction.
671 --
672 vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
673 -> Bool -- ^ Whether the binding is a loop breaker.
674 -> CoreExprWithFVs -- ^ Body of abstraction.
675 -> VITree
676 -> VM VExpr
677 vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
678 = do { let (bndrs, body) = collectAnnValBinders expr
679
680 -- grab the in-scope type variables
681 ; tyvars <- localTyVars
682
683 -- collect and vectorise all /local/ free variables
684 ; vfvs <- readLEnv $ \env ->
685 [ (var, fromJust mb_vv)
686 | var <- varSetElems fvs
687 , let mb_vv = lookupVarEnv (local_vars env) var
688 , isJust mb_vv -- its local == is in local var env
689 ]
690 -- separate dictionary from non-dictionary variables in the free variable set
691 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
692 (_fvs_dict, vfvs_dict) = unzip vvs_dict
693 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
694
695 -- compute the type of the vectorised closure
696 ; arg_tys <- mapM (vectType . idType) bndrs
697 ; res_ty <- vectType (exprType $ deAnnotate body)
698
699 ; let arity = length fvs_nondict + length bndrs
700 vfvs_dict' = map vectorised vfvs_dict
701 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
702 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
703 $ do { -- generate the vectorised body of the lambda abstraction
704 ; lc <- builtin liftingContext
705 ; let viBody = stripLams expr vi
706 -- ; checkTreeAnnM vi expr
707 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
708
709 ; vbody' <- break_loop lc res_ty vbody
710 ; return $ vLams lc vbndrs vbody'
711 }
712 }
713 where
714 stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
715 stripLams _ vi = vi
716
717 maybe_inline n | inline = Inline n
718 | otherwise = DontInline
719
720 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
721 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
722 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
723 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
724 -- in Figure 6 of HtM.
725 break_loop lc ty (ve, le)
726 | loop_breaker
727 = do { empty <- emptyPD ty
728 ; lty <- mkPDataType ty
729 ; return (ve, mkWildCase (Var lc) intPrimTy lty
730 [(DEFAULT, [], le),
731 (LitAlt (mkMachInt 0), [], empty)])
732 }
733 | otherwise = return (ve, le)
734 vectLam _ _ _ _ = panic "vectLam"
735
736 -- Vectorise an algebraic case expression.
737 --
738 -- We convert
739 --
740 -- case e :: t of v { ... }
741 --
742 -- to
743 --
744 -- V: let v' = e in case v' of _ { ... }
745 -- L: let v' = e in case v' `cast` ... of _ { ... }
746 --
747 -- When lifting, we have to do it this way because v must have the type
748 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
749 -- have to handle the case where v is a wild var correctly.
750 --
751
752 -- FIXME: this is too lazy
753 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
754 -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
755 -> VM VExpr
756 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
757 = do
758 vscrut <- vectExpr scrut scrutVit
759 (vty, lty) <- vectAndLiftType ty
760 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
761 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
762
763 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
764 = do
765 vscrut <- vectExpr scrut scrutVit
766 (vty, lty) <- vectAndLiftType ty
767 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
768 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
769
770 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
771 = do
772 (vty, lty) <- vectAndLiftType ty
773 vexpr <- vectExpr scrut scrutVit
774 (vbndr, (vbndrs, (vect_body, lift_body)))
775 <- vect_scrut_bndr
776 . vectBndrsIn bndrs
777 $ vectExpr body altVit
778 let (vect_bndrs, lift_bndrs) = unzip vbndrs
779 (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
780 vect_dc <- maybeV dataConErr (lookupDataCon dc)
781
782 let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
783 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
784
785 return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
786 where
787 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
788 | otherwise = vectBndrIn bndr
789
790 mk_wild_case expr ty dc bndrs body
791 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
792
793 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
794
795 vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
796 = do
797 vect_tc <- maybeV tyConErr (lookupTyCon tycon)
798 (vty, lty) <- vectAndLiftType ty
799
800 let arity = length (tyConDataCons vect_tc)
801 sel_ty <- builtin (selTy arity)
802 sel_bndr <- newLocalVar (fsLit "sel") sel_ty
803 let sel = Var sel_bndr
804
805 (vbndr, valts) <- vect_scrut_bndr
806 $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
807 let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
808
809 vexpr <- vectExpr scrut scrutVit
810 (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
811
812 let (vect_bodies, lift_bodies) = unzip vbodies
813
814 vdummy <- newDummyVar (exprType vect_scrut)
815 ldummy <- newDummyVar (exprType lift_scrut)
816 let vect_case = Case vect_scrut vdummy vty
817 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
818
819 lc <- builtin liftingContext
820 lbody <- combinePD vty (Var lc) sel lift_bodies
821 let lift_case = Case lift_scrut ldummy lty
822 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
823 lbody)]
824
825 return . vLet (vNonRec vbndr vexpr)
826 $ (vect_case, lift_case)
827 where
828 tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
829
830 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
831 | otherwise = vectBndrIn bndr
832
833 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
834
835 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
836 cmp DEFAULT DEFAULT = EQ
837 cmp DEFAULT _ = LT
838 cmp _ DEFAULT = GT
839 cmp _ _ = panic "vectAlgCase/cmp"
840
841 proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
842 = do
843 vect_dc <- maybeV dataConErr (lookupDataCon dc)
844 let ntag = dataConTagZ vect_dc
845 tag = mkDataConTag vect_dc
846 fvs = freeVarsOf body `delVarSetList` bndrs
847
848 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
849 lc <- builtin liftingContext
850 elems <- builtin (selElements arity ntag)
851
852 (vbndrs, vbody)
853 <- vectBndrsIn bndrs
854 . localV
855 $ do
856 binds <- mapM (pack_var (Var lc) sel_tags tag)
857 . filter isLocalId
858 $ varSetElems fvs
859 (ve, le) <- vectExpr body vi
860 return (ve, Case (elems `App` sel) lc lty
861 [(DEFAULT, [], (mkLets (concat binds) le))])
862 -- empty <- emptyPD vty
863 -- return (ve, Case (elems `App` sel) lc lty
864 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
865 -- $ mkLets (concat binds) le),
866 -- (LitAlt (mkMachInt 0), [], empty)])
867 let (vect_bndrs, lift_bndrs) = unzip vbndrs
868 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
869 where
870 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
871
872
873 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
874
875 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
876
877 pack_var len tags t v
878 = do
879 r <- lookupVar v
880 case r of
881 Local (vv, lv) ->
882 do
883 lv' <- cloneVar lv
884 expr <- packByTagPD (idType vv) (Var lv) len tags t
885 updLEnv (\env -> env { local_vars = extendVarEnv
886 (local_vars env) v (vv, lv') })
887 return [(NonRec lv' expr)]
888
889 _ -> return []
890
891 vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ _)
892 = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
893
894
895 -- Support to compute information for vectorisation avoidance ------------------
896
897 -- Annotation for Core AST nodes that describes how they should be handled during vectorisation
898 -- and especially if vectorisation of the corresponding computation can be avoided.
899 --
900 data VectAvoidInfo = VIParr -- tree contains parallel computations
901 | VISimple -- result type is scalar & no parallel subcomputation
902 | VIComplex -- any result type, no parallel subcomputation
903 | VIEncaps -- tree encapsulated by 'liftSimple'
904 deriving (Eq, Show)
905
906 -- Instead of integrating the vectorisation avoidance information into Core expression, we keep
907 -- them in a separate tree (that structurally mirrors the Core expression that it annotates).
908 --
909 data VITree = VITNode VectAvoidInfo [VITree]
910 deriving (Show)
911
912 -- Is any of the tree nodes a 'VIPArr' node?
913 --
914 anyVIPArr :: [VITree] -> Bool
915 anyVIPArr = or . (map (\(VITNode vi _) -> vi == VIParr))
916
917 -- Compute Core annotations to determine for which subexpressions we can avoid vectorisation
918 --
919 -- FIXME: free scalar vars don't actually need to be passed through, since encapsulations makes sure,
920 -- that there are no free variables in encapsulated lambda expressions
921 vectAvoidInfo :: CoreExprWithFVs -> VM VITree
922 vectAvoidInfo ce@(_, AnnVar v)
923 = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
924 ; viTrace ce vi []
925 ; traceVt "vectAvoidInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
926 ; return $ VITNode vi []
927 }
928
929 vectAvoidInfo ce@(_, AnnLit _)
930 = do { vi <- vectAvoidInfoType $ exprType $ deAnnotate ce
931 ; viTrace ce vi []
932 ; traceVt "vectAvoidInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
933 ; return $ VITNode vi []
934 }
935
936 vectAvoidInfo ce@(_, AnnApp e1 e2)
937 = do { vt1 <- vectAvoidInfo e1
938 ; vt2 <- vectAvoidInfo e2
939 ; vi <- if anyVIPArr [vt1, vt2]
940 then return VIParr
941 else vectAvoidInfoType $ exprType $ deAnnotate ce
942 ; viTrace ce vi [vt1, vt2]
943 ; return $ VITNode vi [vt1, vt2]
944 }
945
946 vectAvoidInfo ce@(_, AnnLam _var body)
947 = do { vt@(VITNode vi _) <- vectAvoidInfo body
948 ; viTrace ce vi [vt]
949 ; let resultVI | vi == VIParr = VIParr
950 | otherwise = VIComplex
951 ; return $ VITNode resultVI [vt]
952 }
953
954 vectAvoidInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
955 = do { vtE <- vectAvoidInfo expr
956 ; vtB <- vectAvoidInfo body
957 ; vi <- if anyVIPArr [vtE, vtB]
958 then return VIParr
959 else vectAvoidInfoType $ exprType $ deAnnotate ce
960 ; viTrace ce vi [vtE, vtB]
961 ; return $ VITNode vi [vtE, vtB]
962 }
963
964 vectAvoidInfo ce@(_, AnnLet (AnnRec bnds) body)
965 = do { let (_, exprs) = unzip bnds
966 ; vtBnds <- mapM (\e -> vectAvoidInfo e) exprs
967 ; if (anyVIPArr vtBnds)
968 then do { vtBnds' <- mapM (\e -> vectAvoidInfo e) exprs
969 ; vtB <- vectAvoidInfo body
970 ; return (VITNode VIParr (vtB: vtBnds'))
971 }
972 else do { vtB@(VITNode vib _) <- vectAvoidInfo body
973 ; ni <- if (vib == VIParr)
974 then return VIParr
975 else vectAvoidInfoType $ exprType $ deAnnotate ce
976 ; viTrace ce ni (vtB : vtBnds)
977 ; return $ VITNode ni (vtB : vtBnds)
978 }
979 }
980
981 vectAvoidInfo ce@(_, AnnCase expr _var _ty alts)
982 = do { vtExpr <- vectAvoidInfo expr
983 ; vtAlts <- mapM (\(_, _, e) -> vectAvoidInfo e) alts
984 ; ni <- if anyVIPArr (vtExpr : vtAlts)
985 then return VIParr
986 else vectAvoidInfoType $ exprType $ deAnnotate ce
987 ; viTrace ce ni (vtExpr : vtAlts)
988 ; return $ VITNode ni (vtExpr: vtAlts)
989 }
990
991 vectAvoidInfo (_, AnnCast expr _)
992 = do { vt@(VITNode vi _) <- vectAvoidInfo expr
993 ; return $ VITNode vi [vt]
994 }
995
996 vectAvoidInfo (_, AnnTick _ expr)
997 = do { vt@(VITNode vi _) <- vectAvoidInfo expr
998 ; return $ VITNode vi [vt]
999 }
1000
1001 vectAvoidInfo (_, AnnType {})
1002 = return $ VITNode VISimple []
1003
1004 vectAvoidInfo (_, AnnCoercion {})
1005 = return $ VITNode VISimple []
1006
1007 -- Compute vectorisation avoidance information for a type.
1008 --
1009 vectAvoidInfoType :: Type -> VM VectAvoidInfo
1010 vectAvoidInfoType ty
1011 | maybeParrTy ty = return VIParr
1012 | otherwise
1013 = do { sType <- isSimpleType ty
1014 ; if sType
1015 then return VISimple
1016 else return VIComplex
1017 }
1018
1019 -- Checks whether the type might be a parallel array type. In particular, if the outermost
1020 -- constructor is a type family, we conservatively assume that it may be a parallel array type.
1021 --
1022 maybeParrTy :: Type -> Bool
1023 maybeParrTy ty
1024 | Just ty' <- coreView ty = maybeParrTy ty'
1025 | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
1026 || or (map maybeParrTy ts)
1027 maybeParrTy _ = False
1028
1029 -- FIXME: This should not be hardcoded.
1030 isSimpleType :: Type -> VM Bool
1031 isSimpleType ty
1032 | Just (c, _cs) <- splitTyConApp_maybe ty
1033 = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
1034 {-
1035 = do { globals <- globalScalarTyCons
1036 ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
1037 ; return (elemNameSet (tyConName c) globals )
1038 }
1039 -}
1040 | Nothing <- splitTyConApp_maybe ty
1041 = return False
1042 isSimpleType ty
1043 = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
1044
1045 varsSimple :: VarSet -> VM Bool
1046 varsSimple vs
1047 = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
1048 ; return $ and varTypes
1049 }
1050
1051 viTrace :: CoreExprWithFVs -> VectAvoidInfo -> [VITree] -> VM ()
1052 viTrace ce vi vTs
1053 = traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]")
1054 (ppr $ deAnnotate ce)
1055
1056
1057 {-
1058 ---- Sanity check of the tree, for debugging only
1059 checkTree :: VITree -> CoreExpr -> Bool
1060 checkTree (VITNode _ []) (Type _ty)
1061 = True
1062
1063 checkTree (VITNode _ []) (Var _v)
1064 = True
1065
1066 checkTree (VITNode _ []) (Lit _)
1067 = True
1068
1069 checkTree (VITNode _ [vit]) (Tick _ expr)
1070 = checkTree vit expr
1071
1072 checkTree (VITNode _ [vit]) (Lam _ expr)
1073 = checkTree vit expr
1074
1075 checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2)
1076 = (checkTree vit1 ce1) && (checkTree vit2 ce2)
1077
1078 checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts)
1079 = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
1080 where
1081 checkAlt vt (_, _, expr) = checkTree vt expr
1082
1083 checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2)
1084 = (checkTree vt1 expr1) && (checkTree vt2 expr2)
1085
1086 checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr)
1087 = (and $ zipWith checkBndr vtBnds bndngs) &&
1088 (checkTree vtB expr)
1089 where
1090 checkBndr vt (_, e) = checkTree vt e
1091
1092 checkTree (VITNode _ [vit]) (Cast expr _)
1093 = checkTree vit expr
1094
1095 checkTree _ _ = False
1096
1097 checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM ()
1098 checkTreeAnnM vi e =
1099 if not (checkTree vi $ deAnnotate e)
1100 then error ("checkTreeAnnM : \n " ++ show vi)
1101 else return ()
1102 -}