vectoriser: formatting to PAMethods and start adding PDatas cases
authorBen Lippmeier <benl@ouroborus.net>
Mon, 14 Nov 2011 04:48:58 +0000 (15:48 +1100)
committerBen Lippmeier <benl@ouroborus.net>
Mon, 14 Nov 2011 04:48:58 +0000 (15:48 +1100)
compiler/vectorise/Vectorise/Generic/PADict.hs
compiler/vectorise/Vectorise/Generic/PAMethods.hs

index 9f4b425..b4c6931 100644 (file)
@@ -20,6 +20,7 @@ import Id
 import Var
 import Name
 
+
 -- debug                = False
 -- dtrace s x   = if debug then pprTrace "Vectoris.Type.PADict" s x else x
 
@@ -59,7 +60,8 @@ buildPADict vect_tc prepr_tc arr_tc repr
       ; let dfun_name = mkLocalisedOccName mod mkPADFunOcc vect_tc_name
       
           -- Get ids for each of the methods in the dictionary, including superclass
-      ; method_ids <- mapM (method args dfun_name) buildPAScAndMethods
+      ; paMethodBuilders <- buildPAScAndMethods
+      ; method_ids       <- mapM (method args dfun_name) paMethodBuilders
 
           -- Expression to build the dictionary.
       ; pa_dc  <- builtin paDataCon
index cbc782e..832c839 100644 (file)
@@ -20,21 +20,15 @@ import MkId
 import FastString
 import MonadUtils
 import Control.Monad
-
-
-mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
-mk_fam_inst fam_tc arg_tc
-  = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
+import Data.Maybe
 
 
 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
 buildPReprTyCon orig_tc vect_tc repr
-  = do
-      name     <- mkLocalisedName mkPReprTyConOcc (tyConName orig_tc)
-      -- rhs_ty   <- buildPReprType vect_tc
-      rhs_ty   <- sumReprType repr
-      prepr_tc <- builtin preprTyCon
-      liftDs $ buildSynTyCon name
+ = do name      <- mkLocalisedName mkPReprTyConOcc (tyConName orig_tc)
+      rhs_ty    <- sumReprType repr
+      prepr_tc  <- builtin preprTyCon
+      liftDs    $  buildSynTyCon name
                              tyvars
                              (SynonymTyCon rhs_ty)
                              (typeKind rhs_ty)
@@ -44,20 +38,47 @@ buildPReprTyCon orig_tc vect_tc repr
     tyvars = tyConTyVars vect_tc
 
 
------------------------------------------------------
-buildPAScAndMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
--- buildPAScandmethods says how to build the PR superclass and methods of PA
+mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
+mk_fam_inst fam_tc arg_tc
+  = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
+
+
+
+-- buildPAScAndMethods --------------------------------------------------------
+
+-- | This says how to build the PR superclass and methods of PA
+--   Recall the definition of the PA class:
+--
+--   @
 --    class class PR (PRepr a) => PA a where
