9fdf3ba8f5ff1f939c7619a5ab7e7e449026fb57
[ghc.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1
2 module Vectorise.Builtins.Initialise (
3 -- * Initialisation
4 initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
5 initBuiltinPAs, initBuiltinPRs,
6 initBuiltinBoxedTyCons
7 ) where
8
9 import Vectorise.Builtins.Base
10 import Vectorise.Builtins.Modules
11
12 import BasicTypes
13 import PrelNames
14 import TysPrim
15 import DsMonad
16 import IfaceEnv
17 import InstEnv
18 import TysWiredIn
19 import DataCon
20 import TyCon
21 import Class
22 import CoreSyn
23 import Type
24 import Name
25 import Module
26 import Id
27 import FastString
28 import Outputable
29
30 import Control.Monad
31 import Data.Array
32
33 -- |Create the initial map of builtin types and functions.
34 --
35 initBuiltins :: PackageId -- ^ package id the builtins are in, eg dph-common
36 -> DsM Builtins
37 initBuiltins pkg
38 = do mapM_ load dph_Orphans
39
40 -- From dph-common:Data.Array.Parallel.PArray.PData
41 -- PData is a type family that maps an element type onto the type
42 -- we use to hold an array of those elements.
43 pdataTyCon <- externalTyCon dph_PArray_PData (fsLit "PData")
44
45 -- PR is a type class that holds the primitive operators we can
46 -- apply to array data. Its functions take arrays in terms of PData types.
47 prClass <- externalClass dph_PArray_PData (fsLit "PR")
48 let prTyCon = classTyCon prClass
49 [prDataCon] = tyConDataCons prTyCon
50
51
52 -- From dph-common:Data.Array.Parallel.PArray.PRepr
53 preprTyCon <- externalTyCon dph_PArray_PRepr (fsLit "PRepr")
54 paClass <- externalClass dph_PArray_PRepr (fsLit "PA")
55 let paTyCon = classTyCon paClass
56 [paDataCon] = tyConDataCons paTyCon
57 paPRSel = classSCSelId paClass 0
58
59 replicatePDVar <- externalVar dph_PArray_PRepr (fsLit "replicatePD")
60 emptyPDVar <- externalVar dph_PArray_PRepr (fsLit "emptyPD")
61 packByTagPDVar <- externalVar dph_PArray_PRepr (fsLit "packByTagPD")
62 combines <- mapM (externalVar dph_PArray_PRepr)
63 [mkFastString ("combine" ++ show i ++ "PD")
64 | i <- [2..mAX_DPH_COMBINE]]
65
66 let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
67
68
69 -- From dph-common:Data.Array.Parallel.PArray.Scalar
70 -- Scalar is the class of scalar values.
71 -- The dictionary contains functions to coerce U.Arrays of scalars
72 -- to and from the PData representation.
73 scalarClass <- externalClass dph_PArray_Scalar (fsLit "Scalar")
74
75
76 -- From dph-common:Data.Array.Parallel.Lifted.PArray
77 -- A PArray (Parallel Array) holds the array length and some array elements
78 -- represented by the PData type family.
79 parrayTyCon <- externalTyCon dph_PArray_Base (fsLit "PArray")
80 let [parrayDataCon] = tyConDataCons parrayTyCon
81
82 -- From dph-common:Data.Array.Parallel.PArray.Types
83 voidTyCon <- externalTyCon dph_PArray_Types (fsLit "Void")
84 voidVar <- externalVar dph_PArray_Types (fsLit "void")
85 fromVoidVar <- externalVar dph_PArray_Types (fsLit "fromVoid")
86 wrapTyCon <- externalTyCon dph_PArray_Types (fsLit "Wrap")
87 sum_tcs <- mapM (externalTyCon dph_PArray_Types) (numbered "Sum" 2 mAX_DPH_SUM)
88
89 -- from dph-common:Data.Array.Parallel.PArray.PDataInstances
90 pvoidVar <- externalVar dph_PArray_PDataInstances (fsLit "pvoid")
91 punitVar <- externalVar dph_PArray_PDataInstances (fsLit "punit")
92
93
94 closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
95
96
97 -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
98 sel_tys <- mapM (externalType dph_Unboxed)
99 (numbered "Sel" 2 mAX_DPH_SUM)
100
101 sel_replicates <- mapM (externalFun dph_Unboxed)
102 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
103
104 sel_picks <- mapM (externalFun dph_Unboxed)
105 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
106
107 sel_tags <- mapM (externalFun dph_Unboxed)
108 (numbered "tagsSel" 2 mAX_DPH_SUM)
109
110 sel_els <- mapM mk_elements
111 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
112
113
114 let selTys = listArray (2, mAX_DPH_SUM) sel_tys
115 selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
116 selPicks = listArray (2, mAX_DPH_SUM) sel_picks
117 selTagss = listArray (2, mAX_DPH_SUM) sel_tags
118 selEls = array ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
119 sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
120
121
122
123 closureVar <- externalVar dph_Closure (fsLit "closure")
124 applyVar <- externalVar dph_Closure (fsLit "$:")
125 liftedClosureVar <- externalVar dph_Closure (fsLit "liftedClosure")
126 liftedApplyVar <- externalVar dph_Closure (fsLit "liftedApply")
127
128 scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
129 scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
130 scalar_zips <- mapM (externalVar dph_Scalar)
131 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
132
133 let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
134 (scalar_map : scalar_zip2 : scalar_zips)
135
136 closures <- mapM (externalVar dph_Closure)
137 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
138
139 let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
140
141 liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
142 newUnique
143
144 return $ Builtins
145 { dphModules = mods
146 , parrayTyCon = parrayTyCon
147 , parrayDataCon = parrayDataCon
148 , pdataTyCon = pdataTyCon
149 , paClass = paClass
150 , paTyCon = paTyCon
151 , paDataCon = paDataCon
152 , paPRSel = paPRSel
153 , preprTyCon = preprTyCon
154 , prClass = prClass
155 , prTyCon = prTyCon
156 , prDataCon = prDataCon
157 , voidTyCon = voidTyCon
158 , wrapTyCon = wrapTyCon
159 , selTys = selTys
160 , selReplicates = selReplicates
161 , selPicks = selPicks
162 , selTagss = selTagss
163 , selEls = selEls
164 , sumTyCons = sumTyCons
165 , closureTyCon = closureTyCon
166 , voidVar = voidVar
167 , pvoidVar = pvoidVar
168 , fromVoidVar = fromVoidVar
169 , punitVar = punitVar
170 , closureVar = closureVar
171 , applyVar = applyVar
172 , liftedClosureVar = liftedClosureVar
173 , liftedApplyVar = liftedApplyVar
174 , replicatePDVar = replicatePDVar
175 , emptyPDVar = emptyPDVar
176 , packByTagPDVar = packByTagPDVar
177 , combinePDVars = combinePDVars
178 , scalarClass = scalarClass
179 , scalarZips = scalarZips
180 , closureCtrFuns = closureCtrFuns
181 , liftingContext = liftingContext
182 }
183 where
184 -- Extract out all the modules we'll use.
185 -- These are the modules from the DPH base library that contain
186 -- the primitive array types and functions that vectorised code uses.
187 mods@(Modules
188 { dph_PArray_Base = dph_PArray_Base
189 , dph_PArray_Scalar = dph_PArray_Scalar
190 , dph_PArray_PRepr = dph_PArray_PRepr
191 , dph_PArray_PData = dph_PArray_PData
192 , dph_PArray_PDataInstances = dph_PArray_PDataInstances
193 , dph_PArray_Types = dph_PArray_Types
194 , dph_Closure = dph_Closure
195 , dph_Scalar = dph_Scalar
196 , dph_Unboxed = dph_Unboxed
197 })
198 = dph_Modules pkg
199
200 load get_mod = dsLoadModule doc mod
201 where
202 mod = get_mod mods
203 doc = ppr mod <+> ptext (sLit "is a DPH module")
204
205 -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
206 numbered :: String -> Int -> Int -> [FastString]
207 numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
208
209 numbered_hash :: String -> Int -> Int -> [FastString]
210 numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
211
212 mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
213 mk_elements (i,j)
214 = do
215 v <- externalVar dph_Unboxed
216 $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
217 return ((i,j), Var v)
218
219 -- | Get the mapping of names in the Prelude to names in the DPH library.
220 --
221 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
222 initBuiltinVars (Builtins { dphModules = mods })
223 = do
224 cvars <- zipWithM externalVar cmods cfs
225 return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
226 ++ zip (map dataConWorkId cons) cvars
227 where
228 (cons, cmods, cfs) = unzip3 (preludeDataCons mods)
229
230 defaultDataConWorkers :: [DataCon]
231 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
232
233 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
234 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
235 = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
236 where
237 mk_tup n mod name = (tupleCon Boxed n, mod, name)
238
239 -- |Get a list of names to `TyCon`s in the mock prelude.
240 --
241 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
242 initBuiltinTyCons bi
243 = do
244 -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
245 dft_tcs <- defaultTyCons
246 return $ (tyConName funTyCon, closureTyCon bi)
247 : (parrTyConName, parrayTyCon bi)
248
249 -- FIXME: temporary
250 : (tyConName $ parrayTyCon bi, parrayTyCon bi)
251
252 : [(tyConName tc, tc) | tc <- dft_tcs]
253
254 where
255 defaultTyCons :: DsM [TyCon]
256 defaultTyCons
257 = do word8 <- dsLookupTyCon word8TyConName
258 return [intTyCon, boolTyCon, floatTyCon, doubleTyCon, word8]
259
260 -- |Get a list of names to `DataCon`s in the mock prelude.
261 --
262 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
263 initBuiltinDataCons _
264 = [(dataConName dc, dc)| dc <- defaultDataCons]
265 where
266 defaultDataCons :: [DataCon]
267 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
268
269 -- |Get the names of all buildin instance functions for the PA class.
270 --
271 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
272 initBuiltinPAs (Builtins { dphModules = mods }) insts
273 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PRepr mods) (fsLit "PA"))
274
275 -- |Get the names of all builtin instance functions for the PR class.
276 --
277 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
278 initBuiltinPRs (Builtins { dphModules = mods }) insts
279 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PData mods) (fsLit "PR"))
280
281 -- |Get the names of all DPH instance functions for this class.
282 --
283 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
284 initBuiltinDicts insts cls = map find $ classInstances insts cls
285 where
286 find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
287 | otherwise = pprPanic "Invalid DPH instance" (ppr i)
288
289 -- |Get a list of boxed `TyCons` in the mock prelude. This is Int only.
290 --
291 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
292 initBuiltinBoxedTyCons
293 = return . builtinBoxedTyCons
294 where
295 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
296 builtinBoxedTyCons _
297 = [(tyConName intPrimTyCon, intTyCon)]
298
299
300 -- Auxilliary look up functions ----------------
301
302 -- Lookup some variable given its name and the module that contains it.
303 --
304 externalVar :: Module -> FastString -> DsM Var
305 externalVar mod fs
306 = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
307
308 -- Like `externalVar` but wrap the `Var` in a `CoreExpr`.
309 --
310 externalFun :: Module -> FastString -> DsM CoreExpr
311 externalFun mod fs
312 = do var <- externalVar mod fs
313 return $ Var var
314
315 -- Lookup some `TyCon` given its name and the module that contains it.
316 --
317 externalTyCon :: Module -> FastString -> DsM TyCon
318 externalTyCon mod fs
319 = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
320
321 -- Lookup some `Type` given its name and the module that contains it.
322 --
323 externalType :: Module -> FastString -> DsM Type
324 externalType mod fs
325 = do tycon <- externalTyCon mod fs
326 return $ mkTyConApp tycon []
327
328 -- Lookup some `Class` given its name and the module that contains it.
329 --
330 externalClass :: Module -> FastString -> DsM Class
331 externalClass mod fs
332 = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)