88d63717c5f0db104f65e3a91528e5a73b8a7b97
[packages/dph.git] / dph-common / Data / Array / Parallel / Lifted / Combinators.hs
1 {-# LANGUAGE CPP #-}
2
3 #include "fusion-phases.h"
4
5 module Data.Array.Parallel.Lifted.Combinators (
6 lengthPA, replicatePA, singletonPA, mapPA, crossMapPA,
7 zipWithPA, zipPA, unzipPA,
8 packPA, filterPA, combine2PA, indexPA, concatPA, appPA, enumFromToPA_Int,
9
10 lengthPA_v, replicatePA_v, singletonPA_v, zipPA_v, unzipPA_v,
11 indexPA_v, appPA_v, enumFromToPA_v
12 ) where
13
14 import Data.Array.Parallel.Lifted.PArray
15 import Data.Array.Parallel.Lifted.Closure
16 import Data.Array.Parallel.Lifted.Unboxed
17 import Data.Array.Parallel.Lifted.Repr
18 import Data.Array.Parallel.Lifted.Instances
19
20 import GHC.Exts (Int(..), (+#), (-#), Int#, (<#))
21
22 lengthPA_v :: PA a -> PArray a -> Int
23 {-# INLINE_PA lengthPA_v #-}
24 lengthPA_v pa xs = I# (lengthPA# pa xs)
25
26 lengthPA_l :: PA a -> PArray (PArray a) -> PArray Int
27 {-# INLINE_PA lengthPA_l #-}
28 lengthPA_l pa (PNested n# lens _ _) = PInt n# lens
29
30 lengthPA :: PA a -> (PArray a :-> Int)
31 {-# INLINE lengthPA #-}
32 lengthPA pa = closure1 (lengthPA_v pa) (lengthPA_l pa)
33
34 replicatePA_v :: PA a -> Int -> a -> PArray a
35 {-# INLINE_PA replicatePA_v #-}
36 replicatePA_v pa (I# n#) x = replicatePA# pa n# x
37
38 replicatePA_l :: PA a -> PArray Int -> PArray a -> PArray (PArray a)
39 {-# INLINE_PA replicatePA_l #-}
40 replicatePA_l pa (PInt n# ns) xs
41 = PNested n# ns (indicesSegdPA# segd)
42 (replicatelPA# pa segd xs)
43 where
44 segd = lengthsToSegdPA# ns
45
46 replicatePA :: PA a -> (Int :-> a :-> PArray a)
47 {-# INLINE replicatePA #-}
48 replicatePA pa = closure2 dPA_Int (replicatePA_v pa) (replicatePA_l pa)
49
50 singletonPA_v :: PA a -> a -> PArray a
51 {-# INLINE_PA singletonPA_v #-}
52 singletonPA_v pa x = replicatePA_v pa 1 x
53
54 singletonPA_l :: PA a -> PArray a -> PArray (PArray a)
55 {-# INLINE_PA singletonPA_l #-}
56 singletonPA_l pa xs
57 = case lengthPA# pa xs of
58 n# -> PNested n# (replicatePA_Int# n# 1#) (upToPA_Int# n#) xs
59
60 singletonPA :: PA a -> (a :-> PArray a)
61 {-# INLINE singletonPA #-}
62 singletonPA pa = closure1 (singletonPA_v pa) (singletonPA_l pa)
63
64 mapPA_v :: PA a -> PA b -> (a :-> b) -> PArray a -> PArray b
65 {-# INLINE_PA mapPA_v #-}
66 mapPA_v pa pb f as = replicatePA# (dPA_Clo pa pb) (lengthPA# pa as) f
67 $:^ as
68
69 mapPA_l :: PA a -> PA b
70 -> PArray (a :-> b) -> PArray (PArray a) -> PArray (PArray b)
71 {-# INLINE_PA mapPA_l #-}
72 mapPA_l pa pb fs xss@(PNested n# lens idxs xs)
73 = PNested n# lens idxs
74 (replicatelPA# (dPA_Clo pa pb) (segdOfPA# pa xss) fs $:^ xs)
75
76 mapPA :: PA a -> PA b -> ((a :-> b) :-> PArray a :-> PArray b)
77 {-# INLINE mapPA #-}
78 mapPA pa pb = closure2 (dPA_Clo pa pb) (mapPA_v pa pb) (mapPA_l pa pb)
79
80 crossMapPA_v :: PA a -> PA b -> PArray a -> (a :-> PArray b) -> PArray (a,b)
81 {-# INLINE_PA crossMapPA_v #-}
82 crossMapPA_v pa pb as f
83 = zipPA# pa pb (replicatelPA# pa (segdOfPA# pb bss) as) (concatPA# bss)
84 where
85 bss = mapPA_v pa (dPA_PArray pb) f as
86
87 crossMapPA_l :: PA a -> PA b
88 -> PArray (PArray a)
89 -> PArray (a :-> PArray b)
90 -> PArray (PArray (a,b))
91 {-# INLINE_PA crossMapPA_l #-}
92 crossMapPA_l pa pb ass@(PNested _ _ _ as) fs
93 = case concatPA_l pb bsss of
94 PNested n# lens1 idxs1 bs -> PNested n# lens1 idxs1 (zipPA# pa pb as' bs)
95 where
96 bsss@(PNested _ _ _ bss)
97 = mapPA_l pa (dPA_PArray pb) fs ass
98
99 as' = replicatelPA# pa (segdOfPA# pb bss) as
100
101 crossMapPA :: PA a -> PA b -> (PArray a :-> (a :-> PArray b) :-> PArray (a,b))
102 {-# INLINE crossMapPA #-}
103 crossMapPA pa pb = closure2 (dPA_PArray pa) (crossMapPA_v pa pb)
104 (crossMapPA_l pa pb)
105
106 zipPA_v :: PA a -> PA b -> PArray a -> PArray b -> PArray (a,b)
107 {-# INLINE_PA zipPA_v #-}
108 zipPA_v pa pb xs ys = zipPA# pa pb xs ys
109
110 zipPA_l :: PA a -> PA b
111 -> PArray (PArray a) -> PArray (PArray b) -> PArray (PArray (a,b))
112 {-# INLINE_PA zipPA_l #-}
113 zipPA_l pa pb (PNested n# lens idxs xs) (PNested _ _ _ ys)
114 = PNested n# lens idxs (zipPA_v pa pb xs ys)
115
116 zipPA :: PA a -> PA b -> (PArray a :-> PArray b :-> PArray (a,b))
117 {-# INLINE zipPA #-}
118 zipPA pa pb = closure2 (dPA_PArray pa) (zipPA_v pa pb) (zipPA_l pa pb)
119
120 zipWithPA_v :: PA a -> PA b -> PA c
121 -> (a :-> b :-> c) -> PArray a -> PArray b -> PArray c
122 {-# INLINE_PA zipWithPA_v #-}
123 zipWithPA_v pa pb pc f as bs = replicatePA# (dPA_Clo pa (dPA_Clo pb pc))
124 (lengthPA# pa as)
125 f
126 $:^ as $:^ bs
127
128 zipWithPA_l :: PA a -> PA b -> PA c
129 -> PArray (a :-> b :-> c) -> PArray (PArray a) -> PArray (PArray b)
130 -> PArray (PArray c)
131 {-# INLINE_PA zipWithPA_l #-}
132 zipWithPA_l pa pb pc fs ass@(PNested n# lens idxs as) (PNested _ _ _ bs)
133 = PNested n# lens idxs
134 (replicatelPA# (dPA_Clo pa (dPA_Clo pb pc))
135 (segdOfPA# pa ass) fs $:^ as $:^ bs)
136
137 zipWithPA :: PA a -> PA b -> PA c
138 -> ((a :-> b :-> c) :-> PArray a :-> PArray b :-> PArray c)
139 {-# INLINE zipWithPA #-}
140 zipWithPA pa pb pc = closure3 (dPA_Clo pa (dPA_Clo pb pc)) (dPA_PArray pa)
141 (zipWithPA_v pa pb pc)
142 (zipWithPA_l pa pb pc)
143
144 unzipPA_v:: PA a -> PA b -> PArray (a,b) -> (PArray a, PArray b)
145 unzipPA_v pa pb abs = unzipPA# pa pb abs
146
147 unzipPA_l:: PA a -> PA b -> PArray (PArray (a, b)) -> PArray ((PArray a), (PArray b))
148 unzipPA_l pa pb (PNested n lens idxys xys) =
149 P_2 n (PNested n lens idxys xs) (PNested n lens idxys ys)
150 where
151 (xs, ys) = unzipPA_v pa pb xys
152
153 unzipPA:: PA a -> PA b -> (PArray (a, b) :-> (PArray a, PArray b))
154 {-# INLINE unzipPA #-}
155 unzipPA pa pb = closure1 (unzipPA_v pa pb) (unzipPA_l pa pb)
156
157 packPA_v :: PA a -> PArray a -> PArray Bool -> PArray a
158 {-# INLINE_PA packPA_v #-}
159 packPA_v pa xs bs = packPA# pa xs (truesPA# bs) (toPrimArrPA_Bool bs)
160
161 packPA_l :: PA a
162 -> PArray (PArray a) -> PArray (PArray Bool) -> PArray (PArray a)
163 {-# INLINE_PA packPA_l #-}
164 packPA_l pa (PNested _ _ _ xs) bss
165 = PNested (lengthPA# (dPA_PArray dPA_Bool) bss) lens' idxs' (packPA_v pa xs bs)
166 where
167 lens' = truesPAs_Bool# segd (toPrimArrPA_Bool bs)
168 idxs' = unsafe_scanPA_Int# (+) 0 lens'
169 segd = segdOfPA# dPA_Bool bss
170 bs = concatPA# bss
171
172 packPA :: PA a -> (PArray a :-> PArray Bool :-> PArray a)
173 {-# INLINE packPA #-}
174 packPA pa = closure2 (dPA_PArray pa) (packPA_v pa) (packPA_l pa)
175
176
177 -- TODO: should the selector be a boolean array?
178 -- fix index vector
179 combine2PA_v:: PA a -> PArray a -> PArray a -> PArray Int -> PArray a
180 {-# INLINE_PA combine2PA_v #-}
181 combine2PA_v pa xs ys bs@(PInt _ bs#) =
182 combine2PA# pa (lengthPA# pa xs +# lengthPA# pa ys) bs# bs# xs ys
183
184 combine2PA_l:: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray Int) -> PArray (PArray a)
185 {-# INLINE_PA combine2PA_l #-}
186 combine2PA_l _ _ _ _ = error "combinePA_l nyi"
187
188
189 combine2PA:: PA a -> (PArray a :-> PArray a :-> PArray Int :-> PArray a)
190 {-# INLINE_PA combine2PA #-}
191 combine2PA pa = closure3 (dPA_PArray pa) (dPA_PArray pa) (combine2PA_v pa) (combine2PA_l pa)
192
193
194 filterPA_v :: PA a -> (a :-> Bool) -> PArray a -> PArray a
195 {-# INLINE_PA filterPA_v #-}
196 filterPA_v pa p xs = packPA_v pa xs (mapPA_v pa dPA_Bool p xs)
197
198 filterPA_l :: PA a
199 -> PArray (a :-> Bool) -> PArray (PArray a) -> PArray (PArray a)
200 {-# INLINE_PA filterPA_l #-}
201 filterPA_l pa ps xss = packPA_l pa xss (mapPA_l pa dPA_Bool ps xss)
202
203 filterPA :: PA a -> ((a :-> Bool) :-> PArray a :-> PArray a)
204 {-# INLINE filterPA #-}
205 filterPA pa = closure2 (dPA_Clo pa dPA_Bool) (filterPA_v pa) (filterPA_l pa)
206
207 indexPA_v :: PA a -> PArray a -> Int -> a
208 {-# INLINE_PA indexPA_v #-}
209 indexPA_v pa xs (I# i#) = indexPA# pa xs i#
210
211 indexPA_l :: PA a -> PArray (PArray a) -> PArray Int -> PArray a
212 {-# INLINE_PA indexPA_l #-}
213 indexPA_l pa (PNested _ lens idxs xs) (PInt n# is)
214 = bpermutePA# pa n# xs (unsafe_zipWithPA_Int# (+) idxs is)
215
216 indexPA :: PA a -> (PArray a :-> Int :-> a)
217 {-# INLINE indexPA #-}
218 indexPA pa = closure2 (dPA_PArray pa) (indexPA_v pa) (indexPA_l pa)
219
220 concatPA_v :: PA a -> PArray (PArray a) -> PArray a
221 {-# INLINE_PA concatPA_v #-}
222 concatPA_v pa (PNested _ _ _ xs) = xs
223
224 concatPA_l :: PA a -> PArray (PArray (PArray a)) -> PArray (PArray a)
225 {-# INLINE_PA concatPA_l #-}
226 concatPA_l pa arr@(PNested m# lens1 idxs1 (PNested n# lens2 idxs2 xs))
227 = PNested m# lens idxs xs
228 where
229 lens = sumPAs_Int# segd lens2
230 idxs = bpermutePA_Int# idxs2 idxs1
231 segd = segdOfPA# (dPA_PArray pa) arr
232
233 concatPA :: PA a -> (PArray (PArray a) :-> PArray a)
234 {-# INLINE concatPA #-}
235 concatPA pa = closure1 (concatPA_v pa) (concatPA_l pa)
236
237 appPA_v :: PA a -> PArray a -> PArray a -> PArray a
238 {-# INLINE_PA appPA_v #-}
239 appPA_v pa xs ys = appPA# pa xs ys
240
241 appPA_l :: PA a -> PArray (PArray a) -> PArray (PArray a) -> PArray (PArray a)
242 {-# INLINE_PA appPA_l #-}
243 appPA_l pa xss@(PNested m# lens1 idxs1 xs)
244 yss@(PNested n# lens2 idxs2 ys)
245 = PNested (m# +# n#) (unsafe_zipWithPA_Int# (+) lens1 lens2)
246 (unsafe_zipWithPA_Int# (+) idxs1 idxs2)
247 (applPA# pa (segdOfPA# pa xss) xs
248 (segdOfPA# pa yss) ys)
249
250 appPA :: PA a -> (PArray a :-> PArray a :-> PArray a)
251 {-# INLINE appPA #-}
252 appPA pa = closure2 (dPA_PArray pa) (appPA_v pa) (appPA_l pa)
253
254
255 enumFromToPA_v :: Int -> Int -> PArray Int
256 {-# INLINE_PA enumFromToPA_v #-}
257 enumFromToPA_v m@(I# m#) n@(I# n#) = PInt len# (enumFromToPA_Int# m# n#)
258 where
259 !len# = max# 0# (n# -# m# +# 1#) -- case max 0 (n-m+1) of I# i# -> i#
260
261 max# :: Int# -> Int# -> Int#
262 {-# INLINE_STREAM max# #-}
263 max# m# n# = if m# <# n# then n# else m#
264
265 enumFromToPA_l :: PArray Int -> PArray Int -> PArray (PArray Int)
266 {-# INLINE_PA enumFromToPA_l #-}
267 enumFromToPA_l (PInt k# ms#) (PInt _ ns#) = PNested k# lens# idxs# (PInt n# is#)
268 where
269 lenOf m n = max 0 (n - m + 1)
270
271 lens# = unsafe_zipWithPA_Int# lenOf ms# ns#
272 idxs# = unsafe_scanPA_Int# (+) 0 lens#
273
274 !n# = sumPA_Int# lens#
275 is# = enumFromToEachPA_Int# n# ms# ns#
276
277 enumFromToPA_Int :: Int :-> Int :-> PArray Int
278 {-# INLINE enumFromToPA_Int #-}
279 enumFromToPA_Int = closure2 dPA_Int enumFromToPA_v enumFromToPA_l
280