dph-prim: export scattered segment count
authorBen Lippmeier <benl@ouroborus.net>
Thu, 20 Oct 2011 02:42:40 +0000 (13:42 +1100)
committerBen Lippmeier <benl@ouroborus.net>
Thu, 3 Nov 2011 05:26:53 +0000 (16:26 +1100)
dph-prim-interface/dph-prim-interface.cabal
dph-prim-interface/interface/DPH_Header.h
dph-prim-interface/interface/DPH_Interface.h
dph-prim-par/Data/Array/Parallel/Unlifted.hs
dph-prim-seq/Data/Array/Parallel/Unlifted.hs

index 95ebf4c..57f9d42 100644 (file)
@@ -31,5 +31,6 @@ Library
   Build-Depends: 
         base     == 4.4.*,
         random   == 1.0.*,
-        dph-base == 0.5.*
+        dph-base == 0.5.*,
+        vector   == 0.9
 
index f2e418d..a50eb8b 100644 (file)
@@ -41,26 +41,31 @@ module Data.Array.Parallel.Unlifted (
   zip3, unzip3,
     
   -- * Folds
-  fold, fold1,
-  and, sum, scan,
+  fold,  fold_s,  fold_ss, fold_r,
+  fold1, fold1_s, fold1_ss,
+  sum,   sum_s,   sum_ss,  sum_r,
+  count, count_s, count_ss,
+  scan,
+  and, 
 
   -- * Segmented Constructors
-  append_s, replicate_s, replicate_rs, 
+  append_s,
+  replicate_s, replicate_rs, 
 
   -- * Segmented Projections
   indices_s,
-
-  -- * Segmented Folds
-  fold_s, fold1_s, fold_r, sum_s,  sum_r,
-  
-  -- * Scattered Segmented Folds
-  fold_ss, fold1_ss,
-  
+    
   -- * Segment Descriptors
-  Segd, mkSegd, validSegd,
-  emptySegd, singletonSegd,
+  Segd,
+  mkSegd,
+  validSegd,
+  emptySegd,
+  singletonSegd,
   lengthsToSegd,
-  lengthSegd, lengthsSegd, indicesSegd, elementsSegd,
+  lengthSegd,
+  lengthsSegd,
+  indicesSegd,
+  elementsSegd,
   plusSegd, 
 
   -- * Scattered Segment Descriptors
@@ -111,9 +116,6 @@ module Data.Array.Parallel.Unlifted (
   -- * Packing and picking
   packByTag, pick,
   
-  -- * Counting
-  count, count_s,
-
   -- * Random arrays
   randoms, randomRs,
   
@@ -127,4 +129,4 @@ import System.IO                  (IO, Handle)
 import Data.Word                  (Word8)
 import qualified System.Random
 import qualified Prelude
-
+import qualified Data.Vector    as VV
index f03ab05..a60a8ff 100644 (file)
@@ -480,8 +480,8 @@ fold_r :: Elt a => (a -> a -> a) -> a -> Int -> Array a -> Array a
 {-# INLINE_BACKEND fold_r #-}
 
 sum_s :: (Num a, Elt a) => Segd -> Array a -> Array a
-{-# INLINE sum_s #-}
 sum_s = fold_s (Prelude.+) 0
+{-# INLINE sum_s #-}
 
 sum_r :: (Num a, Elt a) => Int ->Array a -> Array a
 {-# INLINE_BACKEND sum_r #-}
@@ -498,6 +498,18 @@ sum_r :: (Num a, Elt a) => Int ->Array a -> Array a
   #-}
 
 
+-- Scattered Segmented Folds --------------------------------------------------
+fold_ss :: Elt a => (a -> a -> a) -> a -> SSegd -> VV.Vector (Array a) -> Array a
+{-# INLINE_BACKEND fold_ss #-}
+
+fold1_ss :: Elt a => (a -> a -> a) -> SSegd -> VV.Vector (Array a) -> Array a
+{-# INLINE_BACKEND fold1_ss #-}
+
+sum_ss :: (Num a, Elt a) => SSegd -> VV.Vector (Array a) -> Array a
+sum_ss = fold_ss (Prelude.+) 0
+{-# INLINE sum_ss #-}
+
+
 -- Operations on Segment Descriptors ------------------------------------------
 indices_s :: Segd -> Array Int
 {-# INLINE_BACKEND indices_s #-}
@@ -675,18 +687,26 @@ pick xs !x = map (x==) xs
 
 
 
-
 -- Counting -------------------------------------------------------------------
 -- | Count the number of elements in array that are equal to the given value.
 count :: (Elt a, Eq a) => Array a -> a -> Int
-{-# INLINE_BACKEND count #-}
 count xs !x = sum (map (tagToInt . fromBool . (==) x) xs)
+{-# INLINE_BACKEND count #-}
 
 
 -- | Count the number of elements in segments that are equal to the given value.
 count_s :: (Elt a, Eq a) => Segd -> Array a -> a -> Array Int
+count_s segd xs !x
+        = sum_s segd (map (tagToInt . fromBool . (==) x) xs)
 {-# INLINE_BACKEND count_s #-}
-count_s segd xs !x = sum_s segd (map (tagToInt . fromBool . (==) x) xs)
+
+
+-- | Count the number of elements in segments that are equal to the given value.
+--   TODO: Adding V.map here will probably break fusion with sum_ss.
+count_ss :: (Elt a, Eq a) => SSegd -> VV.Vector (Array a) -> a -> Array Int
+{-# INLINE_BACKEND count_ss #-}
+count_ss ssegd xs !x
+        = sum_ss ssegd (VV.map (map (tagToInt . fromBool . (==) x)) xs)
 
 
 {-# RULES
index 562d27d..de00220 100644 (file)
@@ -18,15 +18,15 @@ import Data.Array.Parallel.Unlifted.Parallel
 import Data.Array.Parallel.Base.TracePrim
 import Data.Array.Parallel.Unlifted.Distributed ( DT )
 
+import Data.Array.Parallel.Unlifted.Sequential.Vector (Unbox, Vector)
 import Data.Array.Parallel.Unlifted.Parallel.UPSel
 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSegd           as UPSegd
 import qualified Data.Array.Parallel.Unlifted.Parallel.UPSSegd          as UPSSegd
 import qualified Data.Array.Parallel.Unlifted.Parallel.UPVSegd          as UPVSegd
 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector         as Seq
 import qualified Data.Array.Parallel.Unlifted.Sequential.Combinators    as Seq
+import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
 
-
-import Data.Array.Parallel.Unlifted.Sequential.Vector (Unbox,Vector)
 import Prelude (($!))
 
 #include "DPH_Interface.h"
index 32fa3e4..d8748bd 100644 (file)
@@ -13,7 +13,7 @@
 
 #include "DPH_Header.h"
 
-import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
+import Data.Array.Parallel.Unlifted.Sequential.Vector (Unbox, Vector)
 import Data.Array.Parallel.Unlifted.Sequential.USel
 import Data.Array.Parallel.Unlifted.Sequential.Basics
 import Data.Array.Parallel.Unlifted.Sequential.Combinators
@@ -21,6 +21,7 @@ import Data.Array.Parallel.Unlifted.Sequential.Sums
 import qualified Data.Array.Parallel.Unlifted.Sequential.USegd  as USegd
 import qualified Data.Array.Parallel.Unlifted.Sequential.USSegd as USSegd
 import qualified Data.Array.Parallel.Unlifted.Sequential.UVSegd as UVSegd
+import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
 
 #include "DPH_Interface.h"