Makefiles need real tab chars, ffs.
[packages/dph.git] / dph-lifted-copy / Data / Array / Parallel / Lifted / TH / Repr.hs
1 {-# LANGUAGE TemplateHaskell, Rank2Types #-}
2 module Data.Array.Parallel.Lifted.TH.Repr (
3 scalarInstances, tupleInstances,
4 voidPRInstance, unitPRInstance, wrapPRInstance
5 ) where
6
7 import qualified Data.Array.Parallel.Unlifted as U
8 import Data.Array.Parallel.Lifted.PArray
9 import Data.Array.Parallel.Base.DTrace (traceFn)
10
11 import Language.Haskell.TH
12 import Data.List (intercalate)
13
14 tyBndrVar :: TyVarBndr -> Name
15 tyBndrVar (PlainTV n) = n
16 tyBndrVar (KindedTV n _) = n
17
18 mkAppTs :: Type -> [Type] -> Type
19 mkAppTs = foldl AppT
20
21 varTs :: [Name] -> [TypeQ]
22 varTs = map varT
23
24 appTs :: TypeQ -> [TypeQ] -> TypeQ
25 appTs = foldl appT
26
27 varEs :: [Name] -> [ExpQ]
28 varEs = map varE
29
30 appEs :: ExpQ -> [ExpQ] -> ExpQ
31 appEs = foldl appE
32
33 normalMatch :: PatQ -> ExpQ -> MatchQ
34 normalMatch pat xx = match pat (normalB xx) []
35
36 varPs :: [Name] -> [PatQ]
37 varPs = map varP
38
39 vanillaC :: Name -> [TypeQ] -> ConQ
40 vanillaC con tys = normalC con (map (strictType notStrict) tys)
41
42
43 simpleFunD :: Name -> [PatQ] -> ExpQ -> DecQ
44 simpleFunD name pats xx
45 = funD name [clause pats (normalB xx) []]
46
47
48 inlineD :: Name -> DecQ
49 inlineD name = pragInlD name (inlineSpecNoPhase True False)
50
51
52 instance_PData :: TypeQ -> [Name] -> Name -> [TypeQ] -> DecQ
53 instance_PData tycon tyargs con tys
54 = dataInstD (cxt []) ''PData [tycon `appTs` varTs tyargs]
55 [vanillaC con tys]
56 []
57
58
59 newtype_instance_PData :: Name -> [Name] -> Name -> TypeQ -> DecQ
60 newtype_instance_PData tycon tyargs con ty
61 = newtypeInstD (cxt []) ''PData [conT tycon `appTs` varTs tyargs]
62 (vanillaC con [ty])
63 []
64
65
66 splitConAppTy :: Type -> Maybe (Type, [Type])
67 splitConAppTy ty = collect ty []
68 where
69 collect (ConT tycon) args = Just (ConT tycon, args)
70 collect (TupleT n) args = Just (TupleT n, args)
71 collect ListT args = Just (ListT, args)
72 collect ArrowT args = Just (ArrowT, args)
73 collect (AppT t arg) args = collect t (arg:args)
74 collect _ _ = Nothing
75
76
77 normaliseTy :: Type -> Q Type
78 normaliseTy ty
79 = case splitConAppTy ty of
80 Just (ConT tycon, args)
81 -> do
82 info <- reify tycon
83 case info of
84 TyConI (TySynD _ bndrs t)
85 -> return $ substTy (zip (map tyBndrVar bndrs) args) t
86 _ -> return ty
87 _ -> return ty
88
89
90 substTy :: [(Name, Type)] -> Type -> Type
91 substTy _ (ForallT _ _ _)
92 = error "DPH gen: can't substitute in forall ty"
93
94 substTy env (VarT v) = case lookup v env of
95 Just ty -> ty
96 Nothing -> VarT v
97 substTy env (AppT t u) = AppT (substTy env t) (substTy env u)
98 substTy env (SigT t k) = SigT (substTy env t) k
99 substTy _ t = t
100
101
102 splitFunTy :: Type -> ([Type], Type)
103 splitFunTy ty = case splitConAppTy ty of
104 Just (ArrowT, [arg, r]) -> let (args, res) = splitFunTy r
105 in (arg:args, res)
106 _ -> ([], ty)
107
108 data Val = ScalarVal
109 | PDataVal
110 | ListVal
111 | UnitVal
112 | OtherVal
113 type NameGen = String -> String
114 type ArgVal = (Val, NameGen)
115
116 genPR_methods :: (Name -> [ArgVal] -> Val -> DecQ) -> Q [Dec]
117 genPR_methods mk_method
118 = do
119 ClassI (ClassD _ _ _ _ decs) _ <- reify ''PR
120 inls <- sequence [inlineD $ mkName $ nameBase name | SigD name _ <- decs]
121 defs <- mapM gen [(name, ty) | SigD name ty <- decs]
122 return $ inls ++ defs
123 where
124 gen (name, ty)
125 = case lookup name nameGens of
126 Just gs -> do
127 (args, res) <- methodVals ty
128 mk_method name (zip args gs) res
129 Nothing -> error $ "DPH gen: no name generator for " ++ show name
130
131
132 methodVals :: Type -> Q ([Val], Val)
133 methodVals (ForallT (PlainTV vv : _) _ ty)
134 = do
135 ty' <- normaliseTy ty
136 let (args, res) = splitFunTy ty'
137
138 return (map (val vv) args, val vv res)
139 where
140 val v (VarT n) | v == n = ScalarVal
141
142 val v (AppT (ConT c) (VarT n))
143 | c == ''PData && v == n = PDataVal
144 | c == ''[] && v == n = ListVal
145
146 val v (AppT ListT (VarT n)) | v==n = ListVal
147 val _ (ConT c) | c == ''() = UnitVal
148 val _ (TupleT 0) = UnitVal
149 val _ _ = OtherVal
150
151 methodVals tt
152 = error $ "DPH gen: methodVals: no match for " ++ show tt
153
154
155 data Split = PatSplit PatQ
156 | CaseSplit PatQ ExpQ PatQ
157
158 data Arg = RecArg [ExpQ] [ExpQ]
159 | OtherArg ExpQ
160
161 data Gen = Gen {
162 recursiveCalls :: Int
163 , recursiveName :: Name -> Name
164 , split :: ArgVal -> (Split, Arg)
165 , join :: Val -> [ExpQ] -> ExpQ
166 , typeName :: String
167 }
168
169 recursiveMethod :: Gen -> Name -> [ArgVal] -> Val -> DecQ
170 recursiveMethod gen name avs res
171 = simpleFunD (mkName $ nameBase name) (map pat splits)
172 $ appE (varE 'traceFn `appEs` [stringE (nameBase name), stringE (typeName gen)])
173 $ foldr mk_case
174 (join gen res
175 . recurse (recursiveCalls gen)
176 . trans
177 $ map expand args)
178 splits
179 where
180 (splits, args) = unzip (map split_arg avs)
181
182 pat (PatSplit p) = p
183 pat (CaseSplit p _ _) = p
184
185 split_arg (OtherVal, g)
186 = let v = mkName (g "")
187 in (PatSplit (varP v), OtherArg (varE v))
188
189 split_arg arg = split gen arg
190
191 mk_case (PatSplit _) xx = xx
192 mk_case (CaseSplit _ scrut pat') xx = caseE scrut [normalMatch pat' xx]
193
194 expand (RecArg _ es) = es
195 expand (OtherArg e) = repeat e
196
197 trans [] = []
198 trans [xs] = [[x] | x <- xs]
199 trans (xs : yss) = zipWith (:) xs (trans yss)
200
201 recurse 0 _ = []
202 recurse n [] = replicate n (varE rec_name)
203 recurse n args' = [varE rec_name `appEs` es | es <- take n args']
204
205 rec_name = recursiveName gen name
206
207
208 nameGens :: [(Name, [[Char] -> [Char]])]
209 nameGens =
210 [
211 ('emptyPR, [])
212 , ('replicatePR, [const "n#", id])
213 , ('replicatelPR, [const "segd", id])
214 , ('repeatPR, [const "n#", const "len#", id])
215 , ('indexPR, [id, const "i#"])
216 , ('extractPR, [id, const "i#", const "n#"])
217 , ('bpermutePR, [id, const "n#", const "ixs"])
218 , ('appPR, [(++"1"), (++"2")])
219 , ('applPR, [const "segd", const "ixs", (++"1"), const "jxs", (++"2")])
220 , ('packByTagPR, [id, const "n#", const "tags", const "t#"])
221 , ('combine2PR, [const "n#", const "sel", (++"1"), (++"2")])
222 , ('updatePR, [(++"1"), const "ixs", (++"2")])
223 , ('fromListPR, [const "n#", id])
224 , ('nfPR, [id])
225 ]
226
227 -- ---------------
228 -- Scalar types
229 -- ---------------
230
231 scalarInstances :: [Name] -> Q [Dec]
232 scalarInstances tys
233 = do
234 pdatas <- mapM instance_PData_scalar tys
235 scalars <- mapM instance_Scalar_scalar tys
236 prs <- mapM instance_PR_scalar tys
237 return $ pdatas ++ scalars ++ prs
238
239 pdataScalarCon :: Name -> Name
240 pdataScalarCon n = mkName ("P" ++ nameBase n)
241
242 instance_PData_scalar :: Name -> DecQ
243 instance_PData_scalar tycon
244 = newtype_instance_PData tycon [] (pdataScalarCon tycon)
245 (conT ''U.Array `appT` conT tycon)
246
247 instance_Scalar_scalar :: Name -> DecQ
248 instance_Scalar_scalar ty
249 = instanceD (cxt [])
250 (conT ''Scalar `appT` conT ty)
251 (map (inlineD . mkName . fst) methods ++ map snd methods)
252 where
253 pcon = pdataScalarCon ty
254 xs = mkName "xs"
255
256 methods = [("fromScalarPData", mk_fromScalarPData),
257 ("toScalarPData", mk_toScalarPData)]
258
259 mk_fromScalarPData = simpleFunD (mkName "fromScalarPData")
260 [conP pcon [varP xs]]
261 (varE xs)
262 mk_toScalarPData = simpleFunD (mkName "toScalarPData") [] (conE pcon)
263
264 instance_PR_scalar :: Name -> DecQ
265 instance_PR_scalar ty
266 = do
267 methods <- genPR_methods (scalarMethod ty)
268 return $ InstanceD []
269 (ConT ''PR `AppT` ConT ty)
270 methods
271
272 scalarMethod :: Name -> Name -> [ArgVal] -> Val -> DecQ
273 scalarMethod _ meth _ _
274 = simpleFunD (mkName $ nameBase meth) []
275 $ varE
276 $ mkName (nameBase meth ++ "Scalar")
277
278 {-
279 = simpleFunD (mkName $ nameBase meth) pats
280 $ result res
281 $ varE impl `appEs` vals
282 where
283 pcon = pdataPrimCon ty
284 impl = mkName
285 $ nameBase meth ++ "Prim"
286
287 (pats, vals) = unzip [arg v g | (v,g) <- avs]
288
289 arg ScalarVal g = var (g "x")
290 arg PDataVal g = let v = mkName (g "xs")
291 in (conP pcon [varP v], varE v)
292 arg ListVal g = var (g "xs")
293 arg OtherVal g = var (g "")
294
295 var s = let v = mkName s in (varP v, varE v)
296
297 result ScalarVal e = e
298 result PDataVal e = conE pcon `appE` e
299 result UnitVal e = varE 'seq `appEs` [e, varE '()]
300 result OtherVal e = e
301 -}
302
303 -- ----
304 -- Void
305 -- ----
306
307 voidPRInstance :: Name -> Name -> Name -> Q [Dec]
308 voidPRInstance ty void pvoid
309 = do
310 methods <- genPR_methods (voidMethod void pvoid)
311 return [InstanceD []
312 (ConT ''PR `AppT` ConT ty)
313 methods]
314
315 voidMethod :: Name -> Name -> Name -> [ArgVal] -> Val -> DecQ
316 voidMethod void pvoid meth avs res
317 = simpleFunD (mkName $ nameBase meth) (map (const wildP) avs)
318 $ result res
319 where
320 result ScalarVal = varE void
321 result PDataVal = varE pvoid
322 result UnitVal = conE '()
323 result _ = error "DPH gen: voidMethod: no match"
324
325 -- --
326 -- ()
327 -- --
328
329 unitPRInstance :: Name -> Q [Dec]
330 unitPRInstance punit
331 = do
332 methods <- genPR_methods (unitMethod punit)
333 return [InstanceD []
334 (ConT ''PR `AppT` ConT ''())
335 methods]
336
337 unitMethod :: Name -> Name -> [ArgVal] -> Val -> DecQ
338 unitMethod punit meth avs res
339 = simpleFunD (mkName $ nameBase meth) pats
340 $ foldr seq_val (result res) es
341 where
342 (pats, es) = unzip [mkpat v g | (v,g) <- avs]
343
344 mkpat ScalarVal _ = (conP '() [], Nothing)
345 mkpat PDataVal _ = (conP punit [], Nothing)
346
347 mkpat ListVal g
348 = let xs = mkName (g "xs")
349 in (varP xs, Just $ \e -> varE 'foldr `appEs` [varE 'seq, e, varE xs])
350
351 mkpat OtherVal _ = (wildP, Nothing)
352 mkpat _ _ = error "DPH gen: unitMethod/mkpat: no match"
353
354 result ScalarVal = conE '()
355 result PDataVal = conE punit
356 result UnitVal = conE '()
357 result _ = error "DPH gen: unitMethod/result: no match"
358
359 seq_val Nothing e = e
360 seq_val (Just f) e = f e
361
362 -- ----
363 -- Wrap
364 -- ----
365
366 wrapPRInstance :: Name -> Name -> Name -> Name -> Q [Dec]
367 wrapPRInstance ty wrap unwrap pwrap
368 = do
369 methods <- genPR_methods (recursiveMethod (wrapGen wrap unwrap pwrap))
370 return [InstanceD [ClassP ''PA [a]]
371 (ConT ''PR `AppT` (ConT ty `AppT` a))
372 methods]
373 where
374 a = VarT (mkName "a")
375
376 wrapGen :: Name -> Name -> Name -> Gen
377 wrapGen wrap unwrap pwrap
378 = Gen { recursiveCalls = 1
379 , recursiveName = recursiveName'
380 , split = split'
381 , join = join'
382 , typeName = "Wrap a"
383 }
384 where
385 recursiveName' = mkName . replace . nameBase
386 where
387 replace s = init s ++ "D"
388
389 split' (ScalarVal, gen)
390 = (PatSplit (conP wrap [varP x]), RecArg [] [varE x])
391 where
392 x = mkName (gen "x")
393
394 split' (PDataVal, gen)
395 = (PatSplit (conP pwrap [varP xs]), RecArg [] [varE xs])
396 where
397 xs = mkName (gen "xs")
398
399 split' (ListVal, gen)
400 = (PatSplit (varP xs),
401 RecArg [] [varE 'map `appEs` [varE unwrap, varE xs]])
402 where
403 xs = mkName (gen "xs")
404
405 split' _ = error "DPH gen: split': no match"
406
407
408 join' ScalarVal [x] = conE wrap `appE` x
409 join' PDataVal [xs] = conE pwrap `appE` xs
410 join' UnitVal [x] = x
411 join' _ _ = error "DPH gen: wrapGen: no match"
412
413
414 -- ------
415 -- Tuples
416 -- ------
417
418 tupleInstances :: [Int] -> Q [Dec]
419 tupleInstances ns
420 = do
421 pdatas <- mapM instance_PData_tup ns
422 prs <- mapM instance_PR_tup ns
423 return $ pdatas ++ prs
424
425 pdataTupCon :: Int -> Name
426 pdataTupCon n = mkName ("P_" ++ show n)
427
428 instance_PData_tup :: Int -> DecQ
429 instance_PData_tup arity
430 = instance_PData (tupleT arity) vars (pdataTupCon arity)
431 [conT ''PData `appT` varT v | v <- vars]
432 where
433 vars = take arity $ [mkName [c] | c <- ['a' .. ]]
434
435
436 instance_PR_tup :: Int -> DecQ
437 instance_PR_tup arity
438 = do
439 methods <- genPR_methods (recursiveMethod (tupGen arity))
440 return $ InstanceD [ClassP ''PR [ty] | ty <- tys]
441 (ConT ''PR `AppT` (TupleT arity `mkAppTs` tys))
442 methods
443 where
444 tys = take arity $ [VarT $ mkName [c] | c <- ['a' .. ]]
445
446 tupGen :: Int -> Gen
447 tupGen arity = Gen { recursiveCalls = arity
448 , recursiveName = id
449 , split = split'
450 , join = join'
451 , typeName = tyname
452 }
453 where
454 split' (ScalarVal, gen)
455 = (PatSplit (tupP $ varPs names), RecArg [] (varEs names))
456 where
457 names = map (mkName . gen) vs
458
459 split' (PDataVal, gen)
460 = (PatSplit (conP (pdataTupCon arity) $ varPs names),
461 RecArg [] (varEs names))
462 where
463 names = map (mkName . gen) pvs
464
465 split' (ListVal, gen)
466 = (CaseSplit (varP xs) (varE mkunzip `appE` varE xs)
467 (tupP $ varPs names),
468 RecArg [] (varEs names))
469 where
470 xs = mkName (gen "xs")
471 names = map (mkName . gen) pvs
472
473 mkunzip | arity == 2 = mkName "unzip"
474 | otherwise = mkName ("unzip" ++ show arity)
475
476 split' _ = error "DPH Gen: tupGen/split: no match"
477
478
479 join' ScalarVal xs = tupE xs
480 join' PDataVal xs = conE (pdataTupCon arity) `appEs` xs
481 join' UnitVal xs = foldl1 (\x y -> varE 'seq `appEs` [x,y]) xs
482 join' _ _ = error "DPH Gen: tupGen/join: no match"
483
484 vs = take arity [[c] | c <- ['a' ..]]
485 pvs = take arity [c : "s" | c <- ['a' ..]]
486
487 tyname = "(" ++ intercalate "," vs ++ ")"
488