---      toPRepr      :: a -> PRepr a
---      fromPRepr    :: PRepr a -> a
---      toArrPRepr   :: PData a -> PData (PRepr a)
---      fromArrPRepr :: PData (PRepr a) -> PData a
+--      toPRepr       :: a                -> PRepr a
+--      fromPRepr     :: PRepr a          -> a
+--
+--      toArrPRepr    :: PData a          -> PData (PRepr a)
+--      fromArrPRepr  :: PData (PRepr a)  -> PData a
+--
+--      toArrPReprs   :: PDatas a         -> PDatas (PRepr a)    (optional)
+--      fromArrPReprs :: PDatas (PRepr a) -> PDatas a            (optional)
+--   @
+--
+--  Not all lifted backends use the 'toArrPReprs' and 'fromArrPReprs' methods, 
+--  so we only generate these if the 'PDatas' type family is defined.
+--
+buildPAScAndMethods :: VM [( String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
+buildPAScAndMethods
+ = do   hasPDatas <- liftM isJust $ builtin pdatasTyCon
+        return 
+         $    [ ("PR",            buildPRDict)
+              , ("toPRepr",       buildToPRepr)
+              , ("fromPRepr",     buildFromPRepr)
+              , ("toArrPRepr",    buildToArrPRepr)
+              , ("fromArrPRepr",  buildFromArrPRepr)]
+         ++ (if hasPDatas then
+              [ ("toArrPReprs",   buildToArrPReprs)
+              , ("fromArrPReprs", buildFromArrPReprs)]
+              else [])
+             
 
-buildPAScAndMethods = [("PR",           buildPRDict),
-                       ("toPRepr",      buildToPRepr),
-                       ("fromPRepr",    buildFromPRepr),
-                       ("toArrPRepr",   buildToArrPRepr),
-                       ("fromArrPRepr", buildFromArrPRepr)]
 
 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
 buildPRDict vect_tc prepr_tc _ _
@@ -66,35 +87,44 @@ buildPRDict vect_tc prepr_tc _ _
     arg_tys = mkTyVarTys (tyConTyVars vect_tc)
     inst_ty = mkTyConApp vect_tc arg_tys
 
+
+-- buildToPRepr ---------------------------------------------------------------
+-- | Build the 'toRepr' method of the PA class.
 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
 buildToPRepr vect_tc repr_tc _ repr
-  = do
-      let arg_ty = mkTyConApp vect_tc ty_args
+ = do let arg_ty = mkTyConApp vect_tc ty_args
+
+      -- Get the representation type of the argument.
       res_ty <- mkPReprType arg_ty
+
+      -- Var to bind the argument
       arg    <- newLocalVar (fsLit "x") arg_ty
+
+      -- Build the expression to convert the argument to the generic representation.
       result <- to_sum (Var arg) arg_ty res_ty repr
+
       return $ Lam arg result
   where
-    ty_args = mkTyVarTys (tyConTyVars vect_tc)
+    ty_args        = mkTyVarTys (tyConTyVars vect_tc)
 
     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
 
+    -- CoreExp to convert the given argument to the generic representation.
+    -- We start by doing a case branch on the possible data constructors.
+    to_sum :: CoreExpr -> Type -> Type -> SumRepr -> VM CoreExpr
     to_sum _ _ _ EmptySum
-      = do
-          void <- builtin voidVar
+     = do void <- builtin voidVar
           return $ wrap_repr_inst $ Var void
 
     to_sum arg arg_ty res_ty (UnarySum r)
-      = do
-          (pat, vars, body) <- con_alt r
+     = do (pat, vars, body) <- con_alt r
           return $ mkWildCase arg arg_ty res_ty
                    [(pat, vars, wrap_repr_inst body)]
 
     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
                                   , repr_con_tys = tys
                                   , repr_cons    =  cons })
-      = do
-          alts <- mapM con_alt cons
+     = do alts <- mapM con_alt cons
           let alts' = [(pat, vars, wrap_repr_inst
                                    $ mkConApp sum_con (map Type tys ++ [body]))
                         | ((pat, vars, body), sum_con)
@@ -102,37 +132,38 @@ buildToPRepr vect_tc repr_tc _ repr
           return $ mkWildCase arg arg_ty res_ty alts'
 
     con_alt (ConRepr con r)
-      = do
-          (vars, body) <- to_prod r
+     = do (vars, body) <- to_prod r
           return (DataAlt con, vars, body)
 
+    -- CoreExp to convert data constructor fields to the generic representation.
+    to_prod :: ProdRepr -> VM ([Var], CoreExpr)
     to_prod EmptyProd
-      = do
-          void <- builtin voidVar
+     = do void <- builtin voidVar
           return ([], Var void)
 
     to_prod (UnaryProd comp)
-      = do
-          var  <- newLocalVar (fsLit "x") (compOrigType comp)
+     = do var  <- newLocalVar (fsLit "x") (compOrigType comp)
           body <- to_comp (Var var) comp
           return ([var], body)
 
-    to_prod(Prod { repr_tup_tc   = tup_tc
-                 , repr_comp_tys = tys
-                 , repr_comps    = comps })
-      = do
-          vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
+    to_prod (Prod { repr_tup_tc   = tup_tc
+                  , repr_comp_tys = tys
+                  , repr_comps    = comps })
+     = do vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
           exprs <- zipWithM to_comp (map Var vars) comps
+          let [tup_con] = tyConDataCons tup_tc
           return (vars, mkConApp tup_con (map Type tys ++ exprs))
-      where
-        [tup_con] = tyConDataCons tup_tc
 
+    -- CoreExp to convert a data constructor component to the generic representation.
+    to_comp :: CoreExpr -> CompRepr -> VM CoreExpr
     to_comp expr (Keep _ _) = return expr
