Use TH to generate instance PR (Wrap a)
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Mon, 9 Nov 2009 12:44:22 +0000 (12:44 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Mon, 9 Nov 2009 12:44:22 +0000 (12:44 +0000)
dph-common/Data/Array/Parallel/Lifted/Repr.hs
dph-common/Data/Array/Parallel/Lifted/TH/Repr.hs

index b73257c..20db73c 100644 (file)
@@ -82,53 +82,7 @@ newtype Wrap a = Wrap { unWrap :: a }
 
 newtype instance PData (Wrap a) = PWrap (PData a)
 
-instance PA a => PR (Wrap a) where
-  {-# INLINE emptyPR #-}
-  emptyPR = PWrap emptyPD
-
-  {-# INLINE replicatePR #-}
-  replicatePR n# (Wrap x) = PWrap (replicatePD n# x)
-
-  {-# INLINE replicatelPR #-}
-  replicatelPR segd (PWrap xs) = PWrap (replicatelPD segd xs)
-
-  {-# INLINE repeatPR #-}
-  repeatPR m# n# (PWrap xs) = PWrap (repeatPD m# n# xs)
-
-  {-# INLINE repeatcPR #-}
-  repeatcPR n# ns segd (PWrap xs) = PWrap (repeatcPD n# ns segd xs)
-
-  {-# INLINE indexPR #-}
-  indexPR (PWrap xs) i# = Wrap (indexPD xs i#)
-
-  {-# INLINE extractPR #-}
-  extractPR (PWrap xs) i# n# = PWrap (extractPD xs i# n#)
-
-  {-# INLINE bpermutePR #-}
-  bpermutePR (PWrap xs) n# is = PWrap (bpermutePD xs n# is)
-
-  {-# INLINE appPR #-}
-  appPR (PWrap xs) (PWrap ys) = PWrap (appPD xs ys)
-
-  {-# INLINE applPR #-}
-  applPR xsegd (PWrap xs) ysegd (PWrap ys)
-    = PWrap (applPD xsegd xs ysegd ys)
-
-  {-# INLINE packPR #-}
-  packPR (PWrap xs) n# bs = PWrap (packPD xs n# bs)
-
-  {-# INLINE packByTagPR #-}
-  packByTagPR (PWrap xs) n# tags t# = PWrap (packByTagPD xs n# tags t#)
-
-  {-# INLINE combine2PR #-}
-  combine2PR n# sel (PWrap xs) (PWrap ys)
-    = PWrap (combine2PD n# sel xs ys)
-
-  {-# INLINE fromListPR #-}
-  fromListPR n# xs = PWrap (fromListPD n# (map unWrap xs))
-
-  {-# INLINE nfPR #-}
-  nfPR (PWrap xs) = nfPD xs
+$(wrapPRInstance ''Wrap 'Wrap 'unWrap 'PWrap)
 
 ------------
 -- Tuples --
index fd1843b..3fd4376 100644 (file)
@@ -1,6 +1,6 @@
 {-# LANGUAGE TemplateHaskell, Rank2Types #-}
 module Data.Array.Parallel.Lifted.TH.Repr (
-  primInstances, tupleInstances, voidPRInstance, unitPRInstance
+  primInstances, tupleInstances, voidPRInstance, unitPRInstance, wrapPRInstance
 ) where
 
 import qualified Data.Array.Parallel.Unlifted as U
@@ -126,15 +126,16 @@ data Arg = RecArg   [ExpQ] [ExpQ]
 
 data Gen = Gen {
              recursiveCalls :: Int
+           , recursiveName  :: Name -> Name
            , split          :: ArgVal -> (Split, Arg)
-           , join           :: Name -> [Arg] -> Val -> [ExpQ] -> ExpQ
+           , join           :: Val -> [ExpQ] -> ExpQ
            }
 
 recursiveMethod :: Gen -> Name -> [ArgVal] -> Val -> DecQ
-recursiveMethod gen meth avs res
-  = simpleFunD (mkName $ nameBase meth) (map pat splits)
+recursiveMethod gen name avs res
+  = simpleFunD (mkName $ nameBase name) (map pat splits)
   $ foldr mk_case
-    (join gen meth args res
+    (join gen res
      . recurse (recursiveCalls gen)
      . trans
      $ map expand args)
@@ -161,8 +162,10 @@ recursiveMethod gen meth avs res
     trans (xs : yss) = zipWith (:) xs (trans yss)
 
     recurse 0 _ = []
-    recurse n [] = replicate n (varE meth)
-    recurse n args = [varE meth `appEs` es| es <- take n args]
+    recurse n [] = replicate n (varE rec_name)
+    recurse n args = [varE rec_name `appEs` es| es <- take n args]
+
+    rec_name = recursiveName gen name
 
 nameGens =
   [
@@ -314,7 +317,49 @@ unitMethod punit meth avs res
     seq_val Nothing  e = e
     seq_val (Just f) e = f e
 
+-- ----
+-- Wrap
+-- ----
+
+wrapPRInstance :: Name -> Name -> Name -> Name -> Q [Dec]
+wrapPRInstance ty wrap unwrap pwrap
+  = do
+      methods <- genPR_methods (recursiveMethod (wrapGen wrap unwrap pwrap))
+      return [InstanceD [ClassP ''PA [a]]
+                        (ConT ''PR `AppT` (ConT ty `AppT` a))
+                        methods]
+  where
+    a = VarT (mkName "a")
+
+wrapGen :: Name -> Name -> Name -> Gen
+wrapGen wrap unwrap pwrap = Gen { recursiveCalls = 1
+                                , recursiveName  = recursiveName
+                                , split          = split
+                                , join           = join }
+  where
+    recursiveName = mkName . replace . nameBase
+      where
+        replace s = init s ++ "D"
+
+    split (ScalarVal, gen)
+      = (PatSplit (conP wrap [varP x]), RecArg [] [varE x])
+      where
+        x = mkName (gen "x")
+
+    split (PDataVal, gen)
+      = (PatSplit (conP pwrap [varP xs]), RecArg [] [varE xs])
+      where
+        xs = mkName (gen "xs")
+
+    split (ListVal, gen)
+      = (PatSplit (varP xs),
+         RecArg [] [varE 'map `appEs` [varE unwrap, varE xs]])
+      where
+        xs = mkName (gen "xs")
 
+    join ScalarVal [x]  = conE wrap `appE` x
+    join PDataVal  [xs] = conE pwrap `appE` xs
+    join UnitVal   [x]  = x
 
 -- ------
 -- Tuples
@@ -350,6 +395,7 @@ instance_PR_tup arity
 
 tupGen :: Int -> Gen
 tupGen arity = Gen { recursiveCalls = arity
+                   , recursiveName  = id
                    , split          = split
                    , join           = join }
   where
@@ -375,9 +421,9 @@ tupGen arity = Gen { recursiveCalls = arity
         unzip | arity == 2 = mkName "unzip"
               | otherwise  = mkName ("unzip" ++ show arity)
 
-    join _ _ ScalarVal xs = tupE xs
-    join _ _ PDataVal  xs = conE (pdataTupCon arity) `appEs` xs
-    join _ _ UnitVal   xs = foldl1 (\x y -> varE 'seq `appEs` [x,y]) xs
+    join ScalarVal xs = tupE xs
+    join PDataVal  xs = conE (pdataTupCon arity) `appEs` xs
+    join UnitVal   xs = foldl1 (\x y -> varE 'seq `appEs` [x,y]) xs
 
     vs  = take arity [[c] | c <- ['a' ..]]
     pvs = take arity [c : "s" | c <- ['a' ..]]