Add VECTORISE [SCALAR] type pragma
[ghc.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1 -- Set up the data structures provided by 'Vectorise.Builtins'.
2
3 module Vectorise.Builtins.Initialise (
4 -- * Initialisation
5 initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
6 initBuiltinPAs, initBuiltinPRs,
7 initBuiltinBoxedTyCons
8 ) where
9
10 import Vectorise.Builtins.Base
11 import Vectorise.Builtins.Modules
12
13 import BasicTypes
14 import PrelNames
15 import TysPrim
16 import DsMonad
17 import IfaceEnv
18 import InstEnv
19 import TysWiredIn
20 import DataCon
21 import TyCon
22 import Class
23 import CoreSyn
24 import Type
25 import Name
26 import Module
27 import Id
28 import FastString
29 import Outputable
30
31 import Control.Monad
32 import Data.Array
33
34 -- |Create the initial map of builtin types and functions.
35 --
36 initBuiltins :: PackageId -- ^ package id the builtins are in, eg dph-common
37 -> DsM Builtins
38 initBuiltins pkg
39 = do mapM_ load dph_Orphans
40
41 -- From dph-common:Data.Array.Parallel.PArray.PData
42 -- PData is a type family that maps an element type onto the type
43 -- we use to hold an array of those elements.
44 pdataTyCon <- externalTyCon dph_PArray_PData (fsLit "PData")
45
46 -- PR is a type class that holds the primitive operators we can
47 -- apply to array data. Its functions take arrays in terms of PData types.
48 prClass <- externalClass dph_PArray_PData (fsLit "PR")
49 let prTyCon = classTyCon prClass
50 [prDataCon] = tyConDataCons prTyCon
51
52
53 -- From dph-common:Data.Array.Parallel.PArray.PRepr
54 preprTyCon <- externalTyCon dph_PArray_PRepr (fsLit "PRepr")
55 paClass <- externalClass dph_PArray_PRepr (fsLit "PA")
56 let paTyCon = classTyCon paClass
57 [paDataCon] = tyConDataCons paTyCon
58 paPRSel = classSCSelId paClass 0
59
60 replicatePDVar <- externalVar dph_PArray_PRepr (fsLit "replicatePD")
61 emptyPDVar <- externalVar dph_PArray_PRepr (fsLit "emptyPD")
62 packByTagPDVar <- externalVar dph_PArray_PRepr (fsLit "packByTagPD")
63 combines <- mapM (externalVar dph_PArray_PRepr)
64 [mkFastString ("combine" ++ show i ++ "PD")
65 | i <- [2..mAX_DPH_COMBINE]]
66
67 let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
68
69
70 -- From dph-common:Data.Array.Parallel.PArray.Scalar
71 -- Scalar is the class of scalar values.
72 -- The dictionary contains functions to coerce U.Arrays of scalars
73 -- to and from the PData representation.
74 scalarClass <- externalClass dph_PArray_Scalar (fsLit "Scalar")
75
76
77 -- From dph-common:Data.Array.Parallel.Lifted.PArray
78 -- A PArray (Parallel Array) holds the array length and some array elements
79 -- represented by the PData type family.
80 parrayTyCon <- externalTyCon dph_PArray_Base (fsLit "PArray")
81 let [parrayDataCon] = tyConDataCons parrayTyCon
82
83 -- From dph-common:Data.Array.Parallel.PArray.Types
84 voidTyCon <- externalTyCon dph_PArray_Types (fsLit "Void")
85 voidVar <- externalVar dph_PArray_Types (fsLit "void")
86 fromVoidVar <- externalVar dph_PArray_Types (fsLit "fromVoid")
87 wrapTyCon <- externalTyCon dph_PArray_Types (fsLit "Wrap")
88 sum_tcs <- mapM (externalTyCon dph_PArray_Types) (numbered "Sum" 2 mAX_DPH_SUM)
89
90 -- from dph-common:Data.Array.Parallel.PArray.PDataInstances
91 pvoidVar <- externalVar dph_PArray_PDataInstances (fsLit "pvoid")
92 punitVar <- externalVar dph_PArray_PDataInstances (fsLit "punit")
93
94
95 closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
96
97
98 -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
99 sel_tys <- mapM (externalType dph_Unboxed)
100 (numbered "Sel" 2 mAX_DPH_SUM)
101
102 sel_replicates <- mapM (externalFun dph_Unboxed)
103 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
104
105 sel_picks <- mapM (externalFun dph_Unboxed)
106 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
107
108 sel_tags <- mapM (externalFun dph_Unboxed)
109 (numbered "tagsSel" 2 mAX_DPH_SUM)
110
111 sel_els <- mapM mk_elements
112 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
113
114
115 let selTys = listArray (2, mAX_DPH_SUM) sel_tys
116 selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
117 selPicks = listArray (2, mAX_DPH_SUM) sel_picks
118 selTagss = listArray (2, mAX_DPH_SUM) sel_tags
119 selEls = array ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
120 sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
121
122
123
124 closureVar <- externalVar dph_Closure (fsLit "closure")
125 applyVar <- externalVar dph_Closure (fsLit "$:")
126 liftedClosureVar <- externalVar dph_Closure (fsLit "liftedClosure")
127 liftedApplyVar <- externalVar dph_Closure (fsLit "liftedApply")
128
129 scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
130 scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
131 scalar_zips <- mapM (externalVar dph_Scalar)
132 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
133
134 let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
135 (scalar_map : scalar_zip2 : scalar_zips)
136
137 closures <- mapM (externalVar dph_Closure)
138 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
139
140 let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
141
142 liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
143 newUnique
144
145 return $ Builtins
146 { dphModules = mods
147 , parrayTyCon = parrayTyCon
148 , parrayDataCon = parrayDataCon
149 , pdataTyCon = pdataTyCon
150 , paClass = paClass
151 , paTyCon = paTyCon
152 , paDataCon = paDataCon
153 , paPRSel = paPRSel
154 , preprTyCon = preprTyCon
155 , prClass = prClass
156 , prTyCon = prTyCon
157 , prDataCon = prDataCon
158 , voidTyCon = voidTyCon
159 , wrapTyCon = wrapTyCon
160 , selTys = selTys
161 , selReplicates = selReplicates
162 , selPicks = selPicks
163 , selTagss = selTagss
164 , selEls = selEls
165 , sumTyCons = sumTyCons
166 , closureTyCon = closureTyCon
167 , voidVar = voidVar
168 , pvoidVar = pvoidVar
169 , fromVoidVar = fromVoidVar
170 , punitVar = punitVar
171 , closureVar = closureVar
172 , applyVar = applyVar
173 , liftedClosureVar = liftedClosureVar
174 , liftedApplyVar = liftedApplyVar
175 , replicatePDVar = replicatePDVar
176 , emptyPDVar = emptyPDVar
177 , packByTagPDVar = packByTagPDVar
178 , combinePDVars = combinePDVars
179 , scalarClass = scalarClass
180 , scalarZips = scalarZips
181 , closureCtrFuns = closureCtrFuns
182 , liftingContext = liftingContext
183 }
184 where
185 -- Extract out all the modules we'll use.
186 -- These are the modules from the DPH base library that contain
187 -- the primitive array types and functions that vectorised code uses.
188 mods@(Modules
189 { dph_PArray_Base = dph_PArray_Base
190 , dph_PArray_Scalar = dph_PArray_Scalar
191 , dph_PArray_PRepr = dph_PArray_PRepr
192 , dph_PArray_PData = dph_PArray_PData
193 , dph_PArray_PDataInstances = dph_PArray_PDataInstances
194 , dph_PArray_Types = dph_PArray_Types
195 , dph_Closure = dph_Closure
196 , dph_Scalar = dph_Scalar
197 , dph_Unboxed = dph_Unboxed
198 })
199 = dph_Modules pkg
200
201 load get_mod = dsLoadModule doc mod
202 where
203 mod = get_mod mods
204 doc = ppr mod <+> ptext (sLit "is a DPH module")
205
206 -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
207 numbered :: String -> Int -> Int -> [FastString]
208 numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
209
210 numbered_hash :: String -> Int -> Int -> [FastString]
211 numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
212
213 mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
214 mk_elements (i,j)
215 = do
216 v <- externalVar dph_Unboxed
217 $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
218 return ((i,j), Var v)
219
220 -- | Get the mapping of names in the Prelude to names in the DPH library.
221 --
222 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
223 initBuiltinVars (Builtins { dphModules = mods })
224 = do
225 cvars <- zipWithM externalVar cmods cfs
226 return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
227 ++ zip (map dataConWorkId cons) cvars
228 where
229 (cons, cmods, cfs) = unzip3 (preludeDataCons mods)
230
231 defaultDataConWorkers :: [DataCon]
232 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
233
234 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
235 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
236 = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
237 where
238 mk_tup n mod name = (tupleCon Boxed n, mod, name)
239
240 -- |Get a list of names to `TyCon`s in the mock prelude.
241 --
242 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
243 initBuiltinTyCons bi
244 = do
245 -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
246 dft_tcs <- defaultTyCons
247 return $ (tyConName funTyCon, closureTyCon bi)
248 : (parrTyConName, parrayTyCon bi)
249
250 -- FIXME: temporary
251 : (tyConName $ parrayTyCon bi, parrayTyCon bi)
252
253 : [(tyConName tc, tc) | tc <- dft_tcs]
254
255 where
256 defaultTyCons :: DsM [TyCon]
257 defaultTyCons
258 = do word8 <- dsLookupTyCon word8TyConName
259 return [intTyCon, boolTyCon, floatTyCon, doubleTyCon, word8]
260
261 -- |Get a list of names to `DataCon`s in the mock prelude.
262 --
263 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
264 initBuiltinDataCons _
265 = [(dataConName dc, dc)| dc <- defaultDataCons]
266 where
267 defaultDataCons :: [DataCon]
268 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
269
270 -- |Get the names of all buildin instance functions for the PA class.
271 --
272 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
273 initBuiltinPAs (Builtins { dphModules = mods }) insts
274 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PRepr mods) (fsLit "PA"))
275
276 -- |Get the names of all builtin instance functions for the PR class.
277 --
278 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
279 initBuiltinPRs (Builtins { dphModules = mods }) insts
280 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PData mods) (fsLit "PR"))
281
282 -- |Get the names of all DPH instance functions for this class.
283 --
284 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
285 initBuiltinDicts insts cls = map find $ classInstances insts cls
286 where
287 find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
288 | otherwise = pprPanic "Invalid DPH instance" (ppr i)
289
290 -- |Get a list of boxed `TyCons` in the mock prelude. This is Int only.
291 --
292 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
293 initBuiltinBoxedTyCons
294 = return . builtinBoxedTyCons
295 where
296 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
297 builtinBoxedTyCons _
298 = [(tyConName intPrimTyCon, intTyCon)]
299
300
301 -- Auxilliary look up functions ----------------
302
303 -- Lookup some variable given its name and the module that contains it.
304 --
305 externalVar :: Module -> FastString -> DsM Var
306 externalVar mod fs
307 = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
308
309 -- Like `externalVar` but wrap the `Var` in a `CoreExpr`.
310 --
311 externalFun :: Module -> FastString -> DsM CoreExpr
312 externalFun mod fs
313 = do var <- externalVar mod fs
314 return $ Var var
315
316 -- Lookup some `TyCon` given its name and the module that contains it.
317 --
318 externalTyCon :: Module -> FastString -> DsM TyCon
319 externalTyCon mod fs
320 = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
321
322 -- Lookup some `Type` given its name and the module that contains it.
323 --
324 externalType :: Module -> FastString -> DsM Type
325 externalType mod fs
326 = do tycon <- externalTyCon mod fs
327 return $ mkTyConApp tycon []
328
329 -- Lookup some `Class` given its name and the module that contains it.
330 --
331 externalClass :: Module -> FastString -> DsM Class
332 externalClass mod fs
333 = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)