-    to_comp expr (Wrap ty)  = do
-                                wrap_tc <- builtin wrapTyCon
-                                return $ wrapNewTypeBody wrap_tc [ty] expr
+    to_comp expr (Wrap ty)  
+     = do wrap_tc <- builtin wrapTyCon
+          return $ wrapNewTypeBody wrap_tc [ty] expr
 
 
+-- buildFromPRepr -------------------------------------------------------------
+-- | Build the 'fromPRepr' method of the PA class.
 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
 buildFromPRepr vect_tc repr_tc _ repr
   = do
@@ -147,16 +178,14 @@ buildFromPRepr vect_tc repr_tc _ repr
     res_ty  = mkTyConApp vect_tc ty_args
 
     from_sum _ EmptySum
-      = do
-          dummy <- builtin fromVoidVar
+     = do dummy <- builtin fromVoidVar
           return $ Var dummy `App` Type res_ty
 
     from_sum expr (UnarySum r) = from_con expr r
     from_sum expr (Sum { repr_sum_tc  = sum_tc
                        , repr_con_tys = tys
                        , repr_cons    = cons })
-      = do
-          vars  <- newLocalVars (fsLit "x") tys
+     = do vars  <- newLocalVars (fsLit "x") tys
           es    <- zipWithM from_con (map Var vars) cons
           return $ mkWildCase expr (exprType expr) res_ty
                    [(DataAlt con, [var], e)
@@ -167,21 +196,18 @@ buildFromPRepr vect_tc repr_tc _ repr
 
     from_prod _ con EmptyProd = return con
     from_prod expr con (UnaryProd r)
-      = do
-          e <- from_comp expr r
+     = do e <- from_comp expr r
           return $ con `App` e
      
     from_prod expr con (Prod { repr_tup_tc   = tup_tc
                              , repr_comp_tys = tys
                              , repr_comps    = comps
                              })
-      = do
-          vars <- newLocalVars (fsLit "y") tys
+     = do vars <- newLocalVars (fsLit "y") tys
           es   <- zipWithM from_comp (map Var vars) comps
+          let [tup_con] = tyConDataCons tup_tc
           return $ mkWildCase expr (exprType expr) res_ty
                    [(DataAlt tup_con, vars, con `mkApps` es)]
-      where
-        [tup_con] = tyConDataCons tup_tc  
 
     from_comp expr (Keep _ _) = return expr
     from_comp expr (Wrap ty)
@@ -190,10 +216,11 @@ buildFromPRepr vect_tc repr_tc _ repr
           return $ unwrapNewTypeBody wrap [ty] expr
 
 
+-- buildToArrRepr -------------------------------------------------------------
+-- | Build the 'toArrRepr' method of the PA class.
 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
 buildToArrPRepr vect_tc prepr_tc pdata_tc r
-  = do
-      arg_ty <- mkPDataType el_ty
+ = do arg_ty <- mkPDataType el_ty
       res_ty <- mkPDataType =<< mkPReprType el_ty
       arg    <- newLocalVar (fsLit "xs") arg_ty
 
@@ -217,17 +244,18 @@ buildToArrPRepr vect_tc prepr_tc pdata_tc r
     [pdata_dc] = tyConDataCons pdata_tc
 
 
-    to_sum EmptySum = do
-                        pvoid <- builtin pvoidVar
-                        return ([], Var pvoid)
+    to_sum EmptySum 
+     = do pvoid <- builtin pvoidVar
+          return ([], Var pvoid)
+
     to_sum (UnarySum r) = to_con r
+
     to_sum (Sum { repr_psum_tc = psum_tc
                 , repr_sel_ty  = sel_ty
                 , repr_con_tys = tys
                 , repr_cons    = cons
                 })
-      = do
-          (vars, exprs) <- mapAndUnzipM to_con cons
+     = do (vars, exprs) <- mapAndUnzipM to_con cons
           sel <- newLocalVar (fsLit "sel") sel_ty
           return (sel : concat vars, mk_result (Var sel) exprs)
       where
@@ -238,12 +266,12 @@ buildToArrPRepr vect_tc prepr_tc pdata_tc r
 
     to_con (ConRepr _ r) = to_prod r
 
-    to_prod EmptyProd = do
-                          pvoid <- builtin pvoidVar
-                          return ([], Var pvoid)
+    to_prod EmptyProd
+     = do pvoid <- builtin pvoidVar
+          return ([], Var pvoid)
+
     to_prod (UnaryProd r)
-      = do
-          pty  <- mkPDataType (compOrigType r)
+     = do pty  <- mkPDataType (compOrigType r)
           var  <- newLocalVar (fsLit "x") pty
           expr <- to_comp (Var var) r
           return ([var], expr)
@@ -251,8 +279,7 @@ buildToArrPRepr vect_tc prepr_tc pdata_tc r
     to_prod (Prod { repr_ptup_tc  = ptup_tc
                   , repr_comp_tys = tys
                   , repr_comps    = comps })
-      = do
-          ptys <- mapM (mkPDataType . compOrigType) comps
+     = do ptys <- mapM (mkPDataType . compOrigType) comps
           vars <- newLocalVars (fsLit "x") ptys
           es   <- zipWithM to_comp (map Var vars) comps
           return (vars, mk_result es)
@@ -272,10 +299,11 @@ buildToArrPRepr vect_tc prepr_tc pdata_tc r
           return $ wrapNewTypeBody pwrap_tc [ty] expr
 
 
+-- buildFromArrPRepr ----------------------------------------------------------
+-- | Build the 'fromArrPRepr' method for the PA class.
 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
-  = do
-      arg_ty <- mkPDataType =<< mkPReprType el_ty
+ = do arg_ty <- mkPDataType =<< mkPReprType el_ty
       res_ty <- mkPDataType el_ty
       arg    <- newLocalVar (fsLit "xs") arg_ty
 
@@ -294,15 +322,7 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r
                      from_sum res_ty (mk_result args) scrut r
 
       return $ Lam arg expr
-    
-      -- (args, mk) <- from_sum res_ty scrut r
-      
-      -- let result = wrapFamInstBody pdata_tc var_tys
-      --           . mkConApp pdata_dc
-      --           $ map Type var_tys ++ args
-
-      -- return $ Lam arg (mk result)
-  where
+ where
     var_tys = mkTyVarTys $ tyConTyVars vect_tc
     el_ty   = mkTyConApp vect_tc var_tys
 
@@ -314,8 +334,7 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r
                                   , repr_sel_ty  = sel_ty
                                   , repr_con_tys = tys
                                   , repr_cons    = cons })
