Merge /Users/benl/devel/ghc/ghc-head-devel
[ghc.git] / compiler / vectorise / Vectorise / Generic / PAMethods.hs
1
2 -- | Generate methods for the PA class.
3 --
4 -- TODO: there is a large amount of redundancy here between the
5 -- a, PData a, and PDatas a forms. See if we can factor some of this out.
6 --
7 module Vectorise.Generic.PAMethods
8 ( buildPReprTyCon
9 , buildPAScAndMethods
10 ) where
11
12 import Vectorise.Utils
13 import Vectorise.Monad
14 import Vectorise.Builtins
15 import Vectorise.Generic.Description
16 import CoreSyn
17 import CoreUtils
18 import MkCore ( mkWildCase )
19 import TyCon
20 import Type
21 import BuildTyCl
22 import OccName
23 import Coercion
24 import MkId
25
26 import FastString
27 import MonadUtils
28 import Control.Monad
29
30
31 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
32 buildPReprTyCon orig_tc vect_tc repr
33 = do name <- mkLocalisedName mkPReprTyConOcc (tyConName orig_tc)
34 rhs_ty <- sumReprType repr
35 prepr_tc <- builtin preprTyCon
36 liftDs $ buildSynTyCon name
37 tyvars
38 (SynonymTyCon rhs_ty)
39 (typeKind rhs_ty)
40 NoParentTyCon
41 (Just $ mk_fam_inst prepr_tc vect_tc)
42 where
43 tyvars = tyConTyVars vect_tc
44
45
46 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
47 mk_fam_inst fam_tc arg_tc
48 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
49
50
51
52 -- buildPAScAndMethods --------------------------------------------------------
53
54 -- | This says how to build the PR superclass and methods of PA
55 -- Recall the definition of the PA class:
56 --
57 -- @
58 -- class class PR (PRepr a) => PA a where
59 -- toPRepr :: a -> PRepr a
60 -- fromPRepr :: PRepr a -> a
61 --
62 -- toArrPRepr :: PData a -> PData (PRepr a)
63 -- fromArrPRepr :: PData (PRepr a) -> PData a
64 --
65 -- toArrPReprs :: PDatas a -> PDatas (PRepr a)
66 -- fromArrPReprs :: PDatas (PRepr a) -> PDatas a
67 -- @
68 --
69 type PAInstanceBuilder
70 = TyCon -- ^ Vectorised TyCon
71 -> TyCon -- ^ Representation TyCon
72 -> TyCon -- ^ 'PData' TyCon
73 -> TyCon -- ^ 'PDatas' TyCon
74 -> SumRepr -- ^ Description of generic representation.
75 -> VM CoreExpr -- ^ Instance function.
76
77
78 buildPAScAndMethods :: VM [(String, PAInstanceBuilder)]
79 buildPAScAndMethods
80 = return [ ("PR", buildPRDict)
81 , ("toPRepr", buildToPRepr)
82 , ("fromPRepr", buildFromPRepr)
83 , ("toArrPRepr", buildToArrPRepr)
84 , ("fromArrPRepr", buildFromArrPRepr)
85 , ("toArrPReprs", buildToArrPReprs)
86 , ("fromArrPReprs", buildFromArrPReprs)]
87
88
89 buildPRDict :: PAInstanceBuilder
90 buildPRDict vect_tc prepr_tc _ _ _
91 = prDictOfPReprInstTyCon inst_ty prepr_tc arg_tys
92 where
93 arg_tys = mkTyVarTys (tyConTyVars vect_tc)
94 inst_ty = mkTyConApp vect_tc arg_tys
95
96
97 -- buildToPRepr ---------------------------------------------------------------
98 -- | Build the 'toRepr' method of the PA class.
99 buildToPRepr :: PAInstanceBuilder
100 buildToPRepr vect_tc repr_tc _ _ repr
101 = do let arg_ty = mkTyConApp vect_tc ty_args
102
103 -- Get the representation type of the argument.
104 res_ty <- mkPReprType arg_ty
105
106 -- Var to bind the argument
107 arg <- newLocalVar (fsLit "x") arg_ty
108
109 -- Build the expression to convert the argument to the generic representation.
110 result <- to_sum (Var arg) arg_ty res_ty repr
111
112 return $ Lam arg result
113 where
114 ty_args = mkTyVarTys (tyConTyVars vect_tc)
115
116 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
117
118 -- CoreExp to convert the given argument to the generic representation.
119 -- We start by doing a case branch on the possible data constructors.
120 to_sum :: CoreExpr -> Type -> Type -> SumRepr -> VM CoreExpr
121 to_sum _ _ _ EmptySum
122 = do void <- builtin voidVar
123 return $ wrap_repr_inst $ Var void
124
125 to_sum arg arg_ty res_ty (UnarySum r)
126 = do (pat, vars, body) <- con_alt r
127 return $ mkWildCase arg arg_ty res_ty
128 [(pat, vars, wrap_repr_inst body)]
129
130 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
131 , repr_con_tys = tys
132 , repr_cons = cons })
133 = do alts <- mapM con_alt cons
134 let alts' = [(pat, vars, wrap_repr_inst
135 $ mkConApp sum_con (map Type tys ++ [body]))
136 | ((pat, vars, body), sum_con)
137 <- zip alts (tyConDataCons sum_tc)]
138 return $ mkWildCase arg arg_ty res_ty alts'
139
140 con_alt (ConRepr con r)
141 = do (vars, body) <- to_prod r
142 return (DataAlt con, vars, body)
143
144 -- CoreExp to convert data constructor fields to the generic representation.
145 to_prod :: ProdRepr -> VM ([Var], CoreExpr)
146 to_prod EmptyProd
147 = do void <- builtin voidVar
148 return ([], Var void)
149
150 to_prod (UnaryProd comp)
151 = do var <- newLocalVar (fsLit "x") (compOrigType comp)
152 body <- to_comp (Var var) comp
153 return ([var], body)
154
155 to_prod (Prod { repr_tup_tc = tup_tc
156 , repr_comp_tys = tys
157 , repr_comps = comps })
158 = do vars <- newLocalVars (fsLit "x") (map compOrigType comps)
159 exprs <- zipWithM to_comp (map Var vars) comps
160 let [tup_con] = tyConDataCons tup_tc
161 return (vars, mkConApp tup_con (map Type tys ++ exprs))
162
163 -- CoreExp to convert a data constructor component to the generic representation.
164 to_comp :: CoreExpr -> CompRepr -> VM CoreExpr
165 to_comp expr (Keep _ _) = return expr
166 to_comp expr (Wrap ty)
167 = do wrap_tc <- builtin wrapTyCon
168 return $ wrapNewTypeBody wrap_tc [ty] expr
169
170
171 -- buildFromPRepr -------------------------------------------------------------
172 -- | Build the 'fromPRepr' method of the PA class.
173 buildFromPRepr :: PAInstanceBuilder
174 buildFromPRepr vect_tc repr_tc _ _ repr
175 = do
176 arg_ty <- mkPReprType res_ty
177 arg <- newLocalVar (fsLit "x") arg_ty
178
179 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
180 repr
181 return $ Lam arg result
182 where
183 ty_args = mkTyVarTys (tyConTyVars vect_tc)
184 res_ty = mkTyConApp vect_tc ty_args
185
186 from_sum _ EmptySum
187 = do dummy <- builtin fromVoidVar
188 return $ Var dummy `App` Type res_ty
189
190 from_sum expr (UnarySum r) = from_con expr r
191 from_sum expr (Sum { repr_sum_tc = sum_tc
192 , repr_con_tys = tys
193 , repr_cons = cons })
194 = do vars <- newLocalVars (fsLit "x") tys
195 es <- zipWithM from_con (map Var vars) cons
196 return $ mkWildCase expr (exprType expr) res_ty
197 [(DataAlt con, [var], e)
198 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
199
200 from_con expr (ConRepr con r)
201 = from_prod expr (mkConApp con $ map Type ty_args) r
202
203 from_prod _ con EmptyProd = return con
204 from_prod expr con (UnaryProd r)
205 = do e <- from_comp expr r
206 return $ con `App` e
207
208 from_prod expr con (Prod { repr_tup_tc = tup_tc
209 , repr_comp_tys = tys
210 , repr_comps = comps
211 })
212 = do vars <- newLocalVars (fsLit "y") tys
213 es <- zipWithM from_comp (map Var vars) comps
214 let [tup_con] = tyConDataCons tup_tc
215 return $ mkWildCase expr (exprType expr) res_ty
216 [(DataAlt tup_con, vars, con `mkApps` es)]
217
218 from_comp expr (Keep _ _) = return expr
219 from_comp expr (Wrap ty)
220 = do
221 wrap <- builtin wrapTyCon
222 return $ unwrapNewTypeBody wrap [ty] expr
223
224
225 -- buildToArrRepr -------------------------------------------------------------
226 -- | Build the 'toArrRepr' method of the PA class.
227 buildToArrPRepr :: PAInstanceBuilder
228 buildToArrPRepr vect_tc prepr_tc pdata_tc _ r
229 = do arg_ty <- mkPDataType el_ty
230 res_ty <- mkPDataType =<< mkPReprType el_ty
231 arg <- newLocalVar (fsLit "xs") arg_ty
232
233 pdata_co <- mkBuiltinCo pdataTyCon
234 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
235 co = mkAppCo pdata_co
236 . mkSymCo
237 $ mkAxInstCo repr_co ty_args
238
239 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
240
241 (vars, result) <- to_sum r
242
243 return . Lam arg
244 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
245 [(DataAlt pdata_dc, vars, mkCast result co)]
246 where
247 ty_args = mkTyVarTys $ tyConTyVars vect_tc
248 el_ty = mkTyConApp vect_tc ty_args
249 [pdata_dc] = tyConDataCons pdata_tc
250
251 to_sum ss
252 = case ss of
253 EmptySum -> builtin pvoidVar >>= \pvoid -> return ([], Var pvoid)
254 UnarySum r -> to_con r
255 Sum{}
256 -> do let psum_tc = repr_psum_tc ss
257 let [psum_con] = tyConDataCons psum_tc
258 (vars, exprs) <- mapAndUnzipM to_con (repr_cons ss)
259 sel <- newLocalVar (fsLit "sel") (repr_sel_ty ss)
260 return ( sel : concat vars
261 , wrapFamInstBody psum_tc (repr_con_tys ss)
262 $ mkConApp psum_con
263 $ map Type (repr_con_tys ss) ++ (Var sel : exprs))
264
265 to_prod ss
266 = case ss of
267 EmptyProd -> builtin pvoidVar >>= \pvoid -> return ([], Var pvoid)
268 UnaryProd r
269 -> do pty <- mkPDataType (compOrigType r)
270 var <- newLocalVar (fsLit "x") pty
271 expr <- to_comp (Var var) r
272 return ([var], expr)
273 Prod{}
274 -> do let [ptup_con] = tyConDataCons (repr_ptup_tc ss)
275 ptys <- mapM (mkPDataType . compOrigType) (repr_comps ss)
276 vars <- newLocalVars (fsLit "x") ptys
277 exprs <- zipWithM to_comp (map Var vars) (repr_comps ss)
278 return ( vars
279 , wrapFamInstBody (repr_ptup_tc ss) (repr_comp_tys ss)
280 $ mkConApp ptup_con
281 $ map Type (repr_comp_tys ss) ++ exprs)
282
283 to_con (ConRepr _ r) = to_prod r
284
285 -- FIXME: this is bound to be wrong!
286 to_comp expr (Keep _ _) = return expr
287 to_comp expr (Wrap ty)
288 = do
289 wrap_tc <- builtin wrapTyCon
290 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
291 return $ wrapNewTypeBody pwrap_tc [ty] expr
292
293
294 -- buildFromArrPRepr ----------------------------------------------------------
295 -- | Build the 'fromArrPRepr' method for the PA class.
296 buildFromArrPRepr :: PAInstanceBuilder
297 buildFromArrPRepr vect_tc prepr_tc pdata_tc _ r
298 = do arg_ty <- mkPDataType =<< mkPReprType el_ty
299 res_ty <- mkPDataType el_ty
300 arg <- newLocalVar (fsLit "xs") arg_ty
301
302 pdata_co <- mkBuiltinCo pdataTyCon
303 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
304 let co = mkAppCo pdata_co
305 $ mkAxInstCo repr_co var_tys
306
307 let scrut = mkCast (Var arg) co
308
309 let mk_result args
310 = wrapFamInstBody pdata_tc var_tys
311 $ mkConApp pdata_con
312 $ map Type var_tys ++ args
313
314 (expr, _) <- fixV $ \ ~(_, args) ->
315 from_sum res_ty (mk_result args) scrut r
316
317 return $ Lam arg expr
318 where
319 var_tys = mkTyVarTys $ tyConTyVars vect_tc
320 el_ty = mkTyConApp vect_tc var_tys
321 [pdata_con] = tyConDataCons pdata_tc
322
323 from_sum res_ty res expr ss
324 = case ss of
325 EmptySum -> return (res, [])
326 UnarySum r -> from_con res_ty res expr r
327 Sum {}
328 -> do let psum_tc = repr_psum_tc ss
329 let [psum_con] = tyConDataCons psum_tc
330 sel <- newLocalVar (fsLit "sel") (repr_sel_ty ss)
331 ptys <- mapM mkPDataType (repr_con_tys ss)
332 vars <- newLocalVars (fsLit "xs") ptys
333 (res', args) <- fold from_con res_ty res (map Var vars) (repr_cons ss)
334 let scrut = unwrapFamInstScrut psum_tc (repr_con_tys ss) expr
335 let body = mkWildCase scrut (exprType scrut) res_ty
336 [(DataAlt psum_con, sel : vars, res')]
337 return (body, Var sel : args)
338
339 from_prod res_ty res expr ss
340 = case ss of
341 EmptyProd -> return (res, [])
342 UnaryProd r -> from_comp res_ty res expr r
343 Prod {}
344 -> do let ptup_tc = repr_ptup_tc ss
345 let [ptup_con] = tyConDataCons ptup_tc
346 ptys <- mapM mkPDataType (repr_comp_tys ss)
347 vars <- newLocalVars (fsLit "ys") ptys
348 (res', args) <- fold from_comp res_ty res (map Var vars) (repr_comps ss)
349 let scrut = unwrapFamInstScrut ptup_tc (repr_comp_tys ss) expr
350 let body = mkWildCase scrut (exprType scrut) res_ty
351 [(DataAlt ptup_con, vars, res')]
352 return (body, args)
353
354 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
355
356 from_comp _ res expr (Keep _ _) = return (res, [expr])
357 from_comp _ res expr (Wrap ty)
358 = do wrap_tc <- builtin wrapTyCon
359 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
360 return (res, [unwrapNewTypeBody pwrap_tc [ty]
361 $ unwrapFamInstScrut pwrap_tc [ty] expr])
362
363 fold f res_ty res exprs rs
364 = foldrM f' (res, []) (zip exprs rs)
365 where
366 f' (expr, r) (res, args)
367 = do (res', args') <- f res_ty res expr r
368 return (res', args' ++ args)
369
370
371 -- buildToArrPReprs -----------------------------------------------------------
372 -- | Build the 'toArrPReprs' instance for the PA class.
373 -- This converts a PData of elements into the generic representation.
374 buildToArrPReprs :: PAInstanceBuilder
375 buildToArrPReprs vect_tc prepr_tc _ pdatas_tc r
376 = do
377 -- The argument type of the instance.
378 -- eg: 'PDatas (Tree a b)'
379 arg_ty <- mkPDatasType el_ty
380
381 -- The result type.
382 -- eg: 'PDatas (PRepr (Tree a b))'
383 res_ty <- mkPDatasType =<< mkPReprType el_ty
384
385 -- Variable to bind the argument to the instance
386 -- eg: (xss :: PDatas (Tree a b))
387 varg <- newLocalVar (fsLit "xss") arg_ty
388
389 -- Coersion to case between the (PRepr a) type and its instance.
390 pdatas_co <- mkBuiltinCo pdatasTyCon
391 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
392 let co = mkAppCo pdatas_co
393 . mkSymCo
394 $ mkAxInstCo repr_co ty_args
395
396 let scrut = unwrapFamInstScrut pdatas_tc ty_args (Var varg)
397 (vars, result) <- to_sum r
398
399 return $ Lam varg
400 $ mkWildCase scrut (mkTyConApp pdatas_tc ty_args) res_ty
401 [(DataAlt pdatas_dc, vars, mkCast result co)]
402
403 where
404 -- The element type of the argument.
405 -- eg: 'Tree a b'.
406 ty_args = mkTyVarTys $ tyConTyVars vect_tc
407 el_ty = mkTyConApp vect_tc ty_args
408
409 -- PDatas data constructor
410 [pdatas_dc] = tyConDataCons pdatas_tc
411
412 to_sum ss
413 = case ss of
414 EmptySum -> builtin pvoidsVar >>= \pvoids -> return ([], Var pvoids)
415 UnarySum r -> to_con r
416 Sum{}
417 -> do let psums_tc = repr_psums_tc ss
418 let [psums_con] = tyConDataCons psums_tc
419 (vars, exprs) <- mapAndUnzipM to_con (repr_cons ss)
420 sel <- newLocalVar (fsLit "sels") (repr_sels_ty ss)
421 return ( sel : concat vars
422 , wrapFamInstBody psums_tc (repr_con_tys ss)
423 $ mkConApp psums_con
424 $ map Type (repr_con_tys ss) ++ (Var sel : exprs))
425
426 to_prod ss
427 = case ss of
428 EmptyProd -> builtin pvoidsVar >>= \pvoids -> return ([], Var pvoids)
429 UnaryProd r
430 -> do pty <- mkPDatasType (compOrigType r)
431 var <- newLocalVar (fsLit "x") pty
432 expr <- to_comp (Var var) r
433 return ([var], expr)
434 Prod{}
435 -> do let [ptups_con] = tyConDataCons (repr_ptups_tc ss)
436 ptys <- mapM (mkPDatasType . compOrigType) (repr_comps ss)
437 vars <- newLocalVars (fsLit "x") ptys
438 exprs <- zipWithM to_comp (map Var vars) (repr_comps ss)
439 return ( vars
440 , wrapFamInstBody (repr_ptups_tc ss) (repr_comp_tys ss)
441 $ mkConApp ptups_con
442 $ map Type (repr_comp_tys ss) ++ exprs)
443
444 to_con (ConRepr _ r) = to_prod r
445
446 -- FIXME: this is bound to be wrong!
447 to_comp expr (Keep _ _) = return expr
448 to_comp expr (Wrap ty)
449 = do wrap_tc <- builtin wrapTyCon
450 (pwrap_tc, _) <- pdatasReprTyCon (mkTyConApp wrap_tc [ty])
451 return $ wrapNewTypeBody pwrap_tc [ty] expr
452
453
454 -- buildFromArrPReprs ---------------------------------------------------------
455 buildFromArrPReprs :: PAInstanceBuilder
456 buildFromArrPReprs vect_tc prepr_tc _ pdatas_tc r
457 = do
458 -- The element type of the argument.
459 -- eg: 'Tree a b'.
460 let ty_args = mkTyVarTys $ tyConTyVars vect_tc
461 let el_ty = mkTyConApp vect_tc ty_args
462
463 -- The argument type of the instance.
464 -- eg: 'PDatas (PRepr (Tree a b))'
465 arg_ty <- mkPDatasType =<< mkPReprType el_ty
466
467 -- The result type.
468 -- eg: 'PDatas (Tree a b)'
469 res_ty <- mkPDatasType el_ty
470
471 -- Variable to bind the argument to the instance
472 -- eg: (xss :: PDatas (PRepr (Tree a b)))
473 varg <- newLocalVar (fsLit "xss") arg_ty
474
475 -- Build the coersion between PRepr and the instance type
476 pdatas_co <- mkBuiltinCo pdatasTyCon
477 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
478 let co = mkAppCo pdatas_co
479 $ mkAxInstCo repr_co var_tys
480
481 let scrut = mkCast (Var varg) co
482
483 let mk_result args
484 = wrapFamInstBody pdatas_tc var_tys
485 $ mkConApp pdatas_con
486 $ map Type var_tys ++ args
487
488 (expr, _) <- fixV $ \ ~(_, args) ->
489 from_sum res_ty (mk_result args) scrut r
490
491 return $ Lam varg expr
492 where
493 var_tys = mkTyVarTys $ tyConTyVars vect_tc
494 [pdatas_con] = tyConDataCons pdatas_tc
495
496 from_sum res_ty res expr ss
497 = case ss of
498 EmptySum -> return (res, [])
499 UnarySum r -> from_con res_ty res expr r
500 Sum {}
501 -> do let psums_tc = repr_psums_tc ss
502 let [psums_con] = tyConDataCons psums_tc
503 sel <- newLocalVar (fsLit "sels") (repr_sels_ty ss)
504 ptys <- mapM mkPDatasType (repr_con_tys ss)
505 vars <- newLocalVars (fsLit "xs") ptys
506 (res', args) <- fold from_con res_ty res (map Var vars) (repr_cons ss)
507 let scrut = unwrapFamInstScrut psums_tc (repr_con_tys ss) expr
508 let body = mkWildCase scrut (exprType scrut) res_ty
509 [(DataAlt psums_con, sel : vars, res')]
510 return (body, Var sel : args)
511
512 from_prod res_ty res expr ss
513 = case ss of
514 EmptyProd -> return (res, [])
515 UnaryProd r -> from_comp res_ty res expr r
516 Prod {}
517 -> do let ptups_tc = repr_ptups_tc ss
518 let [ptups_con] = tyConDataCons ptups_tc
519 ptys <- mapM mkPDatasType (repr_comp_tys ss)
520 vars <- newLocalVars (fsLit "ys") ptys
521 (res', args) <- fold from_comp res_ty res (map Var vars) (repr_comps ss)
522 let scrut = unwrapFamInstScrut ptups_tc (repr_comp_tys ss) expr
523 let body = mkWildCase scrut (exprType scrut) res_ty
524 [(DataAlt ptups_con, vars, res')]
525 return (body, args)
526
527 from_con res_ty res expr (ConRepr _ r)
528 = from_prod res_ty res expr r
529
530 from_comp _ res expr (Keep _ _) = return (res, [expr])
531 from_comp _ res expr (Wrap ty)
532 = do wrap_tc <- builtin wrapTyCon
533 (pwraps_tc, _) <- pdatasReprTyCon (mkTyConApp wrap_tc [ty])
534 return (res, [unwrapNewTypeBody pwraps_tc [ty]
535 $ unwrapFamInstScrut pwraps_tc [ty] expr])
536
537 fold f res_ty res exprs rs
538 = foldrM f' (res, []) (zip exprs rs)
539 where
540 f' (expr, r) (res, args)
541 = do (res', args') <- f res_ty res expr r
542 return (res', args' ++ args)
543