Separate length from data in array representation
[packages/dph.git] / dph-common / Data / Array / Parallel / Lifted / Combinators.hs
1 {-# LANGUAGE CPP #-}
2
3 #include "fusion-phases.h"
4
5 module Data.Array.Parallel.Lifted.Combinators (
6 lengthPA, replicatePA, singletonPA, mapPA, crossMapPA,
7 zipWithPA, zipPA, unzipPA,
8 packPA, filterPA, combine2PA, indexPA, concatPA, appPA, enumFromToPA_Int,
9
10 lengthPA_v, replicatePA_v, singletonPA_v, zipPA_v, unzipPA_v,
11 indexPA_v, appPA_v, enumFromToPA_v
12 ) where
13
14 import Data.Array.Parallel.Lifted.PArray
15 import Data.Array.Parallel.Lifted.Closure
16 import Data.Array.Parallel.Lifted.Unboxed ( elementsSegd# )
17 import Data.Array.Parallel.Lifted.Repr
18 import Data.Array.Parallel.Lifted.Instances
19 import Data.Array.Parallel.Lifted.Scalar
20 import Data.Array.Parallel.Lifted.Selector
21
22 import qualified Data.Array.Parallel.Unlifted as U
23 import Data.Array.Parallel.Base ( fromBool )
24
25 import GHC.Exts (Int(..), (+#), (-#), Int#, (<#))
26
27 lengthPA_v :: PA a -> PArray a -> Int
28 {-# INLINE_PA lengthPA_v #-}
29 lengthPA_v pa xs = I# (lengthPA# xs)
30
31 lengthPA_l :: PA a -> PArray (PArray a) -> PArray Int
32 {-# INLINE_PA lengthPA_l #-}
33 lengthPA_l pa xss = fromUArrPA (U.elementsSegd segd) (U.lengthsSegd segd)
34 where
35 segd = segdPA# xss
36
37 lengthPA :: PA a -> (PArray a :-> Int)
38 {-# INLINE lengthPA #-}
39 lengthPA pa = closure1 (lengthPA_v pa) (lengthPA_l pa)
40
41 replicatePA_v :: PA a -> Int -> a -> PArray a
42 {-# INLINE_PA replicatePA_v #-}
43 replicatePA_v pa (I# n#) x = replicatePA# pa n# x
44
45 replicatePA_l :: PA a -> PArray Int -> PArray a -> PArray (PArray a)
46 {-# INLINE_PA replicatePA_l #-}
47 replicatePA_l pa (PArray n# (PInt ns)) (PArray _ xs)
48 = PArray n# (PNested (U.lengthsToSegd ns) xs)
49
50 replicatePA :: PA a -> (Int :-> a :-> PArray a)
51 {-# INLINE replicatePA #-}
52 replicatePA pa = closure2 dPA_Int (replicatePA_v pa) (replicatePA_l pa)
53
54 singletonPA_v :: PA a -> a -> PArray a
55 {-# INLINE_PA singletonPA_v #-}
56 singletonPA_v pa x = replicatePA_v pa 1 x
57
58 singletonPA_l :: PA a -> PArray a -> PArray (PArray a)
59 {-# INLINE_PA singletonPA_l #-}
60 singletonPA_l pa (PArray n# xs)
61 = PArray n# (PNested (U.mkSegd (U.replicate (I# n#) 1)
62 (U.enumFromStepLen 0 1 (I# n#))
63 (I# n#))
64 xs)
65
66 singletonPA :: PA a -> (a :-> PArray a)
67 {-# INLINE singletonPA #-}
68 singletonPA pa = closure1 (singletonPA_v pa) (singletonPA_l pa)
69
70 mapPA_v :: PA a -> PA b -> (a :-> b) -> PArray a -> PArray b
71 {-# INLINE_PA mapPA_v #-}
72 mapPA_v pa pb f as = replicatePA# (dPA_Clo pa pb) (lengthPA# as) f
73 $:^ as
74
75 mapPA_l :: PA a -> PA b
76 -> PArray (a :-> b) -> PArray (PArray a) -> PArray (PArray b)
77 {-# INLINE_PA mapPA_l #-}
78 mapPA_l pa pb fs xss
79 = copySegdPA# xss
80 (replicatelPA# (dPA_Clo pa pb) (segdPA# xss) fs $:^ concatPA# xss)
81
82 mapPA :: PA a -> PA b -> ((a :-> b) :-> PArray a :-> PArray b)
83 {-# INLINE mapPA #-}
84 mapPA pa pb = closure2 (dPA_Clo pa pb) (mapPA_v pa pb) (mapPA_l pa pb)
85
86 crossMapPA_v :: PA a -> PA b -> PArray a -> (a :-> PArray b) -> PArray (a,b)
87 {-# INLINE_PA crossMapPA_v #-}
88 crossMapPA_v pa pb as f
89 = zipPA# (replicatelPA# pa (segdPA# bss) as) (concatPA# bss)
90 where
91 bss = mapPA_v pa (dPA_PArray pb) f as
92
93 crossMapPA_l :: PA a -> PA b
94 -> PArray (PArray a)
95 -> PArray (a :-> PArray b)
96 -> PArray (PArray (a,b))
97 {-# INLINE_PA crossMapPA_l #-}
98 crossMapPA_l pa pb ass fs = copySegdPA# bss (zipPA# as' (concatPA# bss))
99 where
100 bsss = mapPA_l pa (dPA_PArray pb) fs ass
101 bss = concatPA_l pb bsss
102 as' = replicatelPA# pa (segdPA# (concatPA# bsss)) (concatPA# ass)
103
104 crossMapPA :: PA a -> PA b -> (PArray a :-> (a :-> PArray b) :-> PArray (a,b))
105 {-# INLINE crossMapPA #-}
106 crossMapPA pa pb = closure2 (dPA_PArray pa) (crossMapPA_v pa pb)
107 (crossMapPA_l pa pb)
108
109 zipPA_v :: PA a -> PA b -> PArray a -> PArray b -> PArray (a,b)
110 {-# INLINE_PA zipPA_v #-}
111 zipPA_v pa pb xs ys = zipPA# xs ys
112
113 zipPA_l :: PA a -> PA b
114 -> PArray (PArray a) -> PArray (PArray b) -> PArray (PArray (a,b))
115 {-# INLINE_PA zipPA_l #-}
116 zipPA_l pa pb xss yss = copySegdPA# xss (zipPA# (concatPA# xss) (concatPA# yss))
117
118 zipPA :: PA a -> PA b -> (PArray a :-> PArray b :-> PArray (a,b))
119 {-# INLINE zipPA #-}
120 zipPA pa pb = closure2 (dPA_PArray pa) (zipPA_v pa pb) (zipPA_l pa pb)
121
122 zipWithPA_v :: PA a -> PA b -> PA c
123 -> (a :-> b :-> c) -> PArray a -> PArray b -> PArray c
124 {-# INLINE_PA zipWithPA_v #-}
125 zipWithPA_v pa pb pc f as bs = replicatePA# (dPA_Clo pa (dPA_Clo pb pc))
126 (lengthPA# as)
127 f
128 $:^ as $:^ bs
129
130 zipWithPA_l :: PA a -> PA b -> PA c
131 -> PArray (a :-> b :-> c) -> PArray (PArray a) -> PArray (PArray b)
132 -> PArray (PArray c)
133 {-# INLINE_PA zipWithPA_l #-}
134 zipWithPA_l pa pb pc fs ass bss
135 = copySegdPA# ass
136 (replicatelPA# (dPA_Clo pa (dPA_Clo pb pc))
137 (segdPA# ass) fs $:^ concatPA# ass $:^ concatPA# bss)
138
139 zipWithPA :: PA a -> PA b -> PA c
140 -> ((a :-> b :-> c) :-> PArray a :-> PArray b :-> PArray c)
141 {-# INLINE zipWithPA #-}
142 zipWithPA pa pb pc = closure3 (dPA_Clo pa (dPA_Clo pb pc)) (dPA_PArray pa)
143 (zipWithPA_v pa pb pc)
144 (zipWithPA_l pa pb pc)
145
146 unzipPA_v:: PA a -> PA b -> PArray (a,b) -> (PArray a, PArray b)
147 {-# INLINE_PA unzipPA_v #-}
148 unzipPA_v pa pb abs = unzipPA# abs
149
150 unzipPA_l:: PA a -> PA b -> PArray (PArray (a, b)) -> PArray (PArray a, PArray b)
151 {-# INLINE_PA unzipPA_l #-}
152 unzipPA_l pa pb xyss = zipPA# (copySegdPA# xyss xs) (copySegdPA# xyss ys)
153 where
154 (xs, ys) = unzipPA# (concatPA# xyss)
155
156 unzipPA:: PA a -> PA b -> (PArray (a, b) :-> (PArray a, PArray b))
157 {-# INLINE unzipPA #-}
158 unzipPA pa pb = closure1 (unzipPA_v pa pb) (unzipPA_l pa pb)
159
160 packPA_v :: PA a -> PArray a -> PArray Bool -> PArray a
161 {-# INLINE_PA packPA_v #-}
162 packPA_v pa xs bs
163 = case U.count (toUArrPA bs) True of I# n# -> packPA# pa xs n# (toUArrPA bs)
164
165 packPA_l :: PA a
166 -> PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
167 {-# INLINE_PA packPA_l #-}
168 packPA_l pa xss bss
169 = segmentPA# (lengthPA# xss) (segdPA# xss)
170 $ packPA# pa (concatPA# xss) (elementsSegd# segd') (toUArrPA (concatPA# bss))
171 where
172 segd' = U.lengthsToSegd
173 . U.sum_s (segdPA# xss)
174 . U.map fromBool
175 $ toUArrPA (concatPA# bss)
176
177 packPA :: PA a -> (PArray a :-> PArray Bool :-> PArray a)
178 {-# INLINE packPA #-}
179 packPA pa = closure2 (dPA_PArray pa) (packPA_v pa) (packPA_l pa)
180
181
182 -- TODO: should the selector be a boolean array?
183 combine2PA_v:: PA a -> PArray a -> PArray a -> PArray Int -> PArray a
184 {-# INLINE_PA combine2PA_v #-}
185 combine2PA_v pa xs ys bs
186 = combine2PA# pa (lengthPA# xs +# lengthPA# ys)
187 (tagsToSel2 (toUArrPA bs))
188 xs ys
189
190 combine2PA_l:: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray Int) -> PArray (PArray a)
191 {-# INLINE_PA combine2PA_l #-}
192 combine2PA_l _ _ _ _ = error "combinePA_l nyi"
193
194
195 combine2PA:: PA a -> (PArray a :-> PArray a :-> PArray Int :-> PArray a)
196 {-# INLINE_PA combine2PA #-}
197 combine2PA pa = closure3 (dPA_PArray pa) (dPA_PArray pa) (combine2PA_v pa) (combine2PA_l pa)
198
199
200 filterPA_v :: PA a -> (a :-> Bool) -> PArray a -> PArray a
201 {-# INLINE_PA filterPA_v #-}
202 filterPA_v pa p xs = packPA_v pa xs (mapPA_v pa dPA_Bool p xs)
203
204 filterPA_l :: PA a
205 -> PArray (a :-> Bool) -> PArray (PArray a) -> PArray (PArray a)
206 {-# INLINE_PA filterPA_l #-}
207 filterPA_l pa ps xss = packPA_l pa xss (mapPA_l pa dPA_Bool ps xss)
208
209 filterPA :: PA a -> ((a :-> Bool) :-> PArray a :-> PArray a)
210 {-# INLINE filterPA #-}
211 filterPA pa = closure2 (dPA_Clo pa dPA_Bool) (filterPA_v pa) (filterPA_l pa)
212
213 indexPA_v :: PA a -> PArray a -> Int -> a
214 {-# INLINE_PA indexPA_v #-}
215 indexPA_v pa xs (I# i#) = indexPA# pa xs i#
216
217 indexPA_l :: PA a -> PArray (PArray a) -> PArray Int -> PArray a
218 {-# INLINE_PA indexPA_l #-}
219 indexPA_l pa xss is
220 = bpermutePA# pa (concatPA# xss)
221 (lengthPA# xss)
222 (U.zipWith (+) (U.indicesSegd (segdPA# xss)) (toUArrPA is))
223
224 indexPA :: PA a -> (PArray a :-> Int :-> a)
225 {-# INLINE indexPA #-}
226 indexPA pa = closure2 (dPA_PArray pa) (indexPA_v pa) (indexPA_l pa)
227
228 concatPA_v :: PA a -> PArray (PArray a) -> PArray a
229 {-# INLINE_PA concatPA_v #-}
230 concatPA_v pa xss = concatPA# xss
231
232 concatPA_l :: PA a -> PArray (PArray (PArray a)) -> PArray (PArray a)
233 {-# INLINE_PA concatPA_l #-}
234 concatPA_l pa (PArray m# (PNested segd1 (PNested segd2 xs)))
235 = PArray m#
236 (PNested (U.mkSegd (U.sum_s segd1 (U.lengthsSegd segd2))
237 (U.bpermute (U.indicesSegd segd2) (U.indicesSegd segd1))
238 (U.elementsSegd segd2))
239 xs)
240
241 concatPA :: PA a -> (PArray (PArray a) :-> PArray a)
242 {-# INLINE concatPA #-}
243 concatPA pa = closure1 (concatPA_v pa) (concatPA_l pa)
244
245 appPA_v :: PA a -> PArray a -> PArray a -> PArray a
246 {-# INLINE_PA appPA_v #-}
247 appPA_v pa xs ys = appPA# pa xs ys
248
249 appPA_l :: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
250 {-# INLINE_PA appPA_l #-}
251 appPA_l pa xss yss
252 = segmentPA# (lengthPA# xss +# lengthPA# yss)
253 segd
254 xys
255 where
256 xsegd = segdPA# xss
257 ysegd = segdPA# yss
258
259 segd = U.mkSegd (U.zipWith (+) (U.lengthsSegd xsegd) (U.lengthsSegd ysegd))
260 (U.zipWith (+) (U.indicesSegd xsegd) (U.indicesSegd ysegd))
261 (U.elementsSegd xsegd + U.elementsSegd ysegd)
262
263 xys = applPA# pa xsegd (concatPA# xss) ysegd (concatPA# yss)
264
265 appPA :: PA a -> (PArray a :-> PArray a :-> PArray a)
266 {-# INLINE appPA #-}
267 appPA pa = closure2 (dPA_PArray pa) (appPA_v pa) (appPA_l pa)
268
269
270 enumFromToPA_v :: Int -> Int -> PArray Int
271 {-# INLINE_PA enumFromToPA_v #-}
272 enumFromToPA_v m n = fromUArrPA (distance m n) (U.enumFromTo m n)
273
274 distance :: Int -> Int -> Int
275 {-# INLINE_STREAM distance #-}
276 distance m n = max 0 (n - m + 1)
277
278 enumFromToPA_l :: PArray Int -> PArray Int -> PArray (PArray Int)
279 {-# INLINE_PA enumFromToPA_l #-}
280 enumFromToPA_l ms ns
281 = segmentPA# (lengthPA# ms) segd
282 . fromUArrPA (I# (lengthPA# ms))
283 . U.enumFromToEach (U.elementsSegd segd)
284 $ U.zip (toUArrPA ms) (toUArrPA ns)
285 where
286 segd = U.lengthsToSegd
287 $ U.zipWith distance (toUArrPA ms) (toUArrPA ns)
288
289 enumFromToPA_Int :: Int :-> Int :-> PArray Int
290 {-# INLINE enumFromToPA_Int #-}
291 enumFromToPA_Int = closure2 dPA_Int enumFromToPA_v enumFromToPA_l
292