0764c3b2557d547189e2d3463921d6fcab41a2a0
[ghc.git] / compiler / vectorise / Vectorise / Exp.hs
1 {-# LANGUAGE TupleSections #-}
2
3 -- |Vectorisation of expressions.
4
5 module Vectorise.Exp
6 ( -- * Vectorise polymorphic expressions with special cases for right-hand sides of particular
7 -- variable bindings
8 vectPolyExpr
9 , vectDictExpr
10 , vectScalarFun
11 , vectScalarDFun
12 )
13 where
14
15 #include "HsVersions.h"
16
17 import Vectorise.Type.Type
18 import Vectorise.Var
19 import Vectorise.Convert
20 import Vectorise.Vect
21 import Vectorise.Env
22 import Vectorise.Monad
23 import Vectorise.Builtins
24 import Vectorise.Utils
25
26 import CoreUtils
27 import MkCore
28 import CoreSyn
29 import CoreFVs
30 import Class
31 import DataCon
32 import TyCon
33 import TcType
34 import Type
35 import PrelNames
36 import Var
37 import VarEnv
38 import VarSet
39 import Id
40 import BasicTypes( isStrongLoopBreaker )
41 import Literal
42 import TysWiredIn
43 import TysPrim
44 import Outputable
45 import FastString
46 import Control.Monad
47 import Control.Applicative
48 import Data.Maybe
49 import Data.List
50 import TcRnMonad (doptM)
51 import DynFlags (DynFlag(Opt_AvoidVect))
52
53
54 -- For prototyping, the VITree is a separate data structure with the same shape as the corresponding expression
55 -- tree. This will become part of the annotation
56
57 data VectInfo = VIParr
58 | VISimple
59 | VIComplex
60 | VIEncaps
61 deriving (Eq, Show)
62
63 data VITree = VITNode VectInfo [VITree]
64 deriving (Show)
65
66 viTrace :: CoreExprWithFVs -> VectInfo -> [VITree] -> VM ()
67 viTrace ce vi vTs =
68 -- return ()
69 traceVt ("vitrace " ++ (show vi) ++ "[" ++ (concat $ map (\(VITNode vi _) -> show vi ++ " ") vTs) ++"]") (ppr $ deAnnotate ce)
70
71 viOr :: [VITree] -> Bool
72 viOr = or . (map (\(VITNode vi _) -> vi == VIParr))
73
74 -- TODO: free scalar vars don't actually need to be passed through, since encapsulations makes sure, that there are
75 -- no free variables in encapsulated lambda expressions
76 vectInfo:: CoreExprWithFVs -> VM VITree
77 vectInfo ce@(_, AnnVar v)
78 = do { vi <- vectInfoType $ exprType $ deAnnotate ce
79 ; viTrace ce vi []
80 ; traceVt "vectInfo AnnVar" ((ppr v) <+> (ppr $ exprType $ deAnnotate ce))
81 ; return $ VITNode vi []
82 }
83
84 vectInfo ce@(_, AnnLit _)
85 = do { vi <- vectInfoType $ exprType $ deAnnotate ce
86 ; viTrace ce vi []
87 ; traceVt "vectInfo AnnLit" (ppr $ exprType $ deAnnotate ce)
88 ; return $ VITNode vi []
89 }
90
91 vectInfo ce@(_, AnnApp e1 e2)
92 = do { vt1 <- vectInfo e1
93 ; vt2 <- vectInfo e2
94 ; vi <- if viOr [vt1, vt2]
95 then return VIParr
96 else vectInfoType $ exprType $ deAnnotate ce
97 ; viTrace ce vi [vt1, vt2]
98 ; return $ VITNode vi [vt1, vt2]
99 }
100
101 vectInfo ce@(_, AnnLam _var body)
102 = do { vt@(VITNode vi _) <- vectInfo body
103 ; viTrace ce vi [vt]
104 ; if (vi == VIParr)
105 then return $ VITNode vi [vt]
106 else return $ VITNode VIComplex [vt]
107 }
108
109 vectInfo ce@(_, AnnLet (AnnNonRec _var expr) body)
110 = do { vtE <- vectInfo expr
111 ; vtB <- vectInfo body
112 ; vi <- if viOr [vtE, vtB]
113 then return VIParr
114 else vectInfoType $ exprType $ deAnnotate ce
115 ; viTrace ce vi [vtE, vtB]
116 ; return $ VITNode vi [vtE, vtB]
117 }
118
119 vectInfo ce@(_, AnnLet (AnnRec bnds) body)
120 = do { let (_, exprs) = unzip bnds
121 ; vtBnds <- mapM (\e -> vectInfo e) exprs
122 ; if (viOr vtBnds)
123 then do { vtBnds' <- mapM (\e -> vectInfo e) exprs
124 ; vtB <- vectInfo body
125 ; return (VITNode VIParr (vtB: vtBnds'))
126 }
127 else do { vtB@(VITNode vib _) <- vectInfo body
128 ; ni <- if (vib == VIParr)
129 then return VIParr
130 else vectInfoType $ exprType $ deAnnotate ce
131 ; viTrace ce ni (vtB : vtBnds)
132 ; return $ VITNode ni (vtB : vtBnds)
133 }
134 }
135
136 vectInfo ce@(_, AnnCase expr _var _ty alts)
137 = do { vtExpr <- vectInfo expr
138 ; vtAlts <- mapM (\(_, _, e) -> vectInfo e) alts
139 ; ni <- if viOr (vtExpr : vtAlts)
140 then return VIParr
141 else vectInfoType $ exprType $ deAnnotate ce
142 ; viTrace ce ni (vtExpr : vtAlts)
143 ; return $ VITNode ni (vtExpr: vtAlts)
144 }
145
146
147 vectInfo (_, AnnCast expr _)
148 = do { vt@(VITNode vi _) <- vectInfo expr
149 ; return $ VITNode vi [vt]
150 }
151
152 vectInfo (_, AnnTick _ expr )
153 = do { vt@(VITNode vi _) <- vectInfo expr
154 ; return $ VITNode vi [vt]
155 }
156
157 vectInfo (_, AnnType {})
158 = return $ VITNode VISimple []
159
160 vectInfo (_, AnnCoercion {})
161 = return $ VITNode VISimple []
162
163
164
165 vectInfoType:: Type -> VM VectInfo
166 vectInfoType ty
167 | maybeParrTy ty = return VIParr
168 | otherwise
169 = do { sType <- isSimpleType ty
170 ; if sType
171 then return VISimple
172 else return VIComplex
173 }
174
175
176 -- Checks whether the type might be a parallel array type. In particular, if the outermost
177 -- constructor is a type family, we conservatively assume that it may be a parallel array type.
178 maybeParrTy :: Type -> Bool
179 maybeParrTy ty
180 | Just ty' <- coreView ty = maybeParrTy ty'
181 | Just (tyCon, ts) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
182 || or (map maybeParrTy ts)
183 maybeParrTy _ = False
184
185
186 isSimpleType:: Type -> VM Bool
187 isSimpleType ty
188 | Just (c, _cs) <- splitTyConApp_maybe ty = return $ (tyConName c) `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
189 {-
190 = do { globals <- globalScalarTyCons
191 ; traceVt ("isSimpleType " ++ (show (elemNameSet (tyConName c) globals ))) (ppr c)
192 ; return (elemNameSet (tyConName c) globals )
193 }
194 -}
195 | Nothing <- splitTyConApp_maybe ty
196 = return False
197 isSimpleType ty
198 = pprPanic "Vectorise.Exp.isSimpleType not handled" (ppr ty)
199
200 varsSimple :: VarSet -> VM Bool
201 varsSimple vs
202 = do { varTypes <- mapM isSimpleType $ map varType $ varSetElems vs
203 ; return $ and varTypes
204 }
205
206
207 -- | Vectorise a polymorphic expression.
208 vectPolyExpr:: Bool -> [Var] -> CoreExprWithFVs
209 -> VM (Inline, Bool, VExpr)
210 vectPolyExpr loop_breaker recFns (_, AnnTick tickish expr)
211 = do { (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
212 ; return (inline, isScalarFn, vTick tickish expr')
213 }
214
215
216
217 vectPolyExpr loop_breaker recFns expr
218 = do { vectAvoidance <- liftDs $ doptM Opt_AvoidVect
219 ; vi <- vectInfo expr
220 ; ((tvs, mono), vi') <-
221 if vectAvoidance
222 then do { (extExpr, vi') <- encapsulateScalar vi expr
223 ; traceVt "vectPolyExpr extended:" (ppr $ deAnnotate extExpr)
224 ; return $ (collectAnnTypeBinders extExpr , vi')
225 }
226 else return $ (collectAnnTypeBinders expr, vi)
227 ; arity <- polyArity tvs
228 ; polyAbstract tvs $ \args ->
229 do {(inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi'
230 ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
231 }
232 }
233
234 -- todo: clean this
235
236 vectPolyExprVT:: Bool -> [Var] -> CoreExprWithFVs -> VITree
237 -> VM (Inline, Bool, VExpr)
238
239 -- vectPolyExprVT _loop_breaker _recFns e vi | not (checkTree vi (deAnnotate e))
240 -- = pprPanic "vectPolyExprVT" (ppr $ deAnnotate e)
241 vectPolyExprVT loop_breaker recFns (_, AnnTick tickish expr) (VITNode _ [vit])
242 = do { (inline, isScalarFn, expr') <- vectPolyExprVT loop_breaker recFns expr vit
243 ; return (inline, isScalarFn, vTick tickish expr')
244 }
245
246 vectPolyExprVT loop_breaker recFns expr vi
247 = do { -- checkTreeAnnM vi expr ;
248 let (tvs, mono) = collectAnnTypeBinders expr
249 ; arity <- polyArity tvs
250 ; polyAbstract tvs $ \args ->
251 do { (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono vi
252 ; return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono')
253 }
254 }
255
256 -- | encapsulate every purely sequential subexpression with a simple return type
257 -- of a (potentially) parallel expression into a lambda abstraction over all its
258 -- free variables followed by the corresponding application to those variables.
259 -- Condition:
260 -- all free variables and the result type must be of `simple' type
261 -- the expression is 'complex enough', which is, for now, every expression
262 -- which is not constant and contains at least one operation.
263 --
264 encapsulateScalar :: VITree -> CoreExprWithFVs -> VM (CoreExprWithFVs, VITree)
265 encapsulateScalar vit ce@(_, AnnType _ty)
266 = return (ce, vit)
267
268 encapsulateScalar vit ce@(_, AnnVar _v)
269 = return (ce, vit)
270
271 encapsulateScalar vit ce@(_, AnnLit _)
272 = return (ce, vit)
273
274
275 encapsulateScalar (VITNode vi [vit]) (fvs, AnnTick tck expr)
276 = do { (extExpr, vit') <- encapsulateScalar vit expr
277 ; return ((fvs, AnnTick tck extExpr), VITNode vi [vit'])
278 }
279
280 encapsulateScalar _ (_fvs, AnnTick _tck _expr)
281 = panic "encapsulateScalar AnnTick doesn't match up"
282
283 encapsulateScalar (VITNode vi [vit]) ce@(fvs, AnnLam bndr expr)
284 = do { varsS <- varsSimple fvs
285 ; case (vi, varsS) of
286 (VISimple, True) -> do { let (e', vit') = encaps vit ce
287 ; return (e', vit')
288 }
289 _ -> do { (extExpr, vit') <- encapsulateScalar vit expr
290 ; return ((fvs, AnnLam bndr extExpr), VITNode vi [vit'])
291 }
292 }
293
294 encapsulateScalar _ (_fvs, AnnLam _bndr _expr)
295 = panic "encapsulateScalar AnnLam doesn't match up"
296
297
298 encapsulateScalar vt@(VITNode vi [vit1, vit2]) ce@(fvs, AnnApp ce1 ce2)
299 = do { varsS <- varsSimple fvs
300 ; case (vi, varsS) of
301 (VISimple, True) -> do { let (e', vt') = encaps vt ce
302 -- ; checkTreeAnnM vt' e'
303 -- ; traceVt "Passed checkTree test!!" (ppr $ deAnnotate e')
304 ; return (e', vt')
305 }
306 _ -> do { (etaCe1, vit1') <- encapsulateScalar vit1 ce1
307 ; (etaCe2, vit2') <- encapsulateScalar vit2 ce2
308 ; return ((fvs, AnnApp etaCe1 etaCe2), VITNode vi [vit1', vit2'])
309 }
310 }
311 encapsulateScalar _ (_fvs, AnnApp _ce1 _ce2)
312 = panic "encapsulateScalar AnnApp doesn't match up"
313
314 encapsulateScalar vt@(VITNode vi (scrutVit : altVits)) ce@(fvs, AnnCase scrut bndr ty alts)
315 = do { varsS <- varsSimple fvs
316 ; case (vi, varsS) of
317 (VISimple, True) -> return $ encaps vt ce
318 _ -> do { (extScrut, scrutVit') <- encapsulateScalar scrutVit scrut
319 ; extAltsVits <- zipWithM expAlt altVits alts
320 ; let (extAlts, altVits') = unzip extAltsVits
321 ; return ((fvs, AnnCase extScrut bndr ty extAlts), VITNode vi (scrutVit': altVits'))
322 }
323 }
324 where
325 expAlt vt (con, bndrs, expr)
326 = do { (extExpr, vt') <- encapsulateScalar vt expr
327 ; return ((con, bndrs, extExpr), vt')
328 }
329
330 encapsulateScalar _ (_fvs, AnnCase _scrut _bndr _ty _alts)
331 = panic "encapsulateScalar AnnCase doesn't match up"
332
333 encapsulateScalar vt@(VITNode vi [vt1, vt2]) ce@(fvs, AnnLet (AnnNonRec bndr expr1) expr2)
334 = do { varsS <- varsSimple fvs
335 ; case (vi, varsS) of
336 (VISimple, True) -> return $ encaps vt ce
337 _ -> do { (extExpr1, vt1') <- encapsulateScalar vt1 expr1
338 ; (extExpr2, vt2') <- encapsulateScalar vt2 expr2
339 ; return ((fvs, AnnLet (AnnNonRec bndr extExpr1) extExpr2), VITNode vi [vt1', vt2'])
340 }
341 }
342
343 encapsulateScalar _ (_fvs, AnnLet (AnnNonRec _bndr _expr1) _expr2)
344 = panic "encapsulateScalar AnnLet nonrec doesn't match up"
345
346 encapsulateScalar vt@(VITNode vi (vtB : vtBnds)) ce@(fvs, AnnLet (AnnRec bndngs) expr)
347 = do { varsS <- varsSimple fvs
348 ; case (vi, varsS) of
349 (VISimple, True) -> return $ encaps vt ce
350 _ -> do { extBndsVts <- zipWithM expBndg vtBnds bndngs
351 ; let (extBnds, vtBnds') = unzip extBndsVts
352 ; (extExpr, vtB') <- encapsulateScalar vtB expr
353 ; let vt' = VITNode vi (vtB':vtBnds')
354 ; return ((fvs, AnnLet (AnnRec extBnds) extExpr), vt')
355 }
356 }
357 where
358 expBndg vit (bndr, expr)
359 = do { (extExpr, vit') <- encapsulateScalar vit expr
360 ; return ((bndr, extExpr), vit')
361 }
362
363 encapsulateScalar _ (_fvs, AnnLet (AnnRec _) _expr2)
364 = panic "encapsulateScalar AnnLet rec doesn't match up"
365
366
367
368 encapsulateScalar (VITNode vi [vit]) (fvs, AnnCast expr coercion)
369 = do { (extExpr, vit') <- encapsulateScalar vit expr
370 ; return ((fvs, AnnCast extExpr coercion), VITNode vi [vit'])
371 }
372
373 encapsulateScalar _ (_fvs, AnnCast _expr _coercion)
374 = panic "encapsulateScalar AnnCast rec doesn't match up"
375
376
377 encapsulateScalar _ _
378 = panic "encapsulateScalar case not handled"
379
380
381
382
383 -- CoreExprWithFVs, -- = AnnExpr Id VarSet
384 -- AnnExpr bndr VarSet = (annot, AnnExpr' bndr VarSet)
385 -- AnnLam :: bndr -> (AnnExpr bndr VarSet) -> AnnExpr' bndr VarSet
386 -- AnnLam bndr (AnnExpr bndr annot)
387 encaps :: VITree -> CoreExprWithFVs -> (CoreExprWithFVs, VITree)
388 encaps (VITNode vi (scrutVit : altVits)) (fvs, AnnCase expr bndr t alts)
389 | Just (c,_) <- splitTyConApp_maybe (exprType $ deAnnotate $ expr),
390 (not $ elem c [boolTyCon, intTyCon, doubleTyCon, floatTyCon]) -- TODO: globalScalarTyCons
391 = ((fvs, AnnCase expr bndr t alts'), VITNode vi (scrutVit : altVits'))
392
393 where
394 (alts', altVits') = unzip $ map (\(ac,bndrs, (alt, avi)) -> ((ac,bndrs,alt), avi)) $
395 zipWith (\(ac, bndrs, aex) -> \altVi -> (ac, bndrs, encaps altVi aex)) alts altVits
396
397 encaps viTree ae@(fvs, _annEx)
398 = (mkAnnApps (mkAnnLams ae vars) vars, viTree')
399 where
400 mkViTreeLams (VITNode _ vits) [] = VITNode VIEncaps vits
401 mkViTreeLams vi (_:vs) = VITNode VIEncaps [mkViTreeLams vi vs]
402
403 mkViTreeApps vi [] = vi
404 mkViTreeApps vi (_:vs) = VITNode VISimple [mkViTreeApps vi vs, VITNode VISimple []]
405
406 vars = varSetElems fvs
407 viTree' = mkViTreeApps (mkViTreeLams viTree vars) vars
408
409 mkAnnLam :: bndr -> AnnExpr bndr VarSet -> AnnExpr' bndr VarSet
410 mkAnnLam bndr ce = AnnLam bndr ce
411
412 mkAnnLams:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
413 mkAnnLams (fv, aex') [] = (fv, aex') -- fv should be empty. check!
414 mkAnnLams (fv, aex') (v:vs) = mkAnnLams (delVarSet fv v, (mkAnnLam v ((delVarSet fv v), aex'))) vs
415
416 mkAnnApp :: (AnnExpr bndr VarSet) -> Var -> (AnnExpr' bndr VarSet)
417 mkAnnApp aex v = AnnApp aex (unitVarSet v, (AnnVar v))
418
419 mkAnnApps:: CoreExprWithFVs -> [Var] -> CoreExprWithFVs
420 mkAnnApps (fv, aex') [] = (fv, aex')
421 mkAnnApps ae (v:vs) =
422 let
423 (fv, aex') = mkAnnApps ae vs
424 in (extendVarSet fv v, mkAnnApp (fv, aex') v)
425
426
427
428
429
430 -- |Vectorise an expression.
431 --
432 vectExpr :: CoreExprWithFVs -> VITree -> VM VExpr
433 -- vectExpr e vi | not (checkTree vi (deAnnotate e))
434 -- = pprPanic "vectExpr" (ppr $ deAnnotate e)
435
436 vectExpr (_, AnnVar v) _
437 = vectVar v
438
439 vectExpr (_, AnnLit lit) _
440 = vectConst $ Lit lit
441
442 vectExpr e@(_, AnnLam bndr _) vt
443 | isId bndr = (\(_, _, ve) -> ve) <$> vectFnExpr True False [] e vt
444
445 -- SPECIAL CASE: Vectorise/lift 'patError @ ty err' by only vectorising/lifting the type 'ty';
446 -- its only purpose is to abort the program, but we need to adjust the type to keep CoreLint
447 -- happy.
448 -- FIXME: can't be do this with a VECTORISE pragma on 'pAT_ERROR_ID' now?
449 vectExpr (_, AnnApp (_, AnnApp (_, AnnVar v) (_, AnnType ty)) err) _
450 | v == pAT_ERROR_ID
451 = do { (vty, lty) <- vectAndLiftType ty
452 ; return (mkCoreApps (Var v) [Type vty, err'], mkCoreApps (Var v) [Type lty, err'])
453 }
454 where
455 err' = deAnnotate err
456
457 -- type application (handle multiple consecutive type applications simultaneously to ensure the
458 -- PA dictionaries are put at the right places)
459 vectExpr e@(_, AnnApp _ arg) (VITNode _ [_, _])
460 | isAnnTypeArg arg
461 = vectPolyApp e
462
463 -- 'Int', 'Float', or 'Double' literal
464 -- FIXME: this needs to be generalised
465 vectExpr (_, AnnApp (_, AnnVar v) (_, AnnLit lit)) _
466 | Just con <- isDataConId_maybe v
467 , is_special_con con
468 = do
469 let vexpr = App (Var v) (Lit lit)
470 lexpr <- liftPD vexpr
471 return (vexpr, lexpr)
472 where
473 is_special_con con = con `elem` [intDataCon, floatDataCon, doubleDataCon]
474
475 -- value application (dictionary or user value)
476 vectExpr e@(_, AnnApp fn arg) (VITNode _ [vit1, vit2])
477 | isPredTy arg_ty -- dictionary application (whose result is not a dictionary)
478 = vectPolyApp e
479 | otherwise -- user value
480 = do { -- vectorise the types
481 ; varg_ty <- vectType arg_ty
482 ; vres_ty <- vectType res_ty
483
484 -- vectorise the function and argument expression
485 ; vfn <- vectExpr fn vit1
486 ; varg <- vectExpr arg vit2
487
488 -- the vectorised function is a closure; apply it to the vectorised argument
489 ; mkClosureApp varg_ty vres_ty vfn varg
490 }
491 where
492 (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
493
494 vectExpr (_, AnnCase scrut bndr ty alts) vt
495 | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
496 , isAlgTyCon tycon
497 = vectAlgCase tycon ty_args scrut bndr ty alts vt
498 | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
499 where
500 scrut_ty = exprType (deAnnotate scrut)
501
502 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) (VITNode _ [vt1, vt2])
503 = do
504 vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExprVT False [] rhs vt1
505 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body vt2)
506 return $ vLet (vNonRec vbndr vrhs) vbody
507
508 vectExpr (_, AnnLet (AnnRec bs) body) (VITNode _ (vtB : vtBnds))
509 = do
510 (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
511 $ liftM2 (,)
512 (zipWith3M vect_rhs bndrs rhss vtBnds)
513 (vectExpr body vtB)
514 return $ vLet (vRec vbndrs vrhss) vbody
515 where
516 (bndrs, rhss) = unzip bs
517
518 vect_rhs bndr rhs vt = localV
519 . inBind bndr
520 . liftM (\(_,_,z)->z)
521 $ vectPolyExprVT (isStrongLoopBreaker $ idOccInfo bndr) [] rhs vt
522 zipWith3M f xs ys zs = zipWithM (\x -> \(y,z) -> (f x y z)) xs (zip ys zs)
523
524 vectExpr (_, AnnTick tickish expr) (VITNode _ [vit])
525 = liftM (vTick tickish) (vectExpr expr vit)
526
527 vectExpr (_, AnnType ty) _
528 = liftM vType (vectType ty)
529
530 vectExpr e _ = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
531
532 -- |Vectorise an expression that *may* have an outer lambda abstraction.
533 --
534 -- We do not handle type variables at this point, as they will already have been stripped off by
535 -- 'vectPolyExpr'. We also only have to worry about one set of dictionary arguments as we (1) only
536 -- deal with Haskell 2011 and (2) class selectors are vectorised elsewhere.
537 --
538 vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
539 -- be inlined
540 -> Bool -- ^ Whether the binding is a loop breaker
541 -> [Var] -- ^ Names of function in same recursive binding group
542 -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
543 -> VITree
544 -> VM (Inline, Bool, VExpr)
545
546 -- vectFnExpr _ _ _ e vi | not (checkTree vi (deAnnotate e))
547 -- = pprPanic "vectFnExpr" (ppr $ deAnnotate e)
548
549
550 vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr body) vt@(VITNode _ [vt'])
551 -- predicate abstraction: leave as a normal abstraction, but vectorise the predicate type
552 | isId bndr
553 && isPredTy (idType bndr)
554 = do { vBndr <- vectBndr bndr
555 ; (inline, isScalarFn, vbody) <- vectFnExpr inline loop_breaker recFns body vt'
556 ; return (inline, isScalarFn, mapVect (mkLams [vectorised vBndr]) vbody)
557 }
558 -- non-predicate abstraction: vectorise (try to vectorise as a scalar computation)
559 | isId bndr
560 = mark DontInline True (vectScalarFunVT False recFns (deAnnotate expr) vt)
561 `orElseV`
562 mark inlineMe False (vectLam inline loop_breaker expr vt)
563 vectFnExpr _ _ _ e vt
564 -- not an abstraction: vectorise as a vanilla expression
565 = mark DontInline False $ vectExpr e vt
566
567 mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
568 mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
569
570 -- |Vectorise type and dictionary applications.
571 --
572 -- These are always headed by a variable (as we don't support higher-rank polymorphism), but may
573 -- involve two sets of type variables and dictionaries. Consider,
574 --
575 -- > class C a where
576 -- > m :: D b => b -> a
577 --
578 -- The type of 'm' is 'm :: forall a. C a => forall b. D b => b -> a'.
579 --
580 vectPolyApp :: CoreExprWithFVs -> VM VExpr
581 vectPolyApp e0
582 = case e4 of
583 (_, AnnVar var)
584 -> do { -- get the vectorised form of the variable
585 ; vVar <- lookupVar var
586 ; traceVt "vectPolyApp of" (ppr var)
587
588 -- vectorise type and dictionary arguments
589 ; vDictsOuter <- mapM vectDictExpr (map deAnnotate dictsOuter)
590 ; vDictsInner <- mapM vectDictExpr (map deAnnotate dictsInner)
591 ; vTysOuter <- mapM vectType tysOuter
592 ; vTysInner <- mapM vectType tysInner
593
594 ; let reconstructOuter v = (`mkApps` vDictsOuter) <$> polyApply v vTysOuter
595
596 ; case vVar of
597 Local (vv, lv)
598 -> do { MASSERT( null dictsInner ) -- local vars cannot be class selectors
599 ; traceVt " LOCAL" (text "")
600 ; (,) <$> reconstructOuter (Var vv) <*> reconstructOuter (Var lv)
601 }
602 Global vv
603 | isDictComp var -- dictionary computation
604 -> do { -- in a dictionary computation, the innermost, non-empty set of
605 -- arguments are non-vectorised arguments, where no 'PA'dictionaries
606 -- are needed for the type variables
607 ; ve <- if null dictsInner
608 then
609 return $ Var vv `mkTyApps` vTysOuter `mkApps` vDictsOuter
610 else
611 reconstructOuter
612 (Var vv `mkTyApps` vTysInner `mkApps` vDictsInner)
613 ; traceVt " GLOBAL (dict):" (ppr ve)
614 ; vectConst ve
615 }
616 | otherwise -- non-dictionary computation
617 -> do { MASSERT( null dictsInner )
618 ; ve <- reconstructOuter (Var vv)
619 ; traceVt " GLOBAL (non-dict):" (ppr ve)
620 ; vectConst ve
621 }
622 }
623 _ -> pprSorry "Cannot vectorise programs with higher-rank types:" (ppr . deAnnotate $ e0)
624 where
625 -- if there is only one set of variables or dictionaries, it will be the outer set
626 (e1, dictsOuter) = collectAnnDictArgs e0
627 (e2, tysOuter) = collectAnnTypeArgs e1
628 (e3, dictsInner) = collectAnnDictArgs e2
629 (e4, tysInner) = collectAnnTypeArgs e3
630 --
631 isDictComp var = (isJust . isClassOpId_maybe $ var) || isDFunId var
632
633 -- |Vectorise the body of a dfun.
634 --
635 -- Dictionary computations are special for the following reasons. The application of dictionary
636 -- functions are always saturated, so there is no need to create closures. Dictionary computations
637 -- don't depend on array values, so they are always scalar computations whose result we can
638 -- replicate (instead of executing them in parallel).
639 --
640 -- NB: To keep things simple, we are not rewriting any of the bindings introduced in a dictionary
641 -- computation. Consequently, the variable case needs to deal with cases where binders are
642 -- in the vectoriser environments and where that is not the case.
643 --
644 vectDictExpr :: CoreExpr -> VM CoreExpr
645 vectDictExpr (Var var)
646 = do { mb_scope <- lookupVar_maybe var
647 ; case mb_scope of
648 Nothing -> return $ Var var -- binder from within the dict. computation
649 Just (Local (vVar, _)) -> return $ Var vVar -- local vectorised variable
650 Just (Global vVar) -> return $ Var vVar -- global vectorised variable
651 }
652 vectDictExpr (Lit lit)
653 = pprPanic "Vectorise.Exp.vectDictExpr: literal in dictionary computation" (ppr lit)
654 vectDictExpr (Lam bndr e)
655 = Lam bndr <$> vectDictExpr e
656 vectDictExpr (App fn arg)
657 = App <$> vectDictExpr fn <*> vectDictExpr arg
658 vectDictExpr (Case e bndr ty alts)
659 = Case <$> vectDictExpr e <*> pure bndr <*> vectType ty <*> mapM vectDictAlt alts
660 where
661 vectDictAlt (con, bs, e) = (,,) <$> vectDictAltCon con <*> pure bs <*> vectDictExpr e
662 --
663 vectDictAltCon (DataAlt datacon) = DataAlt <$> maybeV dataConErr (lookupDataCon datacon)
664 where
665 dataConErr = ptext (sLit "Cannot vectorise data constructor:") <+> ppr datacon
666 vectDictAltCon (LitAlt lit) = return $ LitAlt lit
667 vectDictAltCon DEFAULT = return DEFAULT
668 vectDictExpr (Let bnd body)
669 = Let <$> vectDictBind bnd <*> vectDictExpr body
670 where
671 vectDictBind (NonRec bndr e) = NonRec bndr <$> vectDictExpr e
672 vectDictBind (Rec bnds) = Rec <$> mapM (\(bndr, e) -> (bndr,) <$> vectDictExpr e) bnds
673 vectDictExpr e@(Cast _e _coe)
674 = pprSorry "Vectorise.Exp.vectDictExpr: cast" (ppr e)
675 vectDictExpr (Tick tickish e)
676 = Tick tickish <$> vectDictExpr e
677 vectDictExpr (Type ty)
678 = Type <$> vectType ty
679 vectDictExpr (Coercion coe)
680 = pprSorry "Vectorise.Exp.vectDictExpr: coercion" (ppr coe)
681
682 -- |Vectorise an expression of functional type, where all arguments and the result are of primitive
683 -- types (i.e., 'Int', 'Float', 'Double' etc., which have instances of the 'Scalar' type class) and
684 -- which does not contain any subcomputations that involve parallel arrays. Such functionals do not
685 -- require the full blown vectorisation transformation; instead, they can be lifted by application
686 -- of a member of the zipWith family (i.e., 'map', 'zipWith', zipWith3', etc.)
687 --
688 -- Dictionary functions are also scalar functions (as dictionaries themselves are not vectorised,
689 -- instead they become dictionaries of vectorised methods). We treat them differently, though see
690 -- "Note [Scalar dfuns]" in 'Vectorise'.
691 --
692 vectScalarFun :: [Var] -- ^ Functions names in same recursive binding group
693 -> CoreExpr -- ^ Expression to be vectorised
694 -> VM VExpr
695 vectScalarFun recFns expr
696 -- this is an external call to vectScalarFun, so we pass a dummy vt tree. The only
697 -- relevant bit is that the node info is *not* VIEncaps
698 = vectScalarFunVT True recFns expr (VITNode VISimple [])
699
700
701 vectScalarFunVT :: Bool -- ^ Was the function marked as scalar by the user?
702 -> [Var] -- ^ Functions names in same recursive binding group
703 -> CoreExpr -- ^ Expression to be vectorised
704 -> VITree
705 -> VM VExpr
706 vectScalarFunVT forceScalar recFns expr (VITNode vi _)
707 = do { gscalarVars <- globalScalarVars
708 ; scalarTyCons <- globalScalarTyCons
709 ; let scalarVars = gscalarVars `extendVarSetList` recFns
710 (arg_tys, res_ty) = splitFunTys (exprType expr)
711 ; MASSERT( not $ null arg_tys )
712 ; traceVt ("vectScalarFun - not scalar? " ++
713 "\n\tall tycons scalar? : " ++ (show $all (is_scalar_ty scalarTyCons) arg_tys) ++
714 "\n\tresult scalar? : " ++ (show $is_scalar_ty scalarTyCons res_ty) ++
715 "\n\tscalar body? : " ++ (show $is_scalar scalarVars (is_scalar_ty scalarTyCons) expr) ++
716 "\n\tuses vars? : " ++ (show $uses scalarVars expr) ++
717 "\n\t is encaps? (same as & of all prev cond): " ++ (show vi)
718 )
719 (ppr expr)
720 ; onlyIfV (ptext (sLit "not a scalar function"))
721 (forceScalar -- user asserts the functions is scalar
722 ||
723 (vi == VIEncaps)) -- should only be true if all the foll. cond are hold
724
725 {- ||
726 all (is_scalar_ty scalarTyCons) arg_tys -- check whether the function is scalar
727 && is_scalar_ty scalarTyCons res_ty
728 && is_scalar scalarVars (is_scalar_ty scalarTyCons) expr
729 && uses scalarVars expr)
730 -}
731 $ do { traceVt "vectScalarFun - is scalar" (ppr expr)
732 ; mkScalarFun arg_tys res_ty expr
733 }
734 }
735 where
736 {-
737 -- !!!FIXME: We would like to allow scalar functions with arguments and results that can be
738 -- any 'scalarTyCons', but can't at the moment, as those argument and result types
739 -- need to be members of the 'Scalar' class (that in its current form would better
740 -- be called 'Primitive'). *ALSO* the hardcoded list of types is ugly!
741 -}
742 is_scalar_ty _scalarTyCons ty
743 | isPredTy ty -- dictionaries never get into the environment
744 = True
745 | Just (tycon, []) <- splitTyConApp_maybe ty -- TODO: FIX THIS!
746 = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
747 -- FIXME: = tyConName tycon `elemNameSet` scalarTyCons
748 | Just (tycon, _) <- splitTyConApp_maybe ty
749 = tyConName tycon `elem` [boolTyConName, intTyConName, word8TyConName, doubleTyConName, floatTyConName]
750
751 -- FIXME: = tyConName tycon `elemNameSet` scalarTyCons
752 | otherwise
753 = False
754
755 -- Checks whether an expression contain a non-scalar subexpression.
756 --
757 -- Precodition: The variables in the first argument are scalar.
758 --
759 -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
760 -- them to the list of scalar variables) and then check them. If one of them turns out not to
761 -- be scalar, the entire group is regarded as not being scalar.
762 --
763 -- The second argument is a predicate that checks whether a type is scalar.
764 --
765 is_scalar :: VarSet -> (Type -> Bool) -> CoreExpr -> Bool
766 is_scalar scalars _isScalarTC (Var v) =
767 v `elemVarSet` scalars
768 is_scalar _scalars _isScalarTC (Lit _) = True
769 is_scalar scalars isScalarTC (App e1 e2) = is_scalar scalars isScalarTC e1 &&
770 is_scalar scalars isScalarTC e2
771 is_scalar scalars isScalarTC (Lam var body)
772 | maybe_parr_ty (varType var) = False
773 | otherwise = is_scalar (scalars `extendVarSet` var)
774 isScalarTC body
775 is_scalar scalars isScalarTC (Let bind body) = trace ("is_scalar LET " ++ (show bindsAreScalar ) ++
776 " " ++ (show $ is_scalar scalars' isScalarTC body) ++
777 (show $ showSDoc $ ppr bind)) $
778 bindsAreScalar &&
779 is_scalar scalars' isScalarTC body
780 where
781 (bindsAreScalar, scalars') = is_scalar_bind scalars isScalarTC bind
782 is_scalar scalars isScalarTC (Case e var ty alts)
783 | isScalarTC ty = is_scalar scalars' isScalarTC e &&
784 all (is_scalar_alt scalars' isScalarTC) alts
785 | otherwise = False
786 where
787 scalars' = scalars `extendVarSet` var
788 is_scalar scalars isScalarTC (Cast e _coe) = is_scalar scalars isScalarTC e
789 is_scalar scalars isScalarTC (Tick _ e ) = is_scalar scalars isScalarTC e
790 is_scalar _scalars _isScalarTC (Type {}) = True
791 is_scalar _scalars _isScalarTC (Coercion {}) = True
792
793 -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
794 is_scalar_bind scalars isScalarTCs (NonRec var e) = (is_scalar scalars isScalarTCs e,
795 scalars `extendVarSet` var)
796 is_scalar_bind scalars isScalarTCs (Rec bnds) = (all (is_scalar scalars' isScalarTCs) es,
797 scalars')
798 where
799 (vars, es) = unzip bnds
800 scalars' = scalars `extendVarSetList` vars
801
802 is_scalar_alt scalars isScalarTCs (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars)
803 isScalarTCs e
804
805 -- Checks whether the type might be a parallel array type. In particular, if the outermost
806 -- constructor is a type family, we conservatively assume that it may be a parallel array type.
807 maybe_parr_ty :: Type -> Bool
808 maybe_parr_ty ty
809 | Just ty' <- coreView ty = maybe_parr_ty ty'
810 | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
811 maybe_parr_ty _ = False
812
813 -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions
814 -- is called by some other function that is otherwise scalar, it would be very bad
815 -- that just this call to the identity makes it not be scalar.
816 -- A scalar function has to actually compute something. Without the check,
817 -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
818 -- (map (\x -> x)) which is very bad. Normal lifting transforms it to
819 -- (\n# x -> x) which is what we want.
820 uses funs (Var v) = v `elemVarSet` funs
821 uses funs (App e1 e2) = uses funs e1 || uses funs e2
822 uses funs (Lam b body) = uses (funs `extendVarSet` b) body
823 uses funs (Let (NonRec _b letExpr) body)
824 = uses funs letExpr || uses funs body
825 uses funs (Case e _eId _ty alts)
826 = uses funs e || any (uses_alt funs) alts
827 uses _ _ = False
828
829 uses_alt funs (_, _bs, e) = uses funs e
830
831 -- Generate code for a scalar function by generating a scalar closure. If the function is a
832 -- dictionary function, vectorise it as dictionary code.
833 --
834 mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
835 mkScalarFun arg_tys res_ty expr
836 | isPredTy res_ty
837 = do { vExpr <- vectDictExpr expr
838 ; return (vExpr, unused)
839 }
840 | otherwise
841 = do { traceVt "mkScalarFun: " $ ppr expr $$ ptext (sLit " ::") <+> ppr (mkFunTys arg_tys res_ty)
842
843 ; fn_var <- hoistExpr (fsLit "fn") expr DontInline
844 ; zipf <- zipScalars arg_tys res_ty
845 ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
846 ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
847 ; lclo <- liftPD (Var clo_var)
848 ; return (Var clo_var, lclo)
849 }
850 where
851 unused = error "Vectorise.Exp.mkScalarFun: we don't lift dictionary expressions"
852
853 -- |Vectorise a dictionary function that has a 'VECTORISE SCALAR instance' pragma.
854 --
855 -- In other words, all methods in that dictionary are scalar functions — to be vectorised with
856 -- 'vectScalarFun'. The dictionary "function" itself may be a constant, though.
857 --
858 -- NB: You may think that we could implement this function guided by the struture of the Core
859 -- expression of the right-hand side of the dictionary function. We cannot proceed like this as
860 -- 'vectScalarDFun' must also work for *imported* dfuns, where we don't necessarily have access
861 -- to the Core code of the unvectorised dfun.
862 --
863 -- Here an example — assume,
864 --
865 -- > class Eq a where { (==) :: a -> a -> Bool }
866 -- > instance (Eq a, Eq b) => Eq (a, b) where { (==) = ... }
867 -- > {-# VECTORISE SCALAR instance Eq (a, b) }
868 --
869 -- The unvectorised dfun for the above instance has the following signature:
870 --
871 -- > $dEqPair :: forall a b. Eq a -> Eq b -> Eq (a, b)
872 --
873 -- We generate the following (scalar) vectorised dfun (liberally using TH notation):
874 --
875 -- > $v$dEqPair :: forall a b. V:Eq a -> V:Eq b -> V:Eq (a, b)
876 -- > $v$dEqPair = /\a b -> \dEqa :: V:Eq a -> \dEqb :: V:Eq b ->
877 -- > D:V:Eq $(vectScalarFun True recFns
878 -- > [| (==) @(a, b) ($dEqPair @a @b $(unVect dEqa) $(unVect dEqb)) |])
879 --
880 -- NB:
881 -- * '(,)' vectorises to '(,)' — hence, the type constructor in the result type remains the same.
882 -- * We share the '$(unVect di)' sub-expressions between the different selectors, but duplicate
883 -- the application of the unvectorised dfun, to enable the dictionary selection rules to fire.
884 --
885 vectScalarDFun :: Var -- ^ Original dfun
886 -> [Var] -- ^ Functions names in same recursive binding group
887 -> VM CoreExpr
888 vectScalarDFun var recFns
889 = do { -- bring the type variables into scope
890 ; mapM_ defLocalTyVar tvs
891
892 -- vectorise dictionary argument types and generate variables for them
893 ; vTheta <- mapM vectType theta
894 ; vThetaBndr <- mapM (newLocalVar (fsLit "vd")) vTheta
895 ; let vThetaVars = varsToCoreExprs vThetaBndr
896
897 -- vectorise superclass dictionaries and methods as scalar expressions
898 ; thetaVars <- mapM (newLocalVar (fsLit "d")) theta
899 ; thetaExprs <- zipWithM unVectDict theta vThetaVars
900 ; let thetaDictBinds = zipWith NonRec thetaVars thetaExprs
901 dict = Var var `mkTyApps` (mkTyVarTys tvs) `mkVarApps` thetaVars
902 scsOps = map (\selId -> varToCoreExpr selId `mkTyApps` tys `mkApps` [dict])
903 selIds
904 ; vScsOps <- mapM (\e -> vectorised <$> vectScalarFun recFns e) scsOps
905
906 -- vectorised applications of the class-dictionary data constructor
907 ; Just vDataCon <- lookupDataCon dataCon
908 ; vTys <- mapM vectType tys
909 ; let vBody = thetaDictBinds `mkLets` mkCoreConApps vDataCon (map Type vTys ++ vScsOps)
910
911 ; return $ mkLams (tvs ++ vThetaBndr) vBody
912 }
913 where
914 ty = varType var
915 (tvs, theta, pty) = tcSplitSigmaTy ty -- 'theta' is the instance context
916 (cls, tys) = tcSplitDFunHead pty -- 'pty' is the instance head
917 selIds = classAllSelIds cls
918 dataCon = classDataCon cls
919
920 -- Build a value of the dictionary before vectorisation from original, unvectorised type and an
921 -- expression computing the vectorised dictionary.
922 --
923 -- Given the vectorised version of a dictionary 'vd :: V:C vt1..vtn', generate code that computes
924 -- the unvectorised version, thus:
925 --
926 -- > D:C op1 .. opm
927 -- > where
928 -- > opi = $(fromVect opTyi [| vSeli @vt1..vtk vd |])
929 --
930 -- where 'opTyi' is the type of the i-th superclass or op of the unvectorised dictionary.
931 --
932 unVectDict :: Type -> CoreExpr -> VM CoreExpr
933 unVectDict ty e
934 = do { vTys <- mapM vectType tys
935 ; let meths = map (\sel -> Var sel `mkTyApps` vTys `mkApps` [e]) selIds
936 ; scOps <- zipWithM fromVect methTys meths
937 ; return $ mkCoreConApps dataCon (map Type tys ++ scOps)
938 }
939 where
940 (tycon, tys, dataCon, methTys) = splitProductType "unVectDict: original type" ty
941 cls = case tyConClass_maybe tycon of
942 Just cls -> cls
943 Nothing -> panic "Vectorise.Exp.unVectDict: no class"
944 selIds = classAllSelIds cls
945
946 -- |Vectorise an 'n'-ary lambda abstraction by building a set of 'n' explicit closures.
947 --
948 -- All non-dictionary free variables go into the closure's environment, whereas the dictionary
949 -- variables are passed explicit (as conventional arguments) into the body during closure
950 -- construction.
951 --
952 vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
953 -> Bool -- ^ Whether the binding is a loop breaker.
954 -> CoreExprWithFVs -- ^ Body of abstraction.
955 -> VITree
956 -> VM VExpr
957 vectLam inline loop_breaker expr@(fvs, AnnLam _ _) vi
958 = do { let (bndrs, body) = collectAnnValBinders expr
959
960 -- grab the in-scope type variables
961 ; tyvars <- localTyVars
962
963 -- collect and vectorise all /local/ free variables
964 ; vfvs <- readLEnv $ \env ->
965 [ (var, fromJust mb_vv)
966 | var <- varSetElems fvs
967 , let mb_vv = lookupVarEnv (local_vars env) var
968 , isJust mb_vv -- its local == is in local var env
969 ]
970 -- separate dictionary from non-dictionary variables in the free variable set
971 ; let (vvs_dict, vvs_nondict) = partition (isPredTy . varType . fst) vfvs
972 (_fvs_dict, vfvs_dict) = unzip vvs_dict
973 (fvs_nondict, vfvs_nondict) = unzip vvs_nondict
974
975 -- compute the type of the vectorised closure
976 ; arg_tys <- mapM (vectType . idType) bndrs
977 ; res_ty <- vectType (exprType $ deAnnotate body)
978
979 ; let arity = length fvs_nondict + length bndrs
980 vfvs_dict' = map vectorised vfvs_dict
981 ; buildClosures tyvars vfvs_dict' vfvs_nondict arg_tys res_ty
982 . hoistPolyVExpr tyvars vfvs_dict' (maybe_inline arity)
983 $ do { -- generate the vectorised body of the lambda abstraction
984 ; lc <- builtin liftingContext
985 ; let viBody = stripLams expr vi
986 -- ; checkTreeAnnM vi expr
987 ; (vbndrs, vbody) <- vectBndrsIn (fvs_nondict ++ bndrs) (vectExpr body viBody)
988
989 ; vbody' <- break_loop lc res_ty vbody
990 ; return $ vLams lc vbndrs vbody'
991 }
992 }
993 where
994 stripLams (_, AnnLam _ e) (VITNode _ [vt]) = stripLams e vt
995 stripLams _ vi = vi
996
997 maybe_inline n | inline = Inline n
998 | otherwise = DontInline
999
1000 -- If this is the body of a binding marked as a loop breaker, add a recursion termination test
1001 -- to the /lifted/ version of the function body. The termination tests checks if the lifting
1002 -- context is empty. If so, it returns an empty array of the (lifted) result type instead of
1003 -- executing the function body. This is the test from the last line (defining \mathcal{L}')
1004 -- in Figure 6 of HtM.
1005 break_loop lc ty (ve, le)
1006 | loop_breaker
1007 = do { empty <- emptyPD ty
1008 ; lty <- mkPDataType ty
1009 ; return (ve, mkWildCase (Var lc) intPrimTy lty
1010 [(DEFAULT, [], le),
1011 (LitAlt (mkMachInt 0), [], empty)])
1012 }
1013 | otherwise = return (ve, le)
1014 vectLam _ _ _ _ = panic "vectLam"
1015
1016 -- | Vectorise an algebraic case expression.
1017 -- We convert
1018 --
1019 -- case e :: t of v { ... }
1020 --
1021 -- to
1022 --
1023 -- V: let v' = e in case v' of _ { ... }
1024 -- L: let v' = e in case v' `cast` ... of _ { ... }
1025 --
1026 -- When lifting, we have to do it this way because v must have the type
1027 -- [:V(T):] but the scrutinee must be cast to the representation type. We also
1028 -- have to handle the case where v is a wild var correctly.
1029 --
1030
1031 -- FIXME: this is too lazy
1032 vectAlgCase :: TyCon -> [Type] -> CoreExprWithFVs-> Var -> Type
1033 -> [(AltCon, [Var], CoreExprWithFVs)] -> VITree
1034 -> VM VExpr
1035 vectAlgCase _tycon _ty_args scrut bndr ty [(DEFAULT, [], body)] (VITNode _ (scrutVit : [altVit]))
1036 = do
1037 vscrut <- vectExpr scrut scrutVit
1038 (vty, lty) <- vectAndLiftType ty
1039 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
1040 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
1041
1042 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt _, [], body)] (VITNode _ (scrutVit : [altVit]))
1043 = do
1044 vscrut <- vectExpr scrut scrutVit
1045 (vty, lty) <- vectAndLiftType ty
1046 (vbndr, vbody) <- vectBndrIn bndr (vectExpr body altVit)
1047 return $ vCaseDEFAULT vscrut vbndr vty lty vbody
1048
1049 vectAlgCase _tycon _ty_args scrut bndr ty [(DataAlt dc, bndrs, body)] (VITNode _ (scrutVit : [altVit]))
1050 = do
1051 (vty, lty) <- vectAndLiftType ty
1052 vexpr <- vectExpr scrut scrutVit
1053 (vbndr, (vbndrs, (vect_body, lift_body)))
1054 <- vect_scrut_bndr
1055 . vectBndrsIn bndrs
1056 $ vectExpr body altVit
1057 let (vect_bndrs, lift_bndrs) = unzip vbndrs
1058 (vscrut, lscrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
1059 vect_dc <- maybeV dataConErr (lookupDataCon dc)
1060
1061 let vcase = mk_wild_case vscrut vty vect_dc vect_bndrs vect_body
1062 lcase = mk_wild_case lscrut lty pdata_dc lift_bndrs lift_body
1063
1064 return $ vLet (vNonRec vbndr vexpr) (vcase, lcase)
1065 where
1066 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
1067 | otherwise = vectBndrIn bndr
1068
1069 mk_wild_case expr ty dc bndrs body
1070 = mkWildCase expr (exprType expr) ty [(DataAlt dc, bndrs, body)]
1071
1072 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
1073
1074 vectAlgCase tycon _ty_args scrut bndr ty alts (VITNode _ (scrutVit : altVits))
1075 = do
1076 vect_tc <- maybeV tyConErr (lookupTyCon tycon)
1077 (vty, lty) <- vectAndLiftType ty
1078
1079 let arity = length (tyConDataCons vect_tc)
1080 sel_ty <- builtin (selTy arity)
1081 sel_bndr <- newLocalVar (fsLit "sel") sel_ty
1082 let sel = Var sel_bndr
1083
1084 (vbndr, valts) <- vect_scrut_bndr
1085 $ mapM (proc_alt arity sel vty lty) (zip alts' altVits)
1086 let (vect_dcs, vect_bndrss, lift_bndrss, vbodies) = unzip4 valts
1087
1088 vexpr <- vectExpr scrut scrutVit
1089 (vect_scrut, lift_scrut, pdata_dc) <- pdataUnwrapScrut (vVar vbndr)
1090
1091 let (vect_bodies, lift_bodies) = unzip vbodies
1092
1093 vdummy <- newDummyVar (exprType vect_scrut)
1094 ldummy <- newDummyVar (exprType lift_scrut)
1095 let vect_case = Case vect_scrut vdummy vty
1096 (zipWith3 mk_vect_alt vect_dcs vect_bndrss vect_bodies)
1097
1098 lc <- builtin liftingContext
1099 lbody <- combinePD vty (Var lc) sel lift_bodies
1100 let lift_case = Case lift_scrut ldummy lty
1101 [(DataAlt pdata_dc, sel_bndr : concat lift_bndrss,
1102 lbody)]
1103
1104 return . vLet (vNonRec vbndr vexpr)
1105 $ (vect_case, lift_case)
1106 where
1107 tyConErr = (text "vectAlgCase: type constructor not vectorised" <+> ppr tycon)
1108
1109 vect_scrut_bndr | isDeadBinder bndr = vectBndrNewIn bndr (fsLit "scrut")
1110 | otherwise = vectBndrIn bndr
1111
1112 alts' = sortBy (\(alt1, _, _) (alt2, _, _) -> cmp alt1 alt2) alts
1113
1114 cmp (DataAlt dc1) (DataAlt dc2) = dataConTag dc1 `compare` dataConTag dc2
1115 cmp DEFAULT DEFAULT = EQ
1116 cmp DEFAULT _ = LT
1117 cmp _ DEFAULT = GT
1118 cmp _ _ = panic "vectAlgCase/cmp"
1119
1120 proc_alt arity sel _ lty ((DataAlt dc, bndrs, body), vi)
1121 = do
1122 vect_dc <- maybeV dataConErr (lookupDataCon dc)
1123 let ntag = dataConTagZ vect_dc
1124 tag = mkDataConTag vect_dc
1125 fvs = freeVarsOf body `delVarSetList` bndrs
1126
1127 sel_tags <- liftM (`App` sel) (builtin (selTags arity))
1128 lc <- builtin liftingContext
1129 elems <- builtin (selElements arity ntag)
1130
1131 (vbndrs, vbody)
1132 <- vectBndrsIn bndrs
1133 . localV
1134 $ do
1135 binds <- mapM (pack_var (Var lc) sel_tags tag)
1136 . filter isLocalId
1137 $ varSetElems fvs
1138 (ve, le) <- vectExpr body vi
1139 return (ve, Case (elems `App` sel) lc lty
1140 [(DEFAULT, [], (mkLets (concat binds) le))])
1141 -- empty <- emptyPD vty
1142 -- return (ve, Case (elems `App` sel) lc lty
1143 -- [(DEFAULT, [], Let (NonRec flags_var flags_expr)
1144 -- $ mkLets (concat binds) le),
1145 -- (LitAlt (mkMachInt 0), [], empty)])
1146 let (vect_bndrs, lift_bndrs) = unzip vbndrs
1147 return (vect_dc, vect_bndrs, lift_bndrs, vbody)
1148 where
1149 dataConErr = (text "vectAlgCase: data constructor not vectorised" <+> ppr dc)
1150
1151
1152 proc_alt _ _ _ _ _ = panic "vectAlgCase/proc_alt"
1153
1154 mk_vect_alt vect_dc bndrs body = (DataAlt vect_dc, bndrs, body)
1155
1156 pack_var len tags t v
1157 = do
1158 r <- lookupVar v
1159 case r of
1160 Local (vv, lv) ->
1161 do
1162 lv' <- cloneVar lv
1163 expr <- packByTagPD (idType vv) (Var lv) len tags t
1164 updLEnv (\env -> env { local_vars = extendVarEnv
1165 (local_vars env) v (vv, lv') })
1166 return [(NonRec lv' expr)]
1167
1168 _ -> return []
1169
1170 vectAlgCase tycon _ty_args _scrut _bndr _ty _alts (VITNode _ [])
1171 = pprPanic "vectAlgCase (mismatched node information)" (ppr tycon)
1172
1173 {-
1174 ---- Sanity check of the tree, for debugging only
1175 checkTree :: VITree -> CoreExpr -> Bool
1176 checkTree (VITNode _ []) (Type _ty)
1177 = True
1178
1179 checkTree (VITNode _ []) (Var _v)
1180 = True
1181
1182
1183 checkTree (VITNode _ []) (Lit _)
1184 = True
1185
1186
1187 checkTree (VITNode _ [vit]) (Tick _ expr)
1188 = checkTree vit expr
1189
1190
1191
1192 checkTree (VITNode _ [vit]) (Lam _ expr)
1193 = checkTree vit expr
1194
1195
1196 checkTree (VITNode _ [vit1, vit2]) (App ce1 ce2)
1197 = (checkTree vit1 ce1) && (checkTree vit2 ce2)
1198
1199
1200
1201 checkTree (VITNode _ (scrutVit : altVits)) (Case scrut _ _ alts)
1202 = (checkTree scrutVit scrut) && (and $ zipWith checkAlt altVits alts)
1203 where
1204 checkAlt vt (_, _, expr) = checkTree vt expr
1205
1206
1207 checkTree (VITNode _ [vt1, vt2]) (Let (NonRec _ expr1) expr2)
1208 = (checkTree vt1 expr1) && (checkTree vt2 expr2)
1209
1210
1211
1212 checkTree (VITNode _ (vtB : vtBnds)) (Let (Rec bndngs) expr)
1213 = (and $ zipWith checkBndr vtBnds bndngs) &&
1214 (checkTree vtB expr)
1215 where
1216 checkBndr vt (_, e) = checkTree vt e
1217
1218
1219 checkTree (VITNode _ [vit]) (Cast expr _)
1220 = checkTree vit expr
1221
1222 checkTree _ _ = False
1223
1224 checkTreeAnnM:: VITree -> CoreExprWithFVs -> VM ()
1225 checkTreeAnnM vi e =
1226 if not (checkTree vi $ deAnnotate e)
1227 then error ("checkTreeAnnM : \n " ++ show vi)
1228 else return ()
1229 -}