-      = do
-          sel  <- newLocalVar (fsLit "sel") sel_ty
+     = do sel  <- newLocalVar (fsLit "sel") sel_ty
           ptys <- mapM mkPDataType tys
           vars <- newLocalVars (fsLit "xs") ptys
           (res', args) <- fold from_con res_ty res (map Var vars) cons
@@ -329,14 +348,14 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r
 
     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
 
-    from_prod _ res _ EmptyProd = return (res, [])
+    from_prod _ res _ EmptyProd
+      = return (res, [])
     from_prod res_ty res expr (UnaryProd r)
       = from_comp res_ty res expr r
     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
                                     , repr_comp_tys = tys
                                     , repr_comps    = comps })
-      = do
-          ptys <- mapM mkPDataType tys
+     = do ptys <- mapM mkPDataType tys
           vars <- newLocalVars (fsLit "ys") ptys
           (res', args) <- fold from_comp res_ty res (map Var vars) comps
           let scrut = unwrapFamInstScrut ptup_tc tys expr
@@ -348,14 +367,19 @@ buildFromArrPRepr vect_tc prepr_tc pdata_tc r
 
     from_comp _ res expr (Keep _ _) = return (res, [expr])
     from_comp _ res expr (Wrap ty)
-      = do
-          wrap_tc  <- builtin wrapTyCon
+     = do wrap_tc       <- builtin wrapTyCon
           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
           return (res, [unwrapNewTypeBody pwrap_tc [ty]
                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
 
     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
       where
-        f' (expr, r) (res, args) = do
-                                     (res', args') <- f res_ty res expr r
-                                     return (res', args' ++ args)
+        f' (expr, r) (res, args) 
+         = do (res', args') <- f res_ty res expr r
+              return (res', args' ++ args)
+
+-- buildToArrPReprs -----------------------------------------------------------
+buildToArrPReprs        = error "buildToArrPReprs not done yet"
+
+-- buildFromArrPReprs ---------------------------------------------------------
+buildFromArrPReprs      = error "buildFromArrPReprs not done yet"
\ No newline at end of file