Add combine2ByTagU
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Thu, 29 Oct 2009 13:49:50 +0000 (13:49 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Thu, 29 Oct 2009 13:49:50 +0000 (13:49 +0000)
dph-base/Data/Array/Parallel/Stream/Flat.hs
dph-base/Data/Array/Parallel/Stream/Flat/Combinators.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential/Flat.hs
dph-prim-seq/Data/Array/Parallel/Unlifted/Sequential/Flat/Combinators.hs

index de091d3..626a173 100644 (file)
@@ -21,7 +21,7 @@ module Data.Array.Parallel.Stream.Flat (
   toStream, fromStream,
 
   mapS, filterS, foldS, fold1MaybeS, scanS, scan1S, mapAccumS,
-  zipWithS, zipWith3S, zipS, zip3S, combineS,
+  zipWithS, zipWith3S, zipS, zip3S, combineS, combine2ByTagS,
 
   findS, findIndexS,
 
index 843763c..dbf0c3c 100644 (file)
@@ -18,7 +18,7 @@
 
 module Data.Array.Parallel.Stream.Flat.Combinators (
   mapS, filterS, foldS, fold1MaybeS, scanS, scan1S, mapAccumS,
-  zipWithS, zipWith3S, zipS, zip3S, combineS
+  zipWithS, zipWith3S, zipS, zip3S, combineS, combine2ByTagS
 ) where
 
 import Data.Array.Parallel.Base (
@@ -148,7 +148,33 @@ combineS (Stream next1 s m c) (Stream nextS1 t1 n1 c1) (Stream nextS2 t2 n2 c2)
                                Done        -> error "combineS: stream 2 terminated unexpectedly" 
                                Skip t2'    -> Skip (s :*: t1 :*: t2')
                                Yield x t2' -> Yield x (s' :*: t1 :*: t2')
-               
+
+
+combine2ByTagS :: Stream Int -> Stream a -> Stream a -> Stream a
+{-# INLINE_STREAM combine2ByTagS #-}
+combine2ByTagS (Stream next_tag s m c) (Stream next0 s0 _ c1)
+                                       (Stream next1 s1 _ c2)
+  = Stream next (NothingS :*: s :*: s0 :*: s1) m ("combine2ByTagS" `sArgs` (c,c1,c2))
+  where
+    {-# INLINE next #-}
+    next (NothingS :*: s :*: s0 :*: s1)
+      = case next_tag s of
+          Done       -> Done
+          Skip    s' -> Skip (NothingS :*: s' :*: s0 :*: s1)
+          Yield t s' -> Skip (JustS t  :*: s' :*: s0 :*: s1)
+
+    next (JustS 0 :*: s :*: s0 :*: s1)
+      = case next0 s0 of
+          Done        -> error "combine2ByTagS: stream 1 too short"
+          Skip    s0' -> Skip    (JustS 0  :*: s :*: s0' :*: s1)
+          Yield x s0' -> Yield x (NothingS :*: s :*: s0' :*: s1)
+
+    next (JustS t :*: s :*: s0 :*: s1)
+      = case next1 s1 of
+          Done        -> error "combine2ByTagS: stream 2 too short"
+          Skip    s1' -> Skip    (JustS t  :*: s :*: s0 :*: s1')
+          Yield x s1' -> Yield x (NothingS :*: s :*: s0 :*: s1')
+
 -- | Zipping
 --
 
index 4b3e4f0..59dcb5b 100644 (file)
@@ -52,7 +52,7 @@ module Data.Array.Parallel.Unlifted.Sequential (
   mapAccumLU,
 
   -- Segmented filter and combines
-  combineU, combineSU,
+  combineU, combine2ByTagU, combineSU,
 
   -- * Searching
   elemU, notElemU,
index b3058a2..4e8e197 100644 (file)
@@ -47,7 +47,7 @@ module Data.Array.Parallel.Unlifted.Sequential.Flat (
   -- * Higher-order operations
   mapU, zipWithU, zipWith3U,
   filterU, packU, 
-  combineU, 
+  combineU, combine2ByTagU,
   foldlU, foldl1U, foldl1MaybeU,
   {-foldrU, foldr1U,-}
   foldU, fold1U, fold1MaybeU,
index dc447b2..e3e72d0 100644 (file)
@@ -31,7 +31,7 @@ module Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
   mapAccumLU,
   zipU, zip3U, unzipU, unzip3U, fstU, sndU,
   zipWithU, zipWith3U, 
-  combineU
+  combineU, combine2ByTagU
 ) where
 
 import Data.Array.Parallel.Base (
@@ -217,3 +217,12 @@ combineU f a1 a2 = checkEq (here "combineU")
 --  trace ("combineU:\n\t"  ++ show (lengthU f)  ++ "\n\t" ++ show (lengthU a1) ++ "\n\t" ++ show (lengthU a2) ++ "\n")
   unstreamU (combineS (streamU f) (streamU a1) (streamU a2))
 
+
+combine2ByTagU :: UA a => UArr Int -> UArr a -> UArr a -> UArr a
+{-# INLINE_U combine2ByTagU #-}
+combine2ByTagU ts xs ys
+  = checkEq (here "combine2ByTagU")
+            ("tags lengnth /= sum of args length")
+            (lengthU ts) (lengthU xs + lengthU ys)
+  $ unstreamU (combine2ByTagS (streamU ts) (streamU xs) (streamU ys))
+