Vectoriser gets all DPH library identifiers from Data.Array.Parallel.Prim
[ghc.git] / compiler / vectorise / Vectorise / Utils / Closure.hs
1 -- |Utils concerning closure construction and application.
2
3 module Vectorise.Utils.Closure (
4 mkClosure,
5 mkClosureApp,
6 buildClosure,
7 buildClosures,
8 buildEnv
9 )
10 where
11
12 import Vectorise.Builtins
13 import Vectorise.Vect
14 import Vectorise.Monad
15 import Vectorise.Utils.Base
16 import Vectorise.Utils.PADict
17 import Vectorise.Utils.Hoisting
18
19 import CoreSyn
20 import Type
21 import MkCore
22 import CoreUtils
23 import TyCon
24 import DataCon
25 import MkId
26 import TysWiredIn
27 import BasicTypes( TupleSort(..) )
28 import FastString
29
30
31 -- | Make a closure.
32 mkClosure
33 :: Type -- ^ Type of the argument.
34 -> Type -- ^ Type of the result.
35 -> Type -- ^ Type of the environment.
36 -> VExpr -- ^ The function to apply.
37 -> VExpr -- ^ The environment to use.
38 -> VM VExpr
39
40 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
41 = do dict <- paDictOfType env_ty
42 mkv <- builtin closureVar
43 mkl <- builtin liftedClosureVar
44 return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
45 Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
46
47
48 -- | Make a closure application.
49 mkClosureApp
50 :: Type -- ^ Type of the argument.
51 -> Type -- ^ Type of the result.
52 -> VExpr -- ^ Closure to apply.
53 -> VExpr -- ^ Argument to use.
54 -> VM VExpr
55
56 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
57 = do vapply <- builtin applyVar
58 lapply <- builtin liftedApplyVar
59 lc <- builtin liftingContext
60 return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
61 Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [Var lc, lclo, larg])
62
63
64 buildClosures
65 :: [TyVar]
66 -> [VVar]
67 -> [Type] -- ^ Type of the arguments.
68 -> Type -- ^ Type of result.
69 -> VM VExpr
70 -> VM VExpr
71
72 buildClosures _ _ [] _ mk_body
73 = mk_body
74
75 buildClosures tvs vars [arg_ty] res_ty mk_body
76 = buildClosure tvs vars arg_ty res_ty mk_body
77
78 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
79 = do res_ty' <- mkClosureTypes arg_tys res_ty
80 arg <- newLocalVVar (fsLit "x") arg_ty
81 buildClosure tvs vars arg_ty res_ty'
82 . hoistPolyVExpr tvs (Inline (length vars + 1))
83 $ do
84 lc <- builtin liftingContext
85 clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
86 return $ vLams lc (vars ++ [arg]) clo
87
88
89 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
90 -- where
91 -- f = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
92 -- f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
93 --
94 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
95 buildClosure tvs vars arg_ty res_ty mk_body
96 = do
97 (env_ty, env, bind) <- buildEnv vars
98 env_bndr <- newLocalVVar (fsLit "env") env_ty
99 arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
100
101 fn <- hoistPolyVExpr tvs (Inline 2)
102 $ do
103 lc <- builtin liftingContext
104 body <- mk_body
105 return . vLams lc [env_bndr, arg_bndr]
106 $ bind (vVar env_bndr)
107 (vVarApps lc body (vars ++ [arg_bndr]))
108
109 mkClosure arg_ty res_ty env_ty fn env
110
111
112 -- Environments ---------------------------------------------------------------
113 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
114 buildEnv []
115 = do
116 ty <- voidType
117 void <- builtin voidVar
118 pvoid <- builtin pvoidVar
119 return (ty, vVar (void, pvoid), \_ body -> body)
120
121 buildEnv [v] = return (vVarType v, vVar v,
122 \env body -> vLet (vNonRec v env) body)
123
124 buildEnv vs
125 = do (lenv_tc, lenv_tyargs) <- pdataReprTyCon ty
126
127 let venv_con = tupleCon BoxedTuple (length vs)
128 [lenv_con] = tyConDataCons lenv_tc
129
130 venv = mkCoreTup (map Var vvs)
131 lenv = Var (dataConWrapId lenv_con)
132 `mkTyApps` lenv_tyargs
133 `mkApps` map Var lvs
134
135 vbind env body = mkWildCase env ty (exprType body)
136 [(DataAlt venv_con, vvs, body)]
137
138 lbind env body =
139 let scrut = unwrapFamInstScrut lenv_tc lenv_tyargs env
140 in
141 mkWildCase scrut (exprType scrut) (exprType body)
142 [(DataAlt lenv_con, lvs, body)]
143
144 bind (venv, lenv) (vbody, lbody) = (vbind venv vbody,
145 lbind lenv lbody)
146
147 return (ty, (venv, lenv), bind)
148 where
149 (vvs, lvs) = unzip vs
150 tys = map vVarType vs
151 ty = mkBoxedTupleTy tys