dph-lifted-vseg: eliminate sharing in arrays during zipl
[packages/dph.git] / dph-lifted-reference / Data / Array / Parallel / PArray.hs
index 00a0526..f927b36 100644 (file)
@@ -14,8 +14,14 @@ module Data.Array.Parallel.PArray
         -- * Projections
         , length,       lengthl
         , index,        indexl
-        , extract)
+        , extract
+        
+        -- * Pack and Combine
+        , pack,         packl
+        , packByTag
+        , combine2)
 where
+import Data.Array.Parallel.Base                 (Tag)
 import Data.Vector                              (Vector)
 import qualified Data.Array.Parallel.Unlifted   as U
 import qualified Data.Array.Parallel.Array      as A
@@ -109,8 +115,7 @@ replicatel = lift2 replicate
 replicates :: U.Segd -> PArray a -> PArray a
 replicates segd (PArray n# vec)
  | I# n# /= U.lengthSegd segd
- = die "replicates"  
-        $ unlines 
+ = die "replicates" $ unlines
         [ "segd length mismatch"
         , "  segd length  = " ++ show (U.lengthSegd segd)
         , "  array length = " ++ show (I# n#) ]
@@ -183,7 +188,53 @@ extract (PArray _ vec) start len@(I# len#)
         = PArray len# $ V.slice start len vec
 
 
+-- Pack and Combine -----------------------------------------------------------
+-- | Select the elements of an array that have their tag set to True.
+pack    :: PArray a -> PArray Bool -> PArray a
+pack (PArray n1# xs) (PArray n2# bs)
+ | I# n1# /= I# n2#
+ = die "pack" $ unlines
+        [ "array length mismatch"
+        , "  data  length = " ++ show (I# n1#)
+        , "  flags length = " ++ show (I# n2#) ]
+
+ | otherwise
+ = let  xs'      = V.ifilter (\i _ -> bs V.! i) xs
+        !(I# n') = V.length xs'
+   in   PArray n' xs'
 
+-- | Lifted pack.
+packl :: PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
+packl   = lift2 pack
 
 
+-- | Filter an array based on some tags.
+packByTag :: PArray a -> U.Array Tag -> Tag -> PArray a
+packByTag (PArray n1# xs) tags tag
+ | I# n1# /= U.length tags
+ = die "packByTag" $ unlines
+        [ "array length mismatch"
+        , "  data  length = " ++ show (I# n1#)
+        , "  flags length = " ++ (show $ U.length tags) ]
 
+ | otherwise
+ = let  xs'      = V.ifilter (\i _ -> tags U.!: i == tag) xs
+        !(I# n') = V.length xs'
+   in   PArray n' xs'
+
+
+-- | Combine two arrays based on a selector.
+combine2 :: U.Sel2 -> PArray a -> PArray a -> PArray a
+combine2 tags (PArray n1# vec1) (PArray n2# vec2)
+ = let  
+        go [] [] [] = []
+        go (0 : bs) (x : xs) ys       = x : go bs xs ys
+        go (1 : bs) xs       (y : ys) = y : go bs xs ys
+        vec3    = V.fromList
+                $ go    (V.toList $ V.convert $ U.tagsSel2 tags)
+                        (V.toList vec1)
+                        (V.toList vec2)
+        !(I# n') = V.length vec3
+   
+    in  PArray n' vec3