Deal with closures/zipWith/unzip of size 4
authorGabriele Keller <keller@cse.unsw.edu.au>
Wed, 1 Feb 2012 05:34:43 +0000 (16:34 +1100)
committerGabriele Keller <keller@cse.unsw.edu.au>
Fri, 17 Feb 2012 02:11:59 +0000 (13:11 +1100)
dph-lifted-boxed/Data/Array/Parallel/Lifted/Closure.hs
dph-lifted-boxed/Data/Array/Parallel/PArray/PData.hs
dph-lifted-copy/Data/Array/Parallel/Lifted/Closure.hs
dph-lifted-copy/Data/Array/Parallel/Lifted/Combinators.hs
dph-lifted-copy/Data/Array/Parallel/Lifted/Scalar.hs
dph-lifted-copy/Data/Array/Parallel/PArray/PDataInstances.hs
dph-lifted-copy/Data/Array/Parallel/Prim.hs
dph-lifted-vseg/Data/Array/Parallel/PArray/Scalar.hs
dph-lifted-vseg/Data/Array/Parallel/Prim.hs

index 424a41f..f4b7986 100644 (file)
@@ -13,8 +13,8 @@ module Data.Array.Parallel.Lifted.Closure
         , ($:^), liftedApply
 
         -- * Closure Construction.
-        , closure1,  closure2,  closure3
-        , closure1', closure2', closure3')
+        , closure1,  closure2,  closure3,  closure4
+        , closure1', closure2', closure3', closure4')
 where
 import Data.Array.Parallel.PArray.PData
 import Data.Array.Parallel.PArray.PRepr
