Add kind equalities to GHC.
[ghc.git] / compiler / vectorise / Vectorise / Generic / PADict.hs
1
2 module Vectorise.Generic.PADict
3 ( buildPADict
4 ) where
5
6 import Vectorise.Monad
7 import Vectorise.Builtins
8 import Vectorise.Generic.Description
9 import Vectorise.Generic.PAMethods ( buildPAScAndMethods )
10 import Vectorise.Utils
11
12 import BasicTypes
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold
16 import Module
17 import TyCon
18 import CoAxiom
19 import Type
20 import Id
21 import Var
22 import Name
23 import FastString
24
25
26 -- |Build the PA dictionary function for some type and hoist it to top level.
27 --
28 -- The PA dictionary holds fns that convert values to and from their vectorised representations.
29 --
30 -- @Recall the definition:
31 -- class PR (PRepr a) => PA a where
32 -- toPRepr :: a -> PRepr a
33 -- fromPRepr :: PRepr a -> a
34 -- toArrPRepr :: PData a -> PData (PRepr a)
35 -- fromArrPRepr :: PData (PRepr a) -> PData a
36 -- toArrPReprs :: PDatas a -> PDatas (PRepr a)
37 -- fromArrPReprs :: PDatas (PRepr a) -> PDatas a
38 --
39 -- Example:
40 -- df :: forall a. PR (PRepr a) -> PA a -> PA (T a)
41 -- df = /\a. \(c:PR (PRepr a)) (d:PA a). MkPA c ($PR_df a d) ($toPRepr a d) ...
42 -- $dPR_df :: forall a. PA a -> PR (PRepr (T a))
43 -- $dPR_df = ....
44 -- $toRepr :: forall a. PA a -> T a -> PRepr (T a)
45 -- $toPRepr = ...
46 -- The "..." stuff is filled in by buildPAScAndMethods
47 -- @
48 --
49 buildPADict
50 :: TyCon -- ^ tycon of the type being vectorised.
51 -> CoAxiom Unbranched
52 -- ^ Coercion between the type and
53 -- its vectorised representation.
54 -> TyCon -- ^ PData instance tycon
55 -> TyCon -- ^ PDatas instance tycon
56 -> SumRepr -- ^ representation used for the type being vectorised.
57 -> VM Var -- ^ name of the top-level dictionary function.
58
59 buildPADict vect_tc prepr_ax pdata_tc pdatas_tc repr
60 = polyAbstract tvs $ \args -> -- The args are the dictionaries we lambda abstract over; and they
61 -- are put in the envt, so when we need a (PA a) we can find it in
62 -- the envt; they don't include the silent superclass args yet
63 do { mod <- liftDs getModule
64 ; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name
65
66 -- The superclass dictionary is a (silent) argument if the tycon is polymorphic...
67 ; let mk_super_ty = do { r <- mkPReprType inst_ty
68 ; pr_cls <- builtin prClass
69 ; return $ mkClassPred pr_cls [r]
70 }
71 ; super_tys <- sequence [mk_super_ty | not (null tvs)]
72 ; super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
73 ; let val_args = super_args ++ args
74 all_args = tvs ++ val_args
75
76 -- ...it is constant otherwise
77 ; super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_ax [] | null tvs]
78
79 -- Get ids for each of the methods in the dictionary, including superclass
80 ; paMethodBuilders <- buildPAScAndMethods
81 ; method_ids <- mapM (method val_args dfun_name) paMethodBuilders
82
83 -- Expression to build the dictionary.
84 ; pa_dc <- builtin paDataCon
85 ; let dict = mkLams all_args (mkConApp pa_dc con_args)
86 con_args = Type inst_ty
87 : map Var super_args -- the superclass dictionary is either
88 ++ super_consts -- lambda-bound or constant
89 ++ map (method_call val_args) method_ids
90
91 -- Build the type of the dictionary function.
92 ; pa_cls <- builtin paClass
93 ; let dfun_ty = mkInvForAllTys tvs
94 $ mkFunTys (map varType val_args)
95 (mkClassPred pa_cls [inst_ty])
96
97 -- Set the unfolding for the inliner.
98 ; raw_dfun <- newExportedVar dfun_name dfun_ty
99 ; let dfun_unf = mkDFunUnfolding all_args pa_dc con_args
100 dfun = raw_dfun `setIdUnfolding` dfun_unf
101 `setInlinePragma` dfunInlinePragma
102
103 -- Add the new binding to the top-level environment.
104 ; hoistBinding dfun dict
105 ; return dfun
106 }
107 where
108 tvs = tyConTyVars vect_tc
109 arg_tys = mkTyVarTys tvs
110 inst_ty = mkTyConApp vect_tc arg_tys
111 vect_tc_name = getName vect_tc
112
113 method args dfun_name (name, build)
114 = localV
115 $ do expr <- build vect_tc prepr_ax pdata_tc pdatas_tc repr
116 let body = mkLams (tvs ++ args) expr
117 raw_var <- newExportedVar (method_name dfun_name name) (exprType body)
118 let var = raw_var
119 `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
120 `setInlinePragma` alwaysInlinePragma
121 hoistBinding var body
122 return var
123
124 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
125 method_name dfun_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)