Make combine2 work with selectors
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Parallel / Combinators.hs
index 8d3a867..6dcef66 100644 (file)
@@ -18,8 +18,8 @@
 #include "fusion-phases.h"
 
 module Data.Array.Parallel.Unlifted.Parallel.Combinators (
-  mapUP, filterUP, packUP, zipWithUP, foldUP, fold1UP, foldl1UP,
-  scanUP
+  mapUP, filterUP, packUP, combineUP, combine2UP,
+  zipWithUP, foldUP, fold1UP, foldl1UP, scanUP
 ) where
 
 import Data.Array.Parallel.Base
@@ -28,7 +28,7 @@ import Data.Array.Parallel.Unlifted.Distributed
 
 mapUP :: (UA a, UA b) => (a -> b) -> UArr a -> UArr b
 {-# INLINE mapUP #-}
-mapUP f = splitJoinD theGang (mapD theGang (mapU f))
+mapUP f xs = splitJoinD theGang (mapD theGang (mapU f)) xs
 
 filterUP :: UA a => (a -> Bool) -> UArr a -> UArr a
 {-# INLINE filterUP #-}
@@ -44,27 +44,78 @@ packUP:: UA e => UArr e -> UArr Bool -> UArr e
 {-# INLINE_UP packUP #-}
 packUP xs flags = fstU . filterUP sndS $  zipU xs flags
 
+combineUP :: UA a => UArr Bool -> UArr a -> UArr a -> UArr a
+{-# INLINE_UP combineUP #-}
+combineUP flags !xs !ys = joinD theGang balanced
+                        . zipWithD theGang go (zipD is ns)
+                        $ splitD theGang balanced flags
+  where
+    ns = mapD theGang count
+       $ splitD theGang balanced flags
+
+    is = fstS $ scanD theGang add (0 :*: 0) ns
+
+    count bs = let ts = sumU (mapU fromBool bs)
+               in ts :*: (lengthU bs - ts)
+
+    add (x1 :*: y1) (x2 :*: y2) = (x1 + x2) :*: (y1 + y2)
+
+    go ((i :*: j) :*: (m :*: n)) bs = combineU bs (sliceU xs i m) (sliceU ys j n)
+
+combine2UP :: UA a => USel2 -> UArr a -> UArr a -> UArr a
+{-# INLINE_UP combine2UP #-}
+combine2UP sel !xs !ys = zipWithUP get (tagsUSel2 sel) (indicesUSel2 sel)
+  where
+    get 0 i = xs !: i
+    get _ i = ys !: i
+
+{-
+combine2UP tags !xs !ys = joinD theGang balanced
+                        $ zipWithD theGang go (zipD is ns)
+                        $ splitD theGang balanced tags
+  where
+    ns = mapD theGang count
+       $ splitD theGang balanced tags
+
+    count bs = let ones = sumU bs
+               in (lengthU bs - ones) :*: ones
+
+    is = fstS $ scanD theGang add (0 :*: 0) ns
+
+    add (x1 :*: y1) (x2 :*: y2) = (x1+x2) :*: (y1+y2)
+
+    go ((i :*: j) :*: (m :*: n)) ts = combine2ByTagU ts (sliceU xs i m)
+                                                        (sliceU ys j n)
+-}
 
 zipWithUP :: (UA a, UA b, UA c) => (a -> b -> c) -> UArr a -> UArr b -> UArr c
 {-# INLINE zipWithUP #-}
+zipWithUP f xs ys = splitJoinD theGang (mapD theGang (mapU (uncurryS f))) (zipU xs ys)
+{-
 zipWithUP f a b = joinD    theGang balanced
-                zipWithD theGang (zipWithU f)
+                 (zipWithD theGang (zipWithU f)
                     (splitD theGang balanced a)
-                    (splitD theGang balanced b)
+                    (splitD theGang balanced b))
+-}
 --zipWithUP f a b = mapUP (uncurryS f) (zipU a b)
 
 foldUP :: (UA a, DT a) => (a -> a -> a) -> a -> UArr a -> a
 {-# INLINE foldUP #-}
-foldUP f z = maybeS z (f z)
-           . foldD  theGang combine
-           . mapD   theGang (foldl1MaybeU f)
-           . splitD theGang unbalanced
+foldUP f !z xs = foldD  theGang f
+                (mapD   theGang (foldU f z)
+                (splitD theGang unbalanced xs))
+{-
+foldUP f z xs = maybeS z (f z)
+               (foldD  theGang combine
+               (mapD   theGang (foldl1MaybeU f)
+               (splitD theGang unbalanced
+                xs)))
   where
     combine (JustS x) (JustS y) = JustS (f x y)
     combine (JustS x) NothingS  = JustS x
     combine NothingS  (JustS y) = JustS y
     combine NothingS  NothingS  = NothingS
-
+-}
 
 -- |Array reduction proceeding from the left (requires associative combination)
 --