Rename *NDP* -> *DPH*
[ghc.git] / compiler / vectorise / VectBuiltIn.hs
1 module VectBuiltIn (
2 Builtins(..), sumTyCon, prodTyCon,
3 combinePAVar,
4 initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
5 initBuiltinPAs, initBuiltinPRs,
6 initBuiltinBoxedTyCons,
7
8 primMethod, primPArray
9 ) where
10
11 import DsMonad
12 import IfaceEnv ( lookupOrig )
13
14 import Module
15 import DataCon ( DataCon, dataConName, dataConWorkId )
16 import TyCon ( TyCon, tyConName, tyConDataCons )
17 import Var ( Var )
18 import Id ( mkSysLocal )
19 import Name ( Name, getOccString )
20 import NameEnv
21 import OccName
22
23 import TypeRep ( funTyCon )
24 import Type ( Type, mkTyConApp )
25 import TysPrim
26 import TysWiredIn ( unitTyCon, unitDataCon,
27 tupleTyCon, tupleCon,
28 intTyCon, intTyConName,
29 doubleTyCon, doubleTyConName,
30 boolTyCon, boolTyConName, trueDataCon, falseDataCon,
31 parrTyConName )
32 import PrelNames ( gHC_PARR )
33 import BasicTypes ( Boxity(..) )
34
35 import FastString
36 import Outputable
37
38 import Data.Array
39 import Control.Monad ( liftM, zipWithM )
40 import Data.List ( unzip4 )
41
42 mAX_DPH_PROD :: Int
43 mAX_DPH_PROD = 5
44
45 mAX_DPH_SUM :: Int
46 mAX_DPH_SUM = 3
47
48 mAX_DPH_COMBINE :: Int
49 mAX_DPH_COMBINE = 2
50
51 data Modules = Modules {
52 dph_PArray :: Module
53 , dph_Repr :: Module
54 , dph_Closure :: Module
55 , dph_Unboxed :: Module
56 , dph_Instances :: Module
57 , dph_Combinators :: Module
58 , dph_Prelude_PArr :: Module
59 , dph_Prelude_Int :: Module
60 , dph_Prelude_Double :: Module
61 , dph_Prelude_Bool :: Module
62 , dph_Prelude_Tuple :: Module
63 }
64
65 dph_Modules :: PackageId -> Modules
66 dph_Modules pkg = Modules {
67 dph_PArray = mk (fsLit "Data.Array.Parallel.Lifted.PArray")
68 , dph_Repr = mk (fsLit "Data.Array.Parallel.Lifted.Repr")
69 , dph_Closure = mk (fsLit "Data.Array.Parallel.Lifted.Closure")
70 , dph_Unboxed = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
71 , dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
72 , dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
73
74 , dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
75 , dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
76 , dph_Prelude_Double = mk (fsLit "Data.Array.Parallel.Prelude.Base.Double")
77 , dph_Prelude_Bool = mk (fsLit "Data.Array.Parallel.Prelude.Base.Bool")
78 , dph_Prelude_Tuple = mk (fsLit "Data.Array.Parallel.Prelude.Base.Tuple")
79 }
80 where
81 mk = mkModule pkg . mkModuleNameFS
82
83
84 data Builtins = Builtins {
85 dphModules :: Modules
86 , parrayTyCon :: TyCon
87 , paTyCon :: TyCon
88 , paDataCon :: DataCon
89 , preprTyCon :: TyCon
90 , prTyCon :: TyCon
91 , prDataCon :: DataCon
92 , intPrimArrayTy :: Type
93 , voidTyCon :: TyCon
94 , wrapTyCon :: TyCon
95 , enumerationTyCon :: TyCon
96 , sumTyCons :: Array Int TyCon
97 , closureTyCon :: TyCon
98 , voidVar :: Var
99 , mkPRVar :: Var
100 , mkClosureVar :: Var
101 , applyClosureVar :: Var
102 , mkClosurePVar :: Var
103 , applyClosurePVar :: Var
104 , replicatePAIntPrimVar :: Var
105 , upToPAIntPrimVar :: Var
106 , selectPAIntPrimVar :: Var
107 , truesPABoolPrimVar :: Var
108 , lengthPAVar :: Var
109 , replicatePAVar :: Var
110 , emptyPAVar :: Var
111 , packPAVar :: Var
112 , combinePAVars :: Array Int Var
113 , liftingContext :: Var
114 }
115
116 sumTyCon :: Int -> Builtins -> TyCon
117 sumTyCon n bi
118 | n >= 2 && n <= mAX_DPH_SUM = sumTyCons bi ! n
119 | otherwise = pprPanic "sumTyCon" (ppr n)
120
121 prodTyCon :: Int -> Builtins -> TyCon
122 prodTyCon n bi
123 | n == 1 = wrapTyCon bi
124 | n >= 0 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
125 | otherwise = pprPanic "prodTyCon" (ppr n)
126
127 combinePAVar :: Int -> Builtins -> Var
128 combinePAVar n bi
129 | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
130 | otherwise = pprPanic "combinePAVar" (ppr n)
131
132 initBuiltins :: PackageId -> DsM Builtins
133 initBuiltins pkg
134 = do
135 parrayTyCon <- externalTyCon dph_PArray (fsLit "PArray")
136 paTyCon <- externalTyCon dph_PArray (fsLit "PA")
137 let [paDataCon] = tyConDataCons paTyCon
138 preprTyCon <- externalTyCon dph_PArray (fsLit "PRepr")
139 prTyCon <- externalTyCon dph_PArray (fsLit "PR")
140 let [prDataCon] = tyConDataCons prTyCon
141 intPrimArrayTy <- externalType dph_Unboxed (fsLit "PArray_Int#")
142 closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
143
144 voidTyCon <- externalTyCon dph_Repr (fsLit "Void")
145 wrapTyCon <- externalTyCon dph_Repr (fsLit "Wrap")
146 enumerationTyCon <- externalTyCon dph_Repr (fsLit "Enumeration")
147 sum_tcs <- mapM (externalTyCon dph_Repr)
148 [mkFastString ("Sum" ++ show i) | i <- [2..mAX_DPH_SUM]]
149
150 let sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
151
152 voidVar <- externalVar dph_Repr (fsLit "void")
153 mkPRVar <- externalVar dph_PArray (fsLit "mkPR")
154 mkClosureVar <- externalVar dph_Closure (fsLit "mkClosure")
155 applyClosureVar <- externalVar dph_Closure (fsLit "$:")
156 mkClosurePVar <- externalVar dph_Closure (fsLit "mkClosureP")
157 applyClosurePVar <- externalVar dph_Closure (fsLit "$:^")
158 replicatePAIntPrimVar <- externalVar dph_Unboxed (fsLit "replicatePA_Int#")
159 upToPAIntPrimVar <- externalVar dph_Unboxed (fsLit "upToPA_Int#")
160 selectPAIntPrimVar <- externalVar dph_Unboxed (fsLit "selectPA_Int#")
161 truesPABoolPrimVar <- externalVar dph_Unboxed (fsLit "truesPA_Bool#")
162 lengthPAVar <- externalVar dph_PArray (fsLit "lengthPA#")
163 replicatePAVar <- externalVar dph_PArray (fsLit "replicatePA#")
164 emptyPAVar <- externalVar dph_PArray (fsLit "emptyPA")
165 packPAVar <- externalVar dph_PArray (fsLit "packPA#")
166
167 combines <- mapM (externalVar dph_PArray)
168 [mkFastString ("combine" ++ show i ++ "PA#")
169 | i <- [2..mAX_DPH_COMBINE]]
170 let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
171
172 liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
173 newUnique
174
175 return $ Builtins {
176 dphModules = modules
177 , parrayTyCon = parrayTyCon
178 , paTyCon = paTyCon
179 , paDataCon = paDataCon
180 , preprTyCon = preprTyCon
181 , prTyCon = prTyCon
182 , prDataCon = prDataCon
183 , intPrimArrayTy = intPrimArrayTy
184 , voidTyCon = voidTyCon
185 , wrapTyCon = wrapTyCon
186 , enumerationTyCon = enumerationTyCon
187 , sumTyCons = sumTyCons
188 , closureTyCon = closureTyCon
189 , voidVar = voidVar
190 , mkPRVar = mkPRVar
191 , mkClosureVar = mkClosureVar
192 , applyClosureVar = applyClosureVar
193 , mkClosurePVar = mkClosurePVar
194 , applyClosurePVar = applyClosurePVar
195 , replicatePAIntPrimVar = replicatePAIntPrimVar
196 , upToPAIntPrimVar = upToPAIntPrimVar
197 , selectPAIntPrimVar = selectPAIntPrimVar
198 , truesPABoolPrimVar = truesPABoolPrimVar
199 , lengthPAVar = lengthPAVar
200 , replicatePAVar = replicatePAVar
201 , emptyPAVar = emptyPAVar
202 , packPAVar = packPAVar
203 , combinePAVars = combinePAVars
204 , liftingContext = liftingContext
205 }
206 where
207 modules@(Modules {
208 dph_PArray = dph_PArray
209 , dph_Repr = dph_Repr
210 , dph_Closure = dph_Closure
211 , dph_Unboxed = dph_Unboxed
212 })
213 = dph_Modules pkg
214
215
216 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
217 initBuiltinVars (Builtins { dphModules = modules })
218 = do
219 uvars <- zipWithM externalVar (map ($ modules) umods) ufs
220 vvars <- zipWithM externalVar (map ($ modules) vmods) vfs
221 cvars <- zipWithM externalVar (map ($ modules) cmods) cfs
222 return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
223 ++ zip (map dataConWorkId cons) cvars
224 ++ zip uvars vvars
225 where
226 (umods, ufs, vmods, vfs) = unzip4 preludeVars
227
228 (cons, cmods, cfs) = unzip3 preludeDataCons
229
230 defaultDataConWorkers :: [DataCon]
231 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
232
233 preludeDataCons :: [(DataCon, Modules -> Module, FastString)]
234 preludeDataCons
235 = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
236 where
237 mk_tup n mod name = (tupleCon Boxed n, mod, name)
238
239 preludeVars :: [(Modules -> Module, FastString, Modules -> Module, FastString)]
240 preludeVars
241 = [
242 mk (const gHC_PARR) (fsLit "mapP") dph_Combinators (fsLit "mapPA")
243 , mk (const gHC_PARR) (fsLit "zipWithP") dph_Combinators (fsLit "zipWithPA")
244 , mk (const gHC_PARR) (fsLit "zipP") dph_Combinators (fsLit "zipPA")
245 , mk (const gHC_PARR) (fsLit "unzipP") dph_Combinators (fsLit "unzipPA")
246 , mk (const gHC_PARR) (fsLit "filterP") dph_Combinators (fsLit "filterPA")
247 , mk (const gHC_PARR) (fsLit "lengthP") dph_Combinators (fsLit "lengthPA")
248 , mk (const gHC_PARR) (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
249 , mk (const gHC_PARR) (fsLit "!:") dph_Combinators (fsLit "indexPA")
250 , mk (const gHC_PARR) (fsLit "crossMapP") dph_Combinators (fsLit "crossMapPA")
251 , mk (const gHC_PARR) (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
252 , mk (const gHC_PARR) (fsLit "concatP") dph_Combinators (fsLit "concatPA")
253 , mk (const gHC_PARR) (fsLit "+:+") dph_Combinators (fsLit "appPA")
254 , mk (const gHC_PARR) (fsLit "emptyP") dph_PArray (fsLit "emptyPA")
255
256 , mk dph_Prelude_Int (fsLit "plus") dph_Prelude_Int (fsLit "plusV")
257 , mk dph_Prelude_Int (fsLit "minus") dph_Prelude_Int (fsLit "minusV")
258 , mk dph_Prelude_Int (fsLit "mult") dph_Prelude_Int (fsLit "multV")
259 , mk dph_Prelude_Int (fsLit "intDiv") dph_Prelude_Int (fsLit "intDivV")
260 , mk dph_Prelude_Int (fsLit "intMod") dph_Prelude_Int (fsLit "intModV")
261 , mk dph_Prelude_Int (fsLit "intSquareRoot") dph_Prelude_Int (fsLit "intSquareRootV")
262 , mk dph_Prelude_Int (fsLit "intSumP") dph_Prelude_Int (fsLit "intSumPA")
263 , mk dph_Prelude_Int (fsLit "enumFromToP") dph_Prelude_Int (fsLit "enumFromToPA")
264 , mk dph_Prelude_Int (fsLit "upToP") dph_Prelude_Int (fsLit "upToPA")
265
266 , mk dph_Prelude_Int (fsLit "eq") dph_Prelude_Int (fsLit "eqV")
267 , mk dph_Prelude_Int (fsLit "neq") dph_Prelude_Int (fsLit "neqV")
268 , mk dph_Prelude_Int (fsLit "le") dph_Prelude_Int (fsLit "leV")
269 , mk dph_Prelude_Int (fsLit "lt") dph_Prelude_Int (fsLit "ltV")
270 , mk dph_Prelude_Int (fsLit "ge") dph_Prelude_Int (fsLit "geV")
271 , mk dph_Prelude_Int (fsLit "gt") dph_Prelude_Int (fsLit "gtV")
272
273 , mk dph_Prelude_Double (fsLit "plus") dph_Prelude_Double (fsLit "plusV")
274 , mk dph_Prelude_Double (fsLit "minus") dph_Prelude_Double (fsLit "minusV")
275 , mk dph_Prelude_Double (fsLit "mult") dph_Prelude_Double (fsLit "multV")
276 , mk dph_Prelude_Double (fsLit "divide") dph_Prelude_Double (fsLit "divideV")
277 , mk dph_Prelude_Double (fsLit "squareRoot") dph_Prelude_Double (fsLit "squareRootV")
278 , mk dph_Prelude_Double (fsLit "doubleSumP") dph_Prelude_Double (fsLit "doubleSumPA")
279 , mk dph_Prelude_Double (fsLit "minIndexP")
280 dph_Prelude_Double (fsLit "minIndexPA")
281 , mk dph_Prelude_Double (fsLit "maxIndexP")
282 dph_Prelude_Double (fsLit "maxIndexPA")
283
284 , mk dph_Prelude_Double (fsLit "eq") dph_Prelude_Double (fsLit "eqV")
285 , mk dph_Prelude_Double (fsLit "neq") dph_Prelude_Double (fsLit "neqV")
286 , mk dph_Prelude_Double (fsLit "le") dph_Prelude_Double (fsLit "leV")
287 , mk dph_Prelude_Double (fsLit "lt") dph_Prelude_Double (fsLit "ltV")
288 , mk dph_Prelude_Double (fsLit "ge") dph_Prelude_Double (fsLit "geV")
289 , mk dph_Prelude_Double (fsLit "gt") dph_Prelude_Double (fsLit "gtV")
290
291 , mk dph_Prelude_Bool (fsLit "andP") dph_Prelude_Bool (fsLit "andPA")
292 , mk dph_Prelude_Bool (fsLit "orP") dph_Prelude_Bool (fsLit "orPA")
293
294 -- FIXME: temporary
295 , mk dph_Prelude_PArr (fsLit "fromPArrayP") dph_Prelude_PArr (fsLit "fromPArrayPA")
296 , mk dph_Prelude_PArr (fsLit "toPArrayP") dph_Prelude_PArr (fsLit "toPArrayPA")
297 , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
298 , mk dph_Prelude_PArr (fsLit "combineP") dph_Combinators (fsLit "combine2PA")
299 ]
300 where
301 mk = (,,,)
302
303 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
304 initBuiltinTyCons bi
305 = do
306 -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
307 return $ (tyConName funTyCon, closureTyCon bi)
308 : (parrTyConName, parrayTyCon bi)
309
310 -- FIXME: temporary
311 : (tyConName $ parrayTyCon bi, parrayTyCon bi)
312
313 : [(tyConName tc, tc) | tc <- defaultTyCons]
314
315 defaultTyCons :: [TyCon]
316 defaultTyCons = [intTyCon, boolTyCon, doubleTyCon]
317
318 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
319 initBuiltinDataCons _ = [(dataConName dc, dc)| dc <- defaultDataCons]
320
321 defaultDataCons :: [DataCon]
322 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
323
324 initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
325 initBuiltinDicts ps
326 = do
327 dicts <- zipWithM externalVar mods fss
328 return $ zip tcs dicts
329 where
330 (tcs, mods, fss) = unzip3 ps
331
332 initBuiltinPAs :: Builtins -> DsM [(Name, Var)]
333 initBuiltinPAs = initBuiltinDicts . builtinPAs
334
335 builtinPAs :: Builtins -> [(Name, Module, FastString)]
336 builtinPAs bi@(Builtins { dphModules = mods })
337 = [
338 mk (tyConName $ closureTyCon bi) (dph_Closure mods) (fsLit "dPA_Clo")
339 , mk (tyConName $ voidTyCon bi) (dph_Repr mods) (fsLit "dPA_Void")
340 , mk (tyConName $ parrayTyCon bi) (dph_Instances mods) (fsLit "dPA_PArray")
341 , mk unitTyConName (dph_Instances mods) (fsLit "dPA_Unit")
342
343 , mk intTyConName (dph_Instances mods) (fsLit "dPA_Int")
344 , mk doubleTyConName (dph_Instances mods) (fsLit "dPA_Double")
345 , mk boolTyConName (dph_Instances mods) (fsLit "dPA_Bool")
346 ]
347 ++ tups
348 where
349 mk name mod fs = (name, mod, fs)
350
351 tups = map mk_tup [2..mAX_DPH_PROD]
352 mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
353 (dph_Instances mods)
354 (mkFastString $ "dPA_" ++ show n)
355
356 initBuiltinPRs :: Builtins -> DsM [(Name, Var)]
357 initBuiltinPRs = initBuiltinDicts . builtinPRs
358
359 builtinPRs :: Builtins -> [(Name, Module, FastString)]
360 builtinPRs bi@(Builtins { dphModules = mods }) =
361 [
362 mk (tyConName unitTyCon) (dph_Repr mods) (fsLit "dPR_Unit")
363 , mk (tyConName $ voidTyCon bi) (dph_Repr mods) (fsLit "dPR_Void")
364 , mk (tyConName $ wrapTyCon bi) (dph_Repr mods) (fsLit "dPR_Wrap")
365 , mk (tyConName $ enumerationTyCon bi) (dph_Repr mods) (fsLit "dPR_Enumeration")
366 , mk (tyConName $ closureTyCon bi) (dph_Closure mods) (fsLit "dPR_Clo")
367
368 -- temporary
369 , mk intTyConName (dph_Instances mods) (fsLit "dPR_Int")
370 , mk doubleTyConName (dph_Instances mods) (fsLit "dPR_Double")
371 ]
372
373 ++ map mk_sum [2..mAX_DPH_SUM]
374 ++ map mk_prod [2..mAX_DPH_PROD]
375 where
376 mk name mod fs = (name, mod, fs)
377
378 mk_sum n = (tyConName $ sumTyCon n bi, dph_Repr mods,
379 mkFastString ("dPR_Sum" ++ show n))
380
381 mk_prod n = (tyConName $ prodTyCon n bi, dph_Repr mods,
382 mkFastString ("dPR_" ++ show n))
383
384 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
385 initBuiltinBoxedTyCons = return . builtinBoxedTyCons
386
387 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
388 builtinBoxedTyCons _ =
389 [(tyConName intPrimTyCon, intTyCon)]
390
391 externalVar :: Module -> FastString -> DsM Var
392 externalVar mod fs
393 = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
394
395 externalTyCon :: Module -> FastString -> DsM TyCon
396 externalTyCon mod fs
397 = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs)
398
399 externalType :: Module -> FastString -> DsM Type
400 externalType mod fs
401 = do
402 tycon <- externalTyCon mod fs
403 return $ mkTyConApp tycon []
404
405 unitTyConName :: Name
406 unitTyConName = tyConName unitTyCon
407
408
409 primMethod :: TyCon -> String -> Builtins -> DsM (Maybe Var)
410 primMethod tycon method (Builtins { dphModules = mods })
411 | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
412 = liftM Just
413 $ dsLookupGlobalId =<< lookupOrig (dph_Unboxed mods)
414 (mkVarOcc $ method ++ suffix)
415
416 | otherwise = return Nothing
417
418 primPArray :: TyCon -> Builtins -> DsM (Maybe TyCon)
419 primPArray tycon (Builtins { dphModules = mods })
420 | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
421 = liftM Just
422 $ dsLookupTyCon =<< lookupOrig (dph_Unboxed mods)
423 (mkOccName tcName $ "PArray" ++ suffix)
424
425 | otherwise = return Nothing
426
427 prim_ty_cons :: NameEnv String
428 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
429 where
430 mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)
431