Prepare dph for a vectInfoVar type change
[packages/dph.git] / dph-lifted-copy / Data / Array / Parallel / Lifted / Scalar.hs
1 {-# OPTIONS -fno-warn-orphans #-}
2 {-# LANGUAGE CPP #-}
3
4 #include "fusion-phases.h"
5
6 module Data.Array.Parallel.Lifted.Scalar
7 where
8 import Data.Array.Parallel.Lifted.PArray
9 import Data.Array.Parallel.PArray.PReprInstances
10 import Data.Array.Parallel.PArray.PDataInstances
11 import qualified Data.Array.Parallel.Unlifted as U
12 import Data.Array.Parallel.Base (fromBool, toBool)
13 import GHC.Exts (Int(..))
14
15
16 -- Pretend Bools are scalars --------------------------------------------------
17 instance Scalar Bool where
18 {-# INLINE toScalarPData #-}
19 toScalarPData bs
20 = PBool (U.tagsToSel2 (U.map fromBool bs))
21
22 {-# INLINE fromScalarPData #-}
23 fromScalarPData (PBool sel) = U.map toBool (U.tagsSel2 sel)
24
25
26 -- Projections ----------------------------------------------------------------
27 prim_lengthPA :: Scalar a => PArray a -> Int
28 {-# INLINE prim_lengthPA #-}
29 prim_lengthPA xs = I# (lengthPA# xs)
30
31
32 -- Conversion -----------------------------------------------------------------
33 -- | Create a PArray out of a scalar U.Array,
34 -- the first argument is the array length.
35 --
36 -- TODO: ditch this version, just use fromUArrPA'
37 --
38 fromUArray :: Scalar a => U.Array a -> PArray a
39 {-# INLINE fromUArray #-}
40 fromUArray xs
41 = let !(I# n#) = U.length xs
42 in PArray n# (toScalarPData xs)
43
44 -- TODO: Why do we want this version that takes the length explicitly?
45 -- Is there some fusion issue that requires this?
46 fromUArray' :: Scalar a => Int -> U.Array a -> PArray a
47 {-# INLINE fromUArray' #-}
48 fromUArray' (I# n#) xs
49 = PArray n# (toScalarPData xs)
50
51
52 -- | Convert a PArray back to a plain U.Array.
53 toUArray :: Scalar a => PArray a -> U.Array a
54 {-# INLINE toUArray #-}
55 toUArray (PArray _ xs) = fromScalarPData xs
56
57
58 -- Tuple Conversions ----------------------------------------------------------
59 -- | Convert an U.Array of pairs to a PArray.
60 fromUArray2
61 :: (Scalar a, Scalar b)
62 => U.Array (a,b) -> PArray (a,b)
63 {-# INLINE fromUArray2 #-}
64 fromUArray2 ps
65 = let !(I# n#) = U.length ps
66 (xs, ys) = U.unzip ps
67 in PArray n# (P_2 (toScalarPData xs) (toScalarPData ys))
68
69
70 -- | Convert a U.Array of triples to a PArray.
71 fromUArray3
72 :: (Scalar a, Scalar b, Scalar c)
73 => U.Array ((a,b),c) -> PArray (a,b,c)
74 {-# INLINE fromUArray3 #-}
75 fromUArray3 ps
76 = let !(I# n#) = U.length ps
77 (qs,zs) = U.unzip ps
78 (xs,ys) = U.unzip qs
79 in PArray n# (P_3 (toScalarPData xs)
80 (toScalarPData ys)
81 (toScalarPData zs))
82
83
84 -- Nesting arrays -------------------------------------------------------------
85 -- | O(1). Create a nested array.
86 nestUSegd
87 :: U.Segd -- ^ segment descriptor
88 -> PArray a -- ^ array of data elements.
89 -> PArray (PArray a)
90
91 {-# INLINE nestUSegd #-}
92 nestUSegd segd (PArray _ xs)
93 = let !(I# n#) = U.lengthSegd segd
94 in PArray n# (PNested segd xs)
95
96
97 -- Scalar Operators -----------------------------------------------------------
98 -- These work on PArrays of scalar elements.
99 -- TODO: Why do we need these versions as well as the standard ones?
100
101 -- | Apply a worker function to every element of an array, yielding a new array.
102 scalar_map
103 :: (Scalar a, Scalar b)
104 => (a -> b) -> PArray a -> PArray b
105
106 {-# INLINE_PA scalar_map #-}
107 scalar_map f xs
108 = fromUArray' (prim_lengthPA xs)
109 . U.map f
110 $ toUArray xs
111
112
113 -- | Zip two arrays, yielding a new array.
114 scalar_zipWith
115 :: (Scalar a, Scalar b, Scalar c)
116 => (a -> b -> c) -> PArray a -> PArray b -> PArray c
117
118 {-# INLINE_PA scalar_zipWith #-}
119 scalar_zipWith f xs ys
120 = fromUArray' (prim_lengthPA xs)
121 $ U.zipWith f (toUArray xs) (toUArray ys)
122
123
124 -- | Zip three arrays, yielding a new array.
125 scalar_zipWith3
126 :: (Scalar a, Scalar b, Scalar c, Scalar d)
127 => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
128
129 {-# INLINE_PA scalar_zipWith3 #-}
130 scalar_zipWith3 f xs ys zs
131 = fromUArray' (prim_lengthPA xs)
132 $ U.zipWith3 f (toUArray xs) (toUArray ys) (toUArray zs)
133
134
135
136
137 -- | Zip four arrays, yielding a new array.
138 scalar_zipWith4
139 :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e)
140 => (a -> b -> c -> d -> e) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e
141
142 {-# INLINE_PA scalar_zipWith4 #-}
143 scalar_zipWith4 f ws xs ys zs
144 = fromUArray' (prim_lengthPA ws)
145 $ U.zipWith4 f (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
146
147
148 -- | Zip five arrays, yielding a new array.
149 scalar_zipWith5
150 :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e, Scalar f)
151 => (a -> b -> c -> d -> e -> f) -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray f
152
153 {-# INLINE_PA scalar_zipWith5 #-}
154 scalar_zipWith5 f vs ws xs ys zs
155 = fromUArray' (prim_lengthPA vs)
156 $ U.zipWith5 f (toUArray vs) (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
157
158
159 -- | Zip six arrays, yielding a new array.
160 scalar_zipWith6
161 :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e, Scalar f, Scalar g)
162 => (a -> b -> c -> d -> e -> f -> g)
163 -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray f-> PArray g
164
165 {-# INLINE_PA scalar_zipWith6 #-}
166 scalar_zipWith6 f us vs ws xs ys zs
167 = fromUArray' (prim_lengthPA us)
168 $ U.zipWith6 f (toUArray us) (toUArray vs) (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
169
170 -- | Zip seven arrays, yielding a new array.
171 scalar_zipWith7
172 :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e, Scalar f, Scalar g, Scalar h)
173 => (a -> b -> c -> d -> e -> f -> g -> h)
174 -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray f-> PArray g -> PArray h
175
176 {-# INLINE_PA scalar_zipWith7 #-}
177 scalar_zipWith7 f ts us vs ws xs ys zs
178 = fromUArray' (prim_lengthPA us)
179 $ U.zipWith7 f (toUArray ts) (toUArray us) (toUArray vs) (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
180
181
182 -- | Zip eight arrays, yielding a new array.
183 scalar_zipWith8
184 :: (Scalar a, Scalar b, Scalar c, Scalar d, Scalar e, Scalar f, Scalar g, Scalar h, Scalar i)
185 => (a -> b -> c -> d -> e -> f -> g -> h -> i)
186 -> PArray a -> PArray b -> PArray c -> PArray d -> PArray e -> PArray f-> PArray g -> PArray h -> PArray i
187
188 {-# INLINE_PA scalar_zipWith8 #-}
189 scalar_zipWith8 f ss ts us vs ws xs ys zs
190 = fromUArray' (prim_lengthPA ss)
191 $ U.zipWith8 f (toUArray ss) (toUArray ts) (toUArray us) (toUArray vs) (toUArray ws) (toUArray xs) (toUArray ys) (toUArray zs)
192
193 -- | Left fold over an array.
194 scalar_fold
195 :: Scalar a
196 => (a -> a -> a) -> a -> PArray a -> a
197
198 {-# INLINE_PA scalar_fold #-}
199 scalar_fold f z
200 = U.fold f z . toUArray
201
202
203 -- | Left fold over an array, using the first element to initialise the state.
204 scalar_fold1
205 :: Scalar a
206 => (a -> a -> a) -> PArray a -> a
207
208 {-# INLINE_PA scalar_fold1 #-}
209 scalar_fold1 f
210 = U.fold1 f . toUArray
211
212
213 -- | Segmented fold of an array of arrays.
214 -- Each segment is folded individually, yielding an array of the fold results.
215 scalar_folds
216 :: Scalar a
217 => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a
218
219 {-# INLINE_PA scalar_folds #-}
220 scalar_folds f z xss
221 = fromUArray' (prim_lengthPA (concatPA# xss))
222 . U.fold_s f z (segdPA# xss)
223 . toUArray
224 $ concatPA# xss
225
226
227 -- | Segmented fold of an array of arrays, using the first element of each
228 -- segment to initialse the state for that segment.
229 -- Each segment is folded individually, yielding an array of all the fold results.
230 scalar_fold1s
231 :: Scalar a
232 => (a -> a -> a) -> PArray (PArray a) -> PArray a
233
234 {-# INLINE_PA scalar_fold1s #-}
235 scalar_fold1s f xss
236 = fromUArray' (prim_lengthPA (concatPA# xss))
237 . U.fold1_s f (segdPA# xss)
238 . toUArray
239 $ concatPA# xss
240
241
242 -- | Left fold over an array, also passing the index of each element
243 -- to the parameter function.
244 scalar_fold1Index
245 :: Scalar a
246 => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int
247
248 {-# INLINE_PA scalar_fold1Index #-}
249 scalar_fold1Index f
250 = fst . U.fold1 f . U.indexed . toUArray
251
252
253 -- | Segmented fold over an array, also passing the index of each
254 -- element to the parameter function.
255 scalar_fold1sIndex
256 :: Scalar a
257 => ((Int, a) -> (Int, a) -> (Int, a))
258 -> PArray (PArray a) -> PArray Int
259
260 {-# INLINE_PA scalar_fold1sIndex #-}
261 scalar_fold1sIndex f (PArray m# (PNested segd xs))
262 = PArray m#
263 $ toScalarPData
264 $ U.fsts
265 $ U.fold1_s f segd
266 $ U.zip (U.indices_s segd)
267 $ fromScalarPData xs
268