@@ -151,6 +151,33 @@ closure3 fv fl
    in   Clo fv_1 fl_1 ()
 {-# INLINE_CLOSURE closure3 #-}
 
+-- | Construct an arity-4 closure
+--   from lifted and unlifted versions of a primitive function.
+closure4 
+        :: forall a b c d e. (PA a, PA b, PA c)
+        => (a -> b -> c -> d -> e)
+        -> (Int -> PData a -> PData b -> PData c -> PData d -> PData e)
+        -> (a :-> b :-> c :-> d :-> e)
+        
+closure4 fv fl
+ = let  fv_1   _ xa = Clo   fv_2 fl_2 xa
+        fl_1 _ _ xs = AClo  fv_2 fl_2 xs
+
+        -----
+        fv_2 xa yb   = Clo  fv_3 fl_3 (xa, yb)
+        fl_2 _ xs ys = AClo fv_3 fl_3 (PTuple2 xs ys)
+
+        -----
+        fv_3 (xa, yb) zc           = Clo  fv_4 fl_4 ((xa, yb), zc)
+        fl_3 n (PTuple2 xs ys) zs  = AClo fv_4 fl_4 (PTuple2 (PTuple2 xs ys) zs)
+
+        -----
+        fv_4 ((xa, yb), zc)   wd        = fv xa yb zc wd
+        fl_4 n (PTuple2 (PTuple2 xs ys) zs)  ws  = fl n xs ys zs ws
+
+   in   Clo fv_1 fl_1 ()
+{-# INLINE_CLOSURE closure4 #-}
+
 
 -- Closure constructors that take PArrays -------------------------------------
 -- These versions are useful when defining prelude functions such as in 
@@ -205,6 +232,22 @@ closure3' fv fl
 {-# INLINE_CLOSURE closure3' #-}
 
 
+-- | Construct an arity-3 closure.
+closure4'
+        :: forall a b c d e. (PA a, PA b, PA c) 
+        => (a -> b -> c -> d -> e)
+        -> (PArray a -> PArray b -> PArray c -> PArray d -> PArray e)
+        -> (a :-> b :-> c :-> d :-> e) 
+
+closure4' fv fl 
+ = let  {-# INLINE fl' #-}
+        fl' (I# n#) pdata1 pdata2 pdata3 pdata4
+         = case fl (PArray n# pdata1) (PArray n# pdata2) (PArray n# pdata3) (PArray n# pdata4) of
+                 PArray _ pdata' -> pdata'
+   in   closure4 fv fl'
+{-# INLINE_CLOSURE closure4' #-}
+
+
 -- PData instance for closures ------------------------------------------------
 -- This needs to be here instead of in a module D.A.P.PArray.PData.Closure
 -- to break an import loop.
index e9e3280..6889d43 100644 (file)
@@ -102,4 +102,3 @@ instance (PR a, PR b) => PR (a, b) where
    = V.zip (toVectorPR as) (toVectorPR bs)
 
 
-
index 828ec7b..e17f9b4 100644 (file)
@@ -4,7 +4,7 @@ module Data.Array.Parallel.Lifted.Closure (
   mkClosure, mkClosureP, ($:), ($:^),
   closure, liftedClosure, liftedApply,
 
-  closure1, closure2, closure3, closure4
+  closure1, closure2, closure3, closure4, closure5,
 ) where
 import Data.Array.Parallel.PArray.PReprInstances ()
 import Data.Array.Parallel.PArray.PDataInstances
@@ -242,3 +242,30 @@ closure4 fv fl = mkClosure fv_1 fl_1 ()
 
     fv_4 (x, y, z) v = fv x y z v
     fl_4 ps vs = case unzip3PA# ps of (xs, ys, zs) -> fl xs ys zs vs
+
+
+-- | Arity-5 closures.
+closure5 :: (PA a, PA b, PA c, PA d, PA e)
+         => (a -> b -> c -> d -> e -> f)
+         -> (PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray f)
+         -> (a :-> b :-> c :-> d :-> e :-> f)
+
+{-# INLINE closure5 #-}
+closure5 fv fl = mkClosure fv_1 fl_1 ()
+  where
+    fv_1 _  x  = mkClosure  fv_2 fl_2 x
+    fl_1 _  xs = mkClosureP fv_2 fl_2 xs
+
+    fv_2 x  y  = mkClosure  fv_3 fl_3 (x, y)
+    fl_2 xs ys = mkClosureP fv_3 fl_3 (zipPA# xs ys)
+
+    fv_3 (x, y)  z  = mkClosure  fv_4 fl_4 (x, y, z)
+    fl_3 xys     zs = case unzipPA# xys of (xs, ys) -> mkClosureP fv_4 fl_4 (zip3PA# xs ys zs)
+
+    fv_4 (w, x, y) z = mkClosure fv_5 fl_5 (w, x, y, z) 
+    fl_4 ps zs = case unzip3PA# ps of (ws, xs, ys) -> mkClosureP fv_5 fl_5  (zip4PA# ws xs ys zs)
+
+    fv_5 (v, w, x, y) z = fv v w x y z
+    fl_5 ps zs = case unzip4PA# ps of (vs, ws, xs, ys) -> fl vs ws xs ys zs
+
+   
\ No newline at end of file
index 59a2783..0c25845 100644 (file)
@@ -23,7 +23,7 @@
 --
 module Data.Array.Parallel.Lifted.Combinators (
   lengthPA, replicatePA, singletonPA, mapPA, crossMapPA,
-  zipPA, zip3PA, zipWithPA, zipWith3PA, unzipPA, unzip3PA, 
+  zipPA, zip3PA, zip4PA, zipWithPA, zipWith3PA, unzipPA, unzip3PA, unzip4PA ,
   packPA, filterPA, combine2PA, indexPA, concatPA, appPA, enumFromToPA_Int,
   indexedPA, slicePA, updatePA, bpermutePA,
 
@@ -191,6 +191,22 @@ zip3PA_l (PArray n# (PNested segd xs)) (PArray _ (PNested _ ys)) (PArray _ (PNes
   = PArray n# (PNested segd (P_3 xs ys zs))
 
 
+
+zip4PA :: (PA a, PA b, PA c, PA d) => PArray a :-> PArray b :-> PArray c :-> PArray d :-> PArray (a, b, c, d)
+{-# INLINE zip4PA #-}
+zip4PA = closure4 zip4PA_v zip4PA_l
+
+zip4PA_v :: (PA a, PA b, PA c, PA d) => PArray a -> PArray b -> PArray c -> PArray d -> PArray (a, b, c, d)
+{-# INLINE_PA zip4PA_v #-}
+zip4PA_v xs ys = zip4PA# xs ys
+
+zip4PA_l :: (PA a, PA b, PA c, PA d)
+         => PArray (PArray a) -> PArray (PArray b) -> PArray (PArray c) -> PArray (PArray d) -> PArray (PArray (a, b, c, d))
+{-# INLINE_PA zip4PA_l #-}
+zip4PA_l (PArray n# (PNested segd ws)) (PArray _ (PNested _ xs)) (PArray _ (PNested _ ys)) (PArray _ (PNested _ zs))
+  = PArray n# (PNested segd (P_4 ws xs ys zs))
+
+
 -- zipWith --------------------------------------------------------------------
 -- |Map a function over multiple arrays at once.
 
@@ -233,6 +249,26 @@ zipWith3PA_l fs ass bss css
       (replicatelPA# (segdPA# ass) fs $:^ concatPA# ass $:^ concatPA# bss $:^ concatPA# css)
 
 
+zipWith4PA :: (PA a, PA b, PA c, PA d, PA e)
+           => (a :-> b :-> c :-> d :-> e) :-> PArray a :-> PArray b :-> PArray c :-> PArray d :-> PArray e
+{-# INLINE zipWith4PA #-}
+zipWith4PA = closure5 zipWith4PA_v zipWith4PA_l
+
+zipWith4PA_v :: (PA a, PA b, PA c, PA d, PA e)
+             => (a :-> b :-> c :-> d :-> e) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e
+{-# INLINE_PA zipWith4PA_v #-}
+zipWith4PA_v f as bs cs ds = replicatePA# (lengthPA# as) f $:^ as $:^ bs $:^ cs $:^ ds
+
+zipWith4PA_l :: (PA a, PA b, PA c, PA d, PA e)
+             => PArray (a :-> b :-> c :-> d :-> e) 
+             -> PArray (PArray a) -> PArray (PArray b) -> PArray (PArray c)
+             -> PArray (PArray d) -> PArray (PArray e)
+{-# INLINE_PA zipWith4PA_l #-}
+zipWith4PA_l fs ass bss css dss
+  = copySegdPA# ass
+      (replicatelPA# (segdPA# ass) fs $:^ concatPA# ass $:^ concatPA# bss $:^ concatPA# css $:^ concatPA# dss)
+
+
 -- unzip ----------------------------------------------------------------------
 -- |Transform an array of tuples into a tuple of arrays.
 
@@ -265,6 +301,21 @@ unzip3PA_l xyzss = zip3PA# (copySegdPA# xyzss xs) (copySegdPA# xyzss ys) (copySe
     (xs, ys, zs) = unzip3PA# (concatPA# xyzss)
 
 
+unzip4PA :: (PA a, PA b, PA c, PA d) => PArray (a, b, c, d) :-> (PArray a, PArray b, PArray c, PArray d)
+{-# INLINE unzip4PA #-}
+unzip4PA = closure1 unzip4PA_v unzip4PA_l
+
+unzip4PA_v :: (PA a, PA b, PA c, PA d) => PArray (a, b, c, d) -> (PArray a, PArray b, PArray c, PArray d)
+{-# INLINE_PA unzip4PA_v #-}
+unzip4PA_v abs' = unzip4PA# abs'
+
+unzip4PA_l :: (PA a, PA b, PA c) => PArray (PArray (a, b, c, d)) -> PArray (PArray a, PArray b, PArray c, PArray d)
+{-# INLINE_PA unzip4PA_l #-}
+unzip4PA_l wxyzss = zip4PA# (copySegdPA# wxyzss ws) (copySegdPA# wxyzss xs) (copySegdPA# wxyzss ys) (copySegdPA# wxyzss zs) 
+  where
+    (ws, xs, ys, zs) = unzip4PA# (concatPA# wxyzss) 
+
+
 -- packPA ---------------------------------------------------------------------
 -- | Select the elements of an array that have their tag set as True.
 --   
index 89de967..54c2e70 100644 (file)
@@ -132,6 +132,19 @@ scalar_zipWith3 f xs ys zs
         $ U.zipWith3 f (toUArray xs) (toUArray ys) (toUArray zs)
 
 
+
+
+-- | Zip four arrays, yielding a new array.
+scalar_zipWith4
+        :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e)
+        => (a -> b -> c -> d -> e) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e
+
+{-# INLINE_PA scalar_zipWith4 #-}
+scalar_zipWith4 f ws xs ys zs 
+        = fromUArray' (prim_lengthPA ws)
+        $ U.zipWith4 f (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
+
+
 -- | Left fold over an array.
 scalar_fold 
         :: Scalar a
index 4e9b95d..7146ec2 100644 (file)
@@ -10,7 +10,7 @@ module Data.Array.Parallel.PArray.PDataInstances(
   punit,
 
   -- * Operators on arrays of tuples
-  zipPA#,  unzipPA#, zip3PA#, unzip3PA#,
+  zipPA#,  unzipPA#, zip3PA#, unzip3PA#, unzip4PA#,
   zip4PA#, zip5PA#, 
   
   -- * Operators on nested arrays
@@ -275,6 +275,11 @@ unzip3PA# :: PArray (a, b, c) -> (PArray a, PArray b, PArray c)
 unzip3PA# (PArray n# (P_3 xs ys zs))
   = (PArray n# xs, PArray n# ys, PArray n# zs)
 
+unzip4PA# :: PArray (a, b, c, d) -> (PArray a, PArray b, PArray c, PArray d)
+{-# INLINE_PA unzip4PA# #-}
+unzip4PA# (PArray n# (P_4 ws xs ys zs))
+  = (PArray n# ws, PArray n# xs, PArray n# ys, PArray n# zs)
+
 
 zip4PA# :: PArray a -> PArray b -> PArray c -> PArray d -> PArray (a, b, c, d)
 {-# INLINE_PA zip4PA# #-}
index 32402e2..22b6b3f 100644 (file)
@@ -17,11 +17,11 @@ module Data.Array.Parallel.Prim (
   PData, PDatas(..), PRepr, PA(..), PR(..),
   replicatePD, emptyPD, packByTagPD, combine2PD,
   Scalar(..),
-  scalar_map, scalar_zipWith, scalar_zipWith3,
+  scalar_map, scalar_zipWith, scalar_zipWith3, scalar_zipWith4,
   Void, Sum2(..), Sum3(..), Wrap(..),
   void, fromVoid, pvoid, pvoids#, punit,
   (:->)(..), 
-  closure, liftedClosure, ($:), liftedApply, closure1, closure2, closure3,
+  closure, liftedClosure, ($:), liftedApply, closure1, closure2, closure3, closure4,
   Sel2,  replicateSel2#, tagsSel2, elementsSel2_0#, elementsSel2_1#,
   Sels2, lengthSels2#,
   replicatePA_Int#, replicatePA_Double#,
@@ -45,14 +45,14 @@ import Data.Array.Parallel.PArray.PReprInstances  ( {-we required instances-} )
 import Data.Array.Parallel.PArray.PData           (PData, PDatas, PR(..))
 import Data.Array.Parallel.PArray.PDataInstances  (pvoid, punit, Sels2)
 import Data.Array.Parallel.Lifted.Closure         ((:->)(..), closure, liftedClosure, ($:),
-                                                   liftedApply, closure1, closure2, closure3)
+                                                   liftedApply, closure1, closure2, closure3, closure4)
 import Data.Array.Parallel.Lifted.Unboxed         (Sel2, replicateSel2#, tagsSel2, elementsSel2_0#,
                                                    elementsSel2_1#,
                                                    replicatePA_Int#, replicatePA_Double#,
                                                    emptyPA_Int#, emptyPA_Double#,
                                                    {- packByTagPA_Int#, packByTagPA_Double# -}
                                                    combine2PA_Int#, combine2PA_Double#)
-import Data.Array.Parallel.Lifted.Scalar          (scalar_map, scalar_zipWith, scalar_zipWith3)
+import Data.Array.Parallel.Lifted.Scalar          (scalar_map, scalar_zipWith, scalar_zipWith3, scalar_zipWith4)
 import Data.Array.Parallel.Prelude.Tuple          (tup2, tup3, tup4)
 import GHC.Exts
 
index fb211e2..37b9845 100644 (file)
@@ -24,6 +24,7 @@ module Data.Array.Parallel.PArray.Scalar
         , map
         , zipWith
         , zipWith3
+        , zipWith4
         
         -- * Folds
         , fold,         folds
@@ -200,6 +201,15 @@ zipWith3
 zipWith3 f (PArray len xs) (PArray _ ys) (PArray _ zs)
         = PArray len $ to $ U.zipWith3 f (from xs) (from ys) (from zs)
 
+-- | Zip four arrays, yielding a new array.
+{-# INLINE_PA zipWith4 #-}
+zipWith4
+        :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e)
+        => (a -> b -> c -> d -> e) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e
+
+zipWith4 f (PArray len ws) (PArray _ xs) (PArray _ ys) (PArray _ zs)
+        = PArray len $ to $ U.zipWith4 f (from ws) (from xs) (from ys) (from zs)
+
 
 -- Folds ----------------------------------------------------------------------
 -- | Left fold over an array.
index cec1e9e..b7bc179 100644 (file)
@@ -19,6 +19,7 @@ module Data.Array.Parallel.Prim
         , scalar_map
         , scalar_zipWith
         , scalar_zipWith3
+        , scalar_zipWith4
 
         -- Types used in the generic representation
         , Void, void, fromVoid, pvoid, pvoids#
@@ -30,7 +31,7 @@ module Data.Array.Parallel.Prim
         , (:->)(..)
         , closure,              ($:)
         , liftedClosure,        liftedApply
-        , closure1, closure2, closure3
+        , closure1, closure2, closure3, closure4
         
         -- Selectors
         , Sel2
@@ -207,6 +208,19 @@ closure3 fv fl
    in   C.closure3 fv fl'
 {-# INLINE_CLOSURE closure3 #-}
 
+{-# INLINE_CLOSURE closure4 #-}
+closure4 :: forall a b c d e.  (PA a, PA b, PA c)
+         => (a -> b -> c -> d -> e)
+         -> (PArray a -> PArray b -> PArray c -> PArray d -> PArray e)
+         -> (a :-> b :-> c :-> d :-> e)
+closure4 fv fl
+ = let  fl' :: Int -> PData a -> PData b -> PData c -> PData d -> PData e
+        fl' (I# c#) pdata1 pdata2 pdata3 pdata4
+         = case fl (PArray c# pdata1) (PArray c# pdata2) (PArray c# pdata3) (PArray c# pdata4) of
+                 PArray _ pdata' -> pdata'
+                
+   in   C.closure4 fv fl'
+
 
 -- Selector functions ---------------------------------------------------------
 -- The vectoriser wants versions of these that take unboxed integers
@@ -281,6 +295,13 @@ scalar_zipWith3
 scalar_zipWith3 = Scalar.zipWith3
 {-# INLINE scalar_zipWith3 #-}
 
+{-# INLINE scalar_zipWith4 #-}
+scalar_zipWith4
+        :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e)
+        => (a -> b -> c -> d -> e) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e
+
+scalar_zipWith4 = Scalar.zipWith4
+
 
 -- Int functions --------------------------------------------------------------
 type PArray_Int# = U.Array Int