Make combine2 work with selectors
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Mon, 7 Jun 2010 07:57:14 +0000 (07:57 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Mon, 7 Jun 2010 07:57:14 +0000 (07:57 +0000)
12 files changed:
dph-common/Data/Array/Parallel/Lifted/PArray.hs
dph-common/Data/Array/Parallel/Lifted/Repr.hs
dph-prim-interface/Data/Array/Parallel/Unlifted.hs
dph-prim-interface/interface/DPH_Header.h
dph-prim-interface/interface/DPH_Interface.h
dph-prim-par/Data/Array/Parallel/Unlifted.hs
dph-prim-par/Data/Array/Parallel/Unlifted/Parallel.hs
dph-prim-par/Data/Array/Parallel/Unlifted/Parallel/Combinators.hs
dph-prim-seq/Data/Array/Parallel/Unlifted.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential/Flat.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential/Flat/Combinators.hs

index 3e6604f..8b5f545 100644 (file)
@@ -433,9 +433,9 @@ combine2PRScalar :: Scalar a => T_combine2PR a
 {-# INLINE combine2PRScalar #-}
 combine2PRScalar _ sel xs ys = traceF "combine2PRScalar"
                              $ toScalarPData
-                             $ U.combine2ByTag (U.tagsSel2 sel)
-                                               (fromScalarPData xs)
-                                               (fromScalarPData ys)
+                             $ U.combine2 sel
+                                          (fromScalarPData xs)
+                                          (fromScalarPData ys)
 
 updatePRScalar :: Scalar a => T_updatePR a
 {-# INLINE updatePRScalar #-}
index 3bea09d..7751155 100644 (file)
@@ -274,7 +274,7 @@ instance (PR a, PR b) => PR (Sum2 a b) where
       PSum2 sel' as bs
     where
       tags  = U.tagsSel2 sel
-      tags' = U.combine2ByTag tags (U.tagsSel2 sel1) (U.tagsSel2 sel2)
+      tags' = U.combine2 sel (U.tagsSel2 sel1) (U.tagsSel2 sel2)
       sel'  = U.tagsToSel2 tags'
 
       asel = U.tagsToSel2 (U.packByTag tags tags' 0)
@@ -386,7 +386,7 @@ instance PR a => PR (PArray a) where
       tags = U.tagsSel2 sel
 
       segd = U.lengthsToSegd
-           $ U.combine2ByTag tags (U.lengthsSegd xsegd) (U.lengthsSegd ysegd)
+           $ U.combine2 sel (U.lengthsSegd xsegd) (U.lengthsSegd ysegd)
 
       sel' = U.tagsToSel2
            $ U.replicate_s segd tags
index 99c188f..6d89242 100644 (file)
@@ -60,9 +60,11 @@ combine [] [] [] = []
 combine (True  : bs) (x : xs) ys       = x : combine bs xs ys
 combine (False : bs) xs       (y : ys) = y : combine bs xs ys
 
-combine2ByTag [] [] [] = []
-combine2ByTag (0 : bs) (x : xs) ys = x : combine2ByTag bs xs ys
-combine2ByTag (1 : bs) xs (y : ys) = y : combine2ByTag bs xs ys
+combine2 sel xs ys = go (tagsSel2 sel) xs ys
+  where
+    go [] [] [] = []
+    go (0 : bs) (x : xs) ys = x : go bs xs ys
+    go (1 : bs) xs (y : ys) = y : go bs xs ys
 
 map = P.map
 filter = P.filter
index 6dfa73e..85eedc4 100644 (file)
@@ -7,7 +7,7 @@ module Data.Array.Parallel.Unlifted (
   length,
   empty, replicate, repeat, (+:+), interleave,
   (!:), extract, drop, permute, mbpermute, bpermute, bpermuteDft, update,
-  pack, combine, combine2ByTag,
+  pack, combine, combine2,
   enumFromTo, enumFromThenTo, enumFromStepLen, enumFromStepLenEach,
   indexed,
   zip, zip3, unzip, unzip3, fsts, snds,
index 58a6a9b..685719f 100644 (file)
@@ -65,8 +65,8 @@ pack :: Elt a => Array a -> Array Bool -> Array a
 combine :: Elt a => Array Bool -> Array a -> Array a -> Array a
 {-# INLINE_BACKEND combine #-}
 
-combine2ByTag :: Elt a => Array Int -> Array a -> Array a -> Array a
-{-# INLINE_BACKEND combine2ByTag #-}
+combine2 :: Elt a => Sel2 -> Array a -> Array a -> Array a
+{-# INLINE_BACKEND combine2 #-}
 
 map :: (Elt a, Elt b) => (a -> b) -> Array a -> Array b
 {-# INLINE_BACKEND map #-}
index 270129e..0a8bc5f 100644 (file)
@@ -32,7 +32,7 @@ update = updateUP
 interleave = interleaveUP
 pack = packUP
 combine = combineUP
-combine2ByTag = combine2ByTagUP
+combine2 = combine2UP
 map = mapUP
 filter = filterUP
 zip = zipU
index 3234606..73958d2 100644 (file)
@@ -18,7 +18,7 @@ module Data.Array.Parallel.Unlifted.Parallel (
 
   enumFromToUP, enumFromThenToUP, enumFromStepLenUP, enumFromStepLenEachUP,
 
-  mapUP, filterUP, packUP, combineUP, combine2ByTagUP,
+  mapUP, filterUP, packUP, combineUP, combine2UP,
   zipWithUP, foldUP, scanUP,
 
   andUP, sumUP,
index dc2aadf..6dcef66 100644 (file)
@@ -18,7 +18,7 @@
 #include "fusion-phases.h"
 
 module Data.Array.Parallel.Unlifted.Parallel.Combinators (
-  mapUP, filterUP, packUP, combineUP, combine2ByTagUP,
+  mapUP, filterUP, packUP, combineUP, combine2UP,
   zipWithUP, foldUP, fold1UP, foldl1UP, scanUP
 ) where
 
@@ -62,11 +62,17 @@ combineUP flags !xs !ys = joinD theGang balanced
 
     go ((i :*: j) :*: (m :*: n)) bs = combineU bs (sliceU xs i m) (sliceU ys j n)
 
-combine2ByTagUP :: UA a => UArr Int -> UArr a -> UArr a -> UArr a
-{-# INLINE_UP combine2ByTagUP #-}
-combine2ByTagUP tags !xs !ys = joinD theGang balanced
-                             $ zipWithD theGang go (zipD is ns)
-                             $ splitD theGang balanced tags
+combine2UP :: UA a => USel2 -> UArr a -> UArr a -> UArr a
+{-# INLINE_UP combine2UP #-}
+combine2UP sel !xs !ys = zipWithUP get (tagsUSel2 sel) (indicesUSel2 sel)
+  where
+    get 0 i = xs !: i
+    get _ i = ys !: i
+
+{-
+combine2UP tags !xs !ys = joinD theGang balanced
+                        $ zipWithD theGang go (zipD is ns)
+                        $ splitD theGang balanced tags
   where
     ns = mapD theGang count
        $ splitD theGang balanced tags
@@ -80,7 +86,7 @@ combine2ByTagUP tags !xs !ys = joinD theGang balanced
 
     go ((i :*: j) :*: (m :*: n)) ts = combine2ByTagU ts (sliceU xs i m)
                                                         (sliceU ys j n)
-
+-}
 
 zipWithUP :: (UA a, UA b, UA c) => (a -> b -> c) -> UArr a -> UArr b -> UArr c
 {-# INLINE zipWithUP #-}
index a50fe69..c1c26e0 100644 (file)
@@ -30,7 +30,7 @@ update = updateU
 interleave = interleaveU
 pack = packU
 combine = combineU
-combine2ByTag = combine2ByTagU
+combine2 = combine2U
 map = mapU
 filter = filterU
 zip = zipU
index 39f7566..0cafb4c 100644 (file)
@@ -52,7 +52,7 @@ module Data.Array.Parallel.Unlifted.Sequential (
   mapAccumLU,
 
   -- Segmented filter and combines
-  combineU, combine2ByTagU, combineSU,
+  combineU, combine2U, combineSU,
 
   -- * Searching
   elemU, notElemU,
index d31bebb..b00f98f 100644 (file)
@@ -47,7 +47,7 @@ module Data.Array.Parallel.Unlifted.Sequential.Flat (
   -- * Higher-order operations
   mapU, zipWithU, zipWith3U,
   filterU, packU, 
-  combineU, combine2ByTagU,
+  combineU, combine2U,
   foldlU, foldl1U, foldl1MaybeU,
   {-foldrU, foldr1U,-}
   foldU, fold1U, fold1MaybeU,
index 57dd2f3..62fefc7 100644 (file)
@@ -31,7 +31,7 @@ module Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
   mapAccumLU,
   zipU, zip3U, unzipU, unzip3U, fstU, sndU,
   zipWithU, zipWith3U, 
-  combineU, combine2ByTagU
+  combineU, combine2U
 ) where
 
 import Data.Array.Parallel.Base (
@@ -52,6 +52,7 @@ import Data.Array.Parallel.Unlifted.Sequential.Flat.Basics (
   lengthU, (!:))
 import Data.Array.Parallel.Unlifted.Sequential.Flat.Subarrays (
   sliceU)
+import Data.Array.Parallel.Unlifted.Sequential.Flat.USel
 
 import qualified GHC.Base
 
@@ -228,11 +229,11 @@ combineU f a1 a2 = checkEq (here "combineU")
   unstreamU (combineS (streamU f) (streamU a1) (streamU a2))
 
 
-combine2ByTagU :: UA a => UArr Int -> UArr a -> UArr a -> UArr a
-{-# INLINE_U combine2ByTagU #-}
-combine2ByTagU ts xs ys
+combine2U :: UA a => USel2 -> UArr a -> UArr a -> UArr a
+{-# INLINE_U combine2U #-}
+combine2U ts xs ys
   = checkEq (here "combine2ByTagU")
-            ("tags lengnth /= sum of args length")
-            (lengthU ts) (lengthU xs + lengthU ys)
-  $ unstreamU (combine2ByTagS (streamU ts) (streamU xs) (streamU ys))
+            ("sel length /= sum of args length")
+            (lengthUSel2 ts) (lengthU xs + lengthU ys)
+  $ unstreamU (combine2ByTagS (streamU (tagsUSel2 ts)) (streamU xs) (streamU ys))