Major refactoring of CoAxioms
[ghc.git] / compiler / vectorise / Vectorise / Generic / PAMethods.hs
1
2 -- | Generate methods for the PA class.
3 --
4 -- TODO: there is a large amount of redundancy here between the
5 -- a, PData a, and PDatas a forms. See if we can factor some of this out.
6 --
7 module Vectorise.Generic.PAMethods
8 ( buildPReprTyCon
9 , buildPAScAndMethods
10 ) where
11
12 import Vectorise.Utils
13 import Vectorise.Monad
14 import Vectorise.Builtins
15 import Vectorise.Generic.Description
16 import CoreSyn
17 import CoreUtils
18 import FamInstEnv
19 import MkCore ( mkWildCase )
20 import TyCon
21 import Type
22 import OccName
23 import Coercion
24 import MkId
25
26 import FastString
27 import MonadUtils
28 import Control.Monad
29 import Outputable
30
31
32 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM FamInst
33 buildPReprTyCon orig_tc vect_tc repr
34 = do name <- mkLocalisedName mkPReprTyConOcc (tyConName orig_tc)
35 rhs_ty <- sumReprType repr
36 prepr_tc <- builtin preprTyCon
37 return $ mkSynFamInst name tyvars prepr_tc instTys rhs_ty
38 where
39 tyvars = tyConTyVars vect_tc
40 instTys = [mkTyConApp vect_tc . mkTyVarTys $ tyConTyVars vect_tc]
41
42 -- buildPAScAndMethods --------------------------------------------------------
43
44 -- | This says how to build the PR superclass and methods of PA
45 -- Recall the definition of the PA class:
46 --
47 -- @
48 -- class class PR (PRepr a) => PA a where
49 -- toPRepr :: a -> PRepr a
50 -- fromPRepr :: PRepr a -> a
51 --
52 -- toArrPRepr :: PData a -> PData (PRepr a)
53 -- fromArrPRepr :: PData (PRepr a) -> PData a
54 --
55 -- toArrPReprs :: PDatas a -> PDatas (PRepr a)
56 -- fromArrPReprs :: PDatas (PRepr a) -> PDatas a
57 -- @
58 --
59 type PAInstanceBuilder
60 = TyCon -- ^ Vectorised TyCon
61 -> CoAxiom -- ^ Coercion to the representation TyCon
62 -> TyCon -- ^ 'PData' TyCon
63 -> TyCon -- ^ 'PDatas' TyCon
64 -> SumRepr -- ^ Description of generic representation.
65 -> VM CoreExpr -- ^ Instance function.
66
67
68 buildPAScAndMethods :: VM [(String, PAInstanceBuilder)]
69 buildPAScAndMethods
70 = return [ ("PR", buildPRDict)
71 , ("toPRepr", buildToPRepr)
72 , ("fromPRepr", buildFromPRepr)
73 , ("toArrPRepr", buildToArrPRepr)
74 , ("fromArrPRepr", buildFromArrPRepr)
75 , ("toArrPReprs", buildToArrPReprs)
76 , ("fromArrPReprs", buildFromArrPReprs)]
77
78
79 buildPRDict :: PAInstanceBuilder
80 buildPRDict vect_tc prepr_ax _ _ _
81 = prDictOfPReprInstTyCon inst_ty prepr_ax arg_tys
82 where
83 arg_tys = mkTyVarTys (tyConTyVars vect_tc)
84 inst_ty = mkTyConApp vect_tc arg_tys
85
86
87 -- buildToPRepr ---------------------------------------------------------------
88 -- | Build the 'toRepr' method of the PA class.
89 buildToPRepr :: PAInstanceBuilder
90 buildToPRepr vect_tc repr_ax _ _ repr
91 = do let arg_ty = mkTyConApp vect_tc ty_args
92
93 -- Get the representation type of the argument.
94 res_ty <- mkPReprType arg_ty
95
96 -- Var to bind the argument
97 arg <- newLocalVar (fsLit "x") arg_ty
98
99 -- Build the expression to convert the argument to the generic representation.
100 result <- to_sum (Var arg) arg_ty res_ty repr
101
102 return $ Lam arg result
103 where
104 ty_args = mkTyVarTys (tyConTyVars vect_tc)
105
106 wrap_repr_inst = wrapTypeFamInstBody repr_ax ty_args
107
108 -- CoreExp to convert the given argument to the generic representation.
109 -- We start by doing a case branch on the possible data constructors.
110 to_sum :: CoreExpr -> Type -> Type -> SumRepr -> VM CoreExpr
111 to_sum _ _ _ EmptySum
112 = do void <- builtin voidVar
113 return $ wrap_repr_inst $ Var void
114
115 to_sum arg arg_ty res_ty (UnarySum r)
116 = do (pat, vars, body) <- con_alt r
117 return $ mkWildCase arg arg_ty res_ty
118 [(pat, vars, wrap_repr_inst body)]
119
120 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
121 , repr_con_tys = tys
122 , repr_cons = cons })
123 = do alts <- mapM con_alt cons
124 let alts' = [(pat, vars, wrap_repr_inst
125 $ mkConApp sum_con (map Type tys ++ [body]))
126 | ((pat, vars, body), sum_con)
127 <- zip alts (tyConDataCons sum_tc)]
128 return $ mkWildCase arg arg_ty res_ty alts'
129
130 con_alt (ConRepr con r)
131 = do (vars, body) <- to_prod r
132 return (DataAlt con, vars, body)
133
134 -- CoreExp to convert data constructor fields to the generic representation.
135 to_prod :: ProdRepr -> VM ([Var], CoreExpr)
136 to_prod EmptyProd
137 = do void <- builtin voidVar
138 return ([], Var void)
139
140 to_prod (UnaryProd comp)
141 = do var <- newLocalVar (fsLit "x") (compOrigType comp)
142 body <- to_comp (Var var) comp
143 return ([var], body)
144
145 to_prod (Prod { repr_tup_tc = tup_tc
146 , repr_comp_tys = tys
147 , repr_comps = comps })
148 = do vars <- newLocalVars (fsLit "x") (map compOrigType comps)
149 exprs <- zipWithM to_comp (map Var vars) comps
150 let [tup_con] = tyConDataCons tup_tc
151 return (vars, mkConApp tup_con (map Type tys ++ exprs))
152
153 -- CoreExp to convert a data constructor component to the generic representation.
154 to_comp :: CoreExpr -> CompRepr -> VM CoreExpr
155 to_comp expr (Keep _ _) = return expr
156 to_comp expr (Wrap ty) = wrapNewTypeBodyOfWrap expr ty
157
158
159 -- buildFromPRepr -------------------------------------------------------------
160
161 -- |Build the 'fromPRepr' method of the PA class.
162 --
163 buildFromPRepr :: PAInstanceBuilder
164 buildFromPRepr vect_tc repr_ax _ _ repr
165 = do
166 arg_ty <- mkPReprType res_ty
167 arg <- newLocalVar (fsLit "x") arg_ty
168
169 result <- from_sum (unwrapTypeFamInstScrut repr_ax ty_args (Var arg))
170 repr
171 return $ Lam arg result
172 where
173 ty_args = mkTyVarTys (tyConTyVars vect_tc)
174 res_ty = mkTyConApp vect_tc ty_args
175
176 from_sum _ EmptySum
177 = do dummy <- builtin fromVoidVar
178 return $ Var dummy `App` Type res_ty
179
180 from_sum expr (UnarySum r) = from_con expr r
181 from_sum expr (Sum { repr_sum_tc = sum_tc
182 , repr_con_tys = tys
183 , repr_cons = cons })
184 = do vars <- newLocalVars (fsLit "x") tys
185 es <- zipWithM from_con (map Var vars) cons
186 return $ mkWildCase expr (exprType expr) res_ty
187 [(DataAlt con, [var], e)
188 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
189
190 from_con expr (ConRepr con r)
191 = from_prod expr (mkConApp con $ map Type ty_args) r
192
193 from_prod _ con EmptyProd = return con
194 from_prod expr con (UnaryProd r)
195 = do e <- from_comp expr r
196 return $ con `App` e
197
198 from_prod expr con (Prod { repr_tup_tc = tup_tc
199 , repr_comp_tys = tys
200 , repr_comps = comps
201 })
202 = do vars <- newLocalVars (fsLit "y") tys
203 es <- zipWithM from_comp (map Var vars) comps
204 let [tup_con] = tyConDataCons tup_tc
205 return $ mkWildCase expr (exprType expr) res_ty
206 [(DataAlt tup_con, vars, con `mkApps` es)]
207
208 from_comp expr (Keep _ _) = return expr
209 from_comp expr (Wrap ty) = unwrapNewTypeBodyOfWrap expr ty
210
211
212 -- buildToArrRepr -------------------------------------------------------------
213
214 -- |Build the 'toArrRepr' method of the PA class.
215 --
216 buildToArrPRepr :: PAInstanceBuilder
217 buildToArrPRepr vect_tc repr_co pdata_tc _ r
218 = do arg_ty <- mkPDataType el_ty
219 res_ty <- mkPDataType =<< mkPReprType el_ty
220 arg <- newLocalVar (fsLit "xs") arg_ty
221
222 pdata_co <- mkBuiltinCo pdataTyCon
223 let co = mkAppCo pdata_co
224 . mkSymCo
225 $ mkAxInstCo repr_co ty_args
226
227 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
228
229 (vars, result) <- to_sum r
230
231 return . Lam arg
232 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
233 [(DataAlt pdata_dc, vars, mkCast result co)]
234 where
235 ty_args = mkTyVarTys $ tyConTyVars vect_tc
236 el_ty = mkTyConApp vect_tc ty_args
237 [pdata_dc] = tyConDataCons pdata_tc
238
239 to_sum ss
240 = case ss of
241 EmptySum -> builtin pvoidVar >>= \pvoid -> return ([], Var pvoid)
242 UnarySum r -> to_con r
243 Sum{}
244 -> do let psum_tc = repr_psum_tc ss
245 let [psum_con] = tyConDataCons psum_tc
246 (vars, exprs) <- mapAndUnzipM to_con (repr_cons ss)
247 sel <- newLocalVar (fsLit "sel") (repr_sel_ty ss)
248 return ( sel : concat vars
249 , wrapFamInstBody psum_tc (repr_con_tys ss)
250 $ mkConApp psum_con
251 $ map Type (repr_con_tys ss) ++ (Var sel : exprs))
252
253 to_prod ss
254 = case ss of
255 EmptyProd -> builtin pvoidVar >>= \pvoid -> return ([], Var pvoid)
256 UnaryProd r
257 -> do pty <- mkPDataType (compOrigType r)
258 var <- newLocalVar (fsLit "x") pty
259 expr <- to_comp (Var var) r
260 return ([var], expr)
261 Prod{}
262 -> do let [ptup_con] = tyConDataCons (repr_ptup_tc ss)
263 ptys <- mapM (mkPDataType . compOrigType) (repr_comps ss)
264 vars <- newLocalVars (fsLit "x") ptys
265 exprs <- zipWithM to_comp (map Var vars) (repr_comps ss)
266 return ( vars
267 , wrapFamInstBody (repr_ptup_tc ss) (repr_comp_tys ss)
268 $ mkConApp ptup_con
269 $ map Type (repr_comp_tys ss) ++ exprs)
270
271 to_con (ConRepr _ r) = to_prod r
272
273 to_comp expr (Keep _ _) = return expr
274 to_comp expr (Wrap ty) = wrapNewTypeBodyOfPDataWrap expr ty
275
276
277 -- buildFromArrPRepr ----------------------------------------------------------
278
279 -- |Build the 'fromArrPRepr' method for the PA class.
280 --
281 buildFromArrPRepr :: PAInstanceBuilder
282 buildFromArrPRepr vect_tc repr_co pdata_tc _ r
283 = do arg_ty <- mkPDataType =<< mkPReprType el_ty
284 res_ty <- mkPDataType el_ty
285 arg <- newLocalVar (fsLit "xs") arg_ty
286
287 pdata_co <- mkBuiltinCo pdataTyCon
288 let co = mkAppCo pdata_co
289 $ mkAxInstCo repr_co var_tys
290
291 let scrut = mkCast (Var arg) co
292
293 let mk_result args
294 = wrapFamInstBody pdata_tc var_tys
295 $ mkConApp pdata_con
296 $ map Type var_tys ++ args
297
298 (expr, _) <- fixV $ \ ~(_, args) ->
299 from_sum res_ty (mk_result args) scrut r
300
301 return $ Lam arg expr
302 where
303 var_tys = mkTyVarTys $ tyConTyVars vect_tc
304 el_ty = mkTyConApp vect_tc var_tys
305 [pdata_con] = tyConDataCons pdata_tc
306
307 from_sum res_ty res expr ss
308 = case ss of
309 EmptySum -> return (res, [])
310 UnarySum r -> from_con res_ty res expr r
311 Sum {}
312 -> do let psum_tc = repr_psum_tc ss
313 let [psum_con] = tyConDataCons psum_tc
314 sel <- newLocalVar (fsLit "sel") (repr_sel_ty ss)
315 ptys <- mapM mkPDataType (repr_con_tys ss)
316 vars <- newLocalVars (fsLit "xs") ptys
317 (res', args) <- fold from_con res_ty res (map Var vars) (repr_cons ss)
318 let scrut = unwrapFamInstScrut psum_tc (repr_con_tys ss) expr
319 let body = mkWildCase scrut (exprType scrut) res_ty
320 [(DataAlt psum_con, sel : vars, res')]
321 return (body, Var sel : args)
322
323 from_prod res_ty res expr ss
324 = case ss of
325 EmptyProd -> return (res, [])
326 UnaryProd r -> from_comp res_ty res expr r
327 Prod {}
328 -> do let ptup_tc = repr_ptup_tc ss
329 let [ptup_con] = tyConDataCons ptup_tc
330 ptys <- mapM mkPDataType (repr_comp_tys ss)
331 vars <- newLocalVars (fsLit "ys") ptys
332 (res', args) <- fold from_comp res_ty res (map Var vars) (repr_comps ss)
333 let scrut = unwrapFamInstScrut ptup_tc (repr_comp_tys ss) expr
334 let body = mkWildCase scrut (exprType scrut) res_ty
335 [(DataAlt ptup_con, vars, res')]
336 return (body, args)
337
338 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
339
340 from_comp _ res expr (Keep _ _) = return (res, [expr])
341 from_comp _ res expr (Wrap ty) = do { expr' <- unwrapNewTypeBodyOfPDataWrap expr ty
342 ; return (res, [expr'])
343 }
344
345 fold f res_ty res exprs rs
346 = foldrM f' (res, []) (zip exprs rs)
347 where
348 f' (expr, r) (res, args)
349 = do (res', args') <- f res_ty res expr r
350 return (res', args' ++ args)
351
352
353 -- buildToArrPReprs -----------------------------------------------------------
354 -- | Build the 'toArrPReprs' instance for the PA class.
355 -- This converts a PData of elements into the generic representation.
356 buildToArrPReprs :: PAInstanceBuilder
357 buildToArrPReprs vect_tc repr_co _ pdatas_tc r
358 = do
359 -- The argument type of the instance.
360 -- eg: 'PDatas (Tree a b)'
361 arg_ty <- mkPDatasType el_ty
362
363 -- The result type.
364 -- eg: 'PDatas (PRepr (Tree a b))'
365 res_ty <- mkPDatasType =<< mkPReprType el_ty
366
367 -- Variable to bind the argument to the instance
368 -- eg: (xss :: PDatas (Tree a b))
369 varg <- newLocalVar (fsLit "xss") arg_ty
370
371 -- Coersion to case between the (PRepr a) type and its instance.
372 pdatas_co <- mkBuiltinCo pdatasTyCon
373 let co = mkAppCo pdatas_co
374 . mkSymCo
375 $ mkAxInstCo repr_co ty_args
376
377 let scrut = unwrapFamInstScrut pdatas_tc ty_args (Var varg)
378 (vars, result) <- to_sum r
379
380 return $ Lam varg
381 $ mkWildCase scrut (mkTyConApp pdatas_tc ty_args) res_ty
382 [(DataAlt pdatas_dc, vars, mkCast result co)]
383
384 where
385 -- The element type of the argument.
386 -- eg: 'Tree a b'.
387 ty_args = mkTyVarTys $ tyConTyVars vect_tc
388 el_ty = mkTyConApp vect_tc ty_args
389
390 -- PDatas data constructor
391 [pdatas_dc] = tyConDataCons pdatas_tc
392
393 to_sum ss
394 = case ss of
395 -- We can't convert data types with no data.
396 -- See Note: [Empty PDatas].
397 EmptySum -> return ([], errorEmptyPDatas el_ty)
398 UnarySum r -> to_con (errorEmptyPDatas el_ty) r
399
400 Sum{}
401 -> do let psums_tc = repr_psums_tc ss
402 let [psums_con] = tyConDataCons psums_tc
403 sels <- newLocalVar (fsLit "sels") (repr_sels_ty ss)
404
405 -- Take the number of selectors to serve as the length of
406 -- and PDatas Void arrays in the product. See Note [Empty PDatas].
407 let xSums = App (repr_selsLength_v ss) (Var sels)
408
409 (vars, exprs) <- mapAndUnzipM (to_con xSums) (repr_cons ss)
410 return ( sels : concat vars
411 , wrapFamInstBody psums_tc (repr_con_tys ss)
412 $ mkConApp psums_con
413 $ map Type (repr_con_tys ss) ++ (Var sels : exprs))
414
415 to_prod xSums ss
416 = case ss of
417 EmptyProd
418 -> do pvoids <- builtin pvoidsVar
419 return ([], App (Var pvoids) xSums )
420
421 UnaryProd r
422 -> do pty <- mkPDatasType (compOrigType r)
423 var <- newLocalVar (fsLit "x") pty
424 expr <- to_comp (Var var) r
425 return ([var], expr)
426
427 Prod{}
428 -> do let [ptups_con] = tyConDataCons (repr_ptups_tc ss)
429 ptys <- mapM (mkPDatasType . compOrigType) (repr_comps ss)
430 vars <- newLocalVars (fsLit "x") ptys
431 exprs <- zipWithM to_comp (map Var vars) (repr_comps ss)
432 return ( vars
433 , wrapFamInstBody (repr_ptups_tc ss) (repr_comp_tys ss)
434 $ mkConApp ptups_con
435 $ map Type (repr_comp_tys ss) ++ exprs)
436
437 to_con xSums (ConRepr _ r)
438 = to_prod xSums r
439
440 to_comp expr (Keep _ _) = return expr
441 to_comp expr (Wrap ty) = wrapNewTypeBodyOfPDatasWrap expr ty
442
443
444 -- buildFromArrPReprs ---------------------------------------------------------
445 buildFromArrPReprs :: PAInstanceBuilder
446 buildFromArrPReprs vect_tc repr_co _ pdatas_tc r
447 = do
448 -- The argument type of the instance.
449 -- eg: 'PDatas (PRepr (Tree a b))'
450 arg_ty <- mkPDatasType =<< mkPReprType el_ty
451
452 -- The result type.
453 -- eg: 'PDatas (Tree a b)'
454 res_ty <- mkPDatasType el_ty
455
456 -- Variable to bind the argument to the instance
457 -- eg: (xss :: PDatas (PRepr (Tree a b)))
458 varg <- newLocalVar (fsLit "xss") arg_ty
459
460 -- Build the coercion between PRepr and the instance type
461 pdatas_co <- mkBuiltinCo pdatasTyCon
462 let co = mkAppCo pdatas_co
463 $ mkAxInstCo repr_co var_tys
464
465 let scrut = mkCast (Var varg) co
466
467 let mk_result args
468 = wrapFamInstBody pdatas_tc var_tys
469 $ mkConApp pdatas_con
470 $ map Type var_tys ++ args
471
472 (expr, _) <- fixV $ \ ~(_, args) ->
473 from_sum res_ty (mk_result args) scrut r
474
475 return $ Lam varg expr
476 where
477 -- The element type of the argument.
478 -- eg: 'Tree a b'.
479 ty_args = mkTyVarTys $ tyConTyVars vect_tc
480 el_ty = mkTyConApp vect_tc ty_args
481
482 var_tys = mkTyVarTys $ tyConTyVars vect_tc
483 [pdatas_con] = tyConDataCons pdatas_tc
484
485 from_sum res_ty res expr ss
486 = case ss of
487 -- We can't convert data types with no data.
488 -- See Note: [Empty PDatas].
489 EmptySum -> return (res, errorEmptyPDatas el_ty)
490 UnarySum r -> from_con res_ty res expr r
491
492 Sum {}
493 -> do let psums_tc = repr_psums_tc ss
494 let [psums_con] = tyConDataCons psums_tc
495 sel <- newLocalVar (fsLit "sels") (repr_sels_ty ss)
496 ptys <- mapM mkPDatasType (repr_con_tys ss)
497 vars <- newLocalVars (fsLit "xs") ptys
498 (res', args) <- fold from_con res_ty res (map Var vars) (repr_cons ss)
499 let scrut = unwrapFamInstScrut psums_tc (repr_con_tys ss) expr
500 let body = mkWildCase scrut (exprType scrut) res_ty
501 [(DataAlt psums_con, sel : vars, res')]
502 return (body, Var sel : args)
503
504 from_prod res_ty res expr ss
505 = case ss of
506 EmptyProd -> return (res, [])
507 UnaryProd r -> from_comp res_ty res expr r
508 Prod {}
509 -> do let ptups_tc = repr_ptups_tc ss
510 let [ptups_con] = tyConDataCons ptups_tc
511 ptys <- mapM mkPDatasType (repr_comp_tys ss)
512 vars <- newLocalVars (fsLit "ys") ptys
513 (res', args) <- fold from_comp res_ty res (map Var vars) (repr_comps ss)
514 let scrut = unwrapFamInstScrut ptups_tc (repr_comp_tys ss) expr
515 let body = mkWildCase scrut (exprType scrut) res_ty
516 [(DataAlt ptups_con, vars, res')]
517 return (body, args)
518
519 from_con res_ty res expr (ConRepr _ r)
520 = from_prod res_ty res expr r
521
522 from_comp _ res expr (Keep _ _) = return (res, [expr])
523 from_comp _ res expr (Wrap ty) = do { expr' <- unwrapNewTypeBodyOfPDatasWrap expr ty
524 ; return (res, [expr'])
525 }
526
527 fold f res_ty res exprs rs
528 = foldrM f' (res, []) (zip exprs rs)
529 where
530 f' (expr, r) (res, args)
531 = do (res', args') <- f res_ty res expr r
532 return (res', args' ++ args)
533
534
535 -- Notes ----------------------------------------------------------------------
536 {-
537 Note [Empty PDatas]
538 ~~~~~~~~~~~~~~~~~~~
539 We don't support "empty" data types like the following:
540
541 data Empty0
542 data Empty1 = MkEmpty1
543 data Empty2 = MkEmpty2 Empty0
544 ...
545
546 There is no parallel data associcated with these types, so there is no where
547 to store the length of the PDatas array with our standard representation.
548
549 Enumerations like the following are ok:
550 data Bool = True | False
551
552 The native and generic representations are:
553 type instance (PDatas Bool) = VPDs:Bool Sels2
554 type instance (PDatas (Repr Bool)) = PSum2s Sels2 (PDatas Void) (PDatas Void)
555
556 To take the length of a (PDatas Bool) we take the length of the contained Sels2.
557 When converting a (PDatas Bool) to a (PDatas (Repr Bool)) we use this length to
558 initialise the two (PDatas Void) arrays.
559
560 However, with this:
561 data Empty1 = MkEmpty1
562
563 The native and generic representations would be:
564 type instance (PDatas Empty1) = VPDs:Empty1
565 type instance (PDatas (Repr Empty1)) = PVoids Int
566
567 The 'Int' argument of PVoids is supposed to store the length of the PDatas
568 array. When converting the (PDatas Empty1) to a (PDatas (Repr Empty1)) we
569 need to come up with a value for it, but there isn't one.
570
571 To fix this we'd need to add an Int field to VPDs:Empty1 as well, but that's
572 too much hassle and there's no point running a parallel computation on no
573 data anyway.
574 -}
575 errorEmptyPDatas :: Type -> a
576 errorEmptyPDatas tc
577 = cantVectorise "Vectorise.PAMethods"
578 $ vcat [ text "Cannot vectorise data type with no parallel data " <> quotes (ppr tc)
579 , text "Data types to be vectorised must contain at least one constructor"
580 , text "with at least one field." ]