Make combine2 work with selectors
[packages/dph.git] / dph-prim-seq / Data / Array / Parallel / Unlifted / Sequential / Flat / Combinators.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module : Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators
4 -- Copyright : (c) [2001..2002] Manuel M T Chakravarty & Gabriele Keller
5 -- (c) 2006 Manuel M T Chakravarty & Roman Leshchinskiy
6 -- License : see libraries/ndp/LICENSE
7 --
8 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
9 -- Stability : internal
10 -- Portability : portable
11 --
12 -- Description ---------------------------------------------------------------
13 --
14 -- Standard combinators for flat unlifted arrays.
15 --
16 -- Todo ----------------------------------------------------------------------
17 --
18
19 {-# LANGUAGE CPP #-}
20
21 #include "fusion-phases.h"
22
23 module Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
24 mapU,
25 filterU,
26 packU,
27 foldlU, foldl1U, foldl1MaybeU, {-foldrU, foldr1U,-}
28 foldU, fold1U, fold1MaybeU,
29 scanlU, scanl1U, {-scanrU, scanr1U,-} scanU, scan1U,
30 scanResU,
31 mapAccumLU,
32 zipU, zip3U, unzipU, unzip3U, fstU, sndU,
33 zipWithU, zipWith3U,
34 combineU, combine2U
35 ) where
36
37 import Data.Array.Parallel.Base (
38 (:*:)(..), MaybeS(..), checkNotEmpty, checkEq, sndS, Rebox(..), ST, runST)
39 import Data.Array.Parallel.Base.DTrace
40 import Data.Array.Parallel.Stream (
41 Step(..), Stream(..),
42 mapS, filterS, foldS, fold1MaybeS, scan1S, scanS, mapAccumS,
43 zipWithS, zipWith3S, combineS, combine2ByTagS,
44 sArgs, sNoArgs)
45 import Data.Array.Parallel.Unlifted.Sequential.Flat.UArr (
46 UA, UArr, MUArr,
47 writeMU, newDynResU,
48 zipU, unzipU, fstU, sndU)
49 import Data.Array.Parallel.Unlifted.Sequential.Flat.Stream (
50 streamU, unstreamU)
51 import Data.Array.Parallel.Unlifted.Sequential.Flat.Basics (
52 lengthU, (!:))
53 import Data.Array.Parallel.Unlifted.Sequential.Flat.Subarrays (
54 sliceU)
55 import Data.Array.Parallel.Unlifted.Sequential.Flat.USel
56
57 import qualified GHC.Base
58
59 import Debug.Trace
60
61 here s = "Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators." ++ s
62
63 -- |Map a function over an array
64 --
65 mapU :: (UA e, UA e') => (e -> e') -> UArr e -> UArr e'
66 {-# INLINE_U mapU #-}
67 mapU f = unstreamU . mapS f . streamU
68
69 -- |Extract all elements from an array that meet the given predicate
70 --
71 filterU :: UA e => (e -> Bool) -> UArr e -> UArr e
72 {-# INLINE_U filterU #-}
73 filterU p = unstreamU . filterS p . streamU
74
75 -- |Extract all elements from an array according to a given flag array
76 --
77 packU:: UA e => UArr e -> UArr Bool -> UArr e
78 {-# INLINE_U packU #-}
79 packU xs = fstU . filterU sndS . zipU xs
80
81
82
83 -- |Array reduction proceeding from the left
84 --
85 foldlU :: (UA a, Rebox b) => (b -> a -> b) -> b -> UArr a -> b
86 {-# INLINE_U foldlU #-}
87 foldlU f z xs = foldS f z (streamU xs)
88
89 -- |Array reduction proceeding from the left for non-empty arrays
90 --
91 -- FIXME: Rewrite for 'Stream's.
92 --
93 foldl1U :: UA a => (a -> a -> a) -> UArr a -> a
94 {-# INLINE_U foldl1U #-}
95 foldl1U f arr = checkNotEmpty (here "foldl1U") (lengthU arr) $
96 foldlU f (arr !: 0) (sliceU arr 1 (lengthU arr - 1))
97
98 foldl1MaybeU :: UA a => (a -> a -> a) -> UArr a -> MaybeS a
99 {-# INLINE_U foldl1MaybeU #-}
100 foldl1MaybeU f = fold1MaybeS f . streamU
101
102 -- |Array reduction that requires an associative combination function with its
103 -- unit
104 --
105 foldU :: UA a => (a -> a -> a) -> a -> UArr a -> a
106 {-# INLINE_U foldU #-}
107 foldU = foldlU
108
109 fold1MaybeU :: UA a => (a -> a -> a) -> UArr a -> MaybeS a
110 {-# INLINE_U fold1MaybeU #-}
111 fold1MaybeU = foldl1MaybeU
112
113 -- |Reduction of a non-empty array which requires an associative combination
114 -- function
115 --
116 fold1U :: UA a => (a -> a -> a) -> UArr a -> a
117 {-# INLINE_U fold1U #-}
118 fold1U = foldl1U
119
120 -- |Prefix scan proceedings from left to right
121 --
122 scanlU :: (UA a, UA b) => (b -> a -> b) -> b -> UArr a -> UArr b
123 {-# INLINE_U scanlU #-}
124 scanlU f z = unstreamU . scanS f z . streamU
125
126 {-# RULES
127
128 "seq/scanlU (+)" forall i xs z.
129 seq (unstreamU (scanS GHC.Base.plusInt i (streamU xs))) z
130 = i `seq` xs `seq` z
131
132 #-}
133
134 -- |Prefix scan of a non-empty array proceeding from left to right
135 --
136 scanl1U :: UA a => (a -> a -> a) -> UArr a -> UArr a
137 {-# INLINE_U scanl1U #-}
138 scanl1U f arr = checkNotEmpty (here "scanl1U") (lengthU arr) $
139 unstreamU (scan1S f (streamU arr))
140
141 -- |Prefix scan proceeding from left to right that needs an associative
142 -- combination function with its unit
143 --
144 scanU :: UA a => (a -> a -> a) -> a -> UArr a -> UArr a
145 {-# INLINE_U scanU #-}
146 scanU = scanlU
147
148 -- |Prefix scan of a non-empty array proceeding from left to right that needs
149 -- an associative combination function
150 --
151 scan1U :: UA a => (a -> a -> a) -> UArr a -> UArr a
152 {-# INLINE_U scan1U #-}
153 scan1U = scanl1U
154
155 scanResU :: UA a => (a -> a -> a) -> a -> UArr a -> UArr a :*: a
156 {-# INLINE_U scanResU #-}
157 scanResU f z = unstreamScan f z . streamU
158
159 unstreamScan :: UA a => (a -> a -> a) -> a -> Stream a -> UArr a :*: a
160 {-# INLINE_STREAM unstreamScan #-}
161 unstreamScan f z st@(Stream _ _ n _)
162 = newDynResU n (\marr -> unstreamScanM marr f z st)
163
164 unstreamScanM :: UA a => MUArr a s -> (a -> a -> a) -> a -> Stream a
165 -> ST s (Int :*: a)
166 {-# INLINE_U unstreamScanM #-}
167 unstreamScanM marr f z (Stream next s n c)
168 = traceLoopST ("unstreamScanM" `sArgs` c) $ fill s z 0
169 where
170 fill s !z !i = case next s of
171 Done -> return (i :*: z)
172 Skip s' -> s' `dseq` fill s' z i
173 Yield x s' -> s' `dseq`
174 do
175 writeMU marr i z
176 fill s' (f z x) (i+1)
177
178 -- |Accumulating map from left to right. Does not return the accumulator.
179 --
180 -- FIXME: Naming inconsistent with lists.
181 --
182 mapAccumLU :: (UA a, UA b, Rebox c) => (c -> a -> c :*: b) -> c -> UArr a -> UArr b
183 {-# INLINE_U mapAccumLU #-}
184 mapAccumLU f z = unstreamU . mapAccumS f z . streamU
185
186 -- zipU is re-exported from UArr
187
188 -- |
189 --
190 zip3U :: (UA e1, UA e2, UA e3)
191 => UArr e1 -> UArr e2 -> UArr e3 -> UArr (e1 :*: e2 :*: e3)
192 {-# INLINE_U zip3U #-}
193 zip3U a1 a2 a3 = (a1 `zipU` a2) `zipU` a3
194
195 -- |
196 zipWithU :: (UA a, UA b, UA c)
197 => (a -> b -> c) -> UArr a -> UArr b -> UArr c
198 {-# INLINE_U zipWithU #-}
199 zipWithU f a1 a2 = unstreamU (zipWithS f (streamU a1) (streamU a2))
200
201 -- |
202 zipWith3U :: (UA a, UA b, UA c, UA d)
203 => (a -> b -> c -> d) -> UArr a -> UArr b -> UArr c -> UArr d
204 {-# INLINE_U zipWith3U #-}
205 zipWith3U f a1 a2 a3 = unstreamU (zipWith3S f (streamU a1)
206 (streamU a2)
207 (streamU a3))
208
209 -- unzipU is re-exported from UArr
210
211 -- |
212 unzip3U :: (UA e1, UA e2, UA e3)
213 => UArr (e1 :*: e2 :*: e3) -> (UArr e1,UArr e2,UArr e3)
214 {-# INLINE_U unzip3U #-}
215 unzip3U a = let (a12,a3) = unzipU a
216 (a1,a2) = unzipU a12
217 in
218 (a1,a2,a3)
219
220 -- fstU and sndU reexported from UArr
221 -- |
222 combineU :: UA a
223 => UArr Bool -> UArr a -> UArr a -> UArr a
224 {-# INLINE_U combineU #-}
225 combineU f a1 a2 = checkEq (here "combineU")
226 ("flag length not equal to sum of arg length")
227 (lengthU f) (lengthU a1 + lengthU a2) $
228 -- trace ("combineU:\n\t" ++ show (lengthU f) ++ "\n\t" ++ show (lengthU a1) ++ "\n\t" ++ show (lengthU a2) ++ "\n")
229 unstreamU (combineS (streamU f) (streamU a1) (streamU a2))
230
231
232 combine2U :: UA a => USel2 -> UArr a -> UArr a -> UArr a
233 {-# INLINE_U combine2U #-}
234 combine2U ts xs ys
235 = checkEq (here "combine2ByTagU")
236 ("sel length /= sum of args length")
237 (lengthUSel2 ts) (lengthU xs + lengthU ys)
238 $ unstreamU (combine2ByTagS (streamU (tagsUSel2 ts)) (streamU xs) (streamU ys))
239