Makefiles need real tab chars, ffs.
[packages/dph.git] / dph-lifted-copy / Data / Array / Parallel / Lifted / Scalar.hs
1 {-# OPTIONS -fno-warn-orphans #-}
2 {-# LANGUAGE CPP #-}
3
4 #include "fusion-phases.h"
5
6 module Data.Array.Parallel.Lifted.Scalar
7 where
8 import Data.Array.Parallel.Lifted.PArray
9 import Data.Array.Parallel.PArray.PReprInstances
10 import Data.Array.Parallel.PArray.PDataInstances
11 import qualified Data.Array.Parallel.Unlifted as U
12 import Data.Array.Parallel.Base (fromBool, toBool)
13 import GHC.Exts (Int(..))
14
15
16 -- Pretend Bools are scalars --------------------------------------------------
17 instance Scalar Bool where
18 {-# INLINE toScalarPData #-}
19 toScalarPData bs
20 = PBool (U.tagsToSel2 (U.map fromBool bs))
21
22 {-# INLINE fromScalarPData #-}
23 fromScalarPData (PBool sel) = U.map toBool (U.tagsSel2 sel)
24
25
26 -- Projections ----------------------------------------------------------------
27 prim_lengthPA :: Scalar a => PArray a -> Int
28 {-# INLINE prim_lengthPA #-}
29 prim_lengthPA xs = I# (lengthPA# xs)
30
31
32 -- Conversion -----------------------------------------------------------------
33 -- | Create a PArray out of a scalar U.Array,
34 -- the first argument is the array length.
35 --
36 -- TODO: ditch this version, just use fromUArrPA'
37 --
38 fromUArrPA :: Scalar a => Int -> U.Array a -> PArray a
39 {-# INLINE fromUArrPA #-}
40 fromUArrPA (I# n#) xs = PArray n# (toScalarPData xs)
41
42
43 -- | Create a PArray out of a scalar U.Array,
44 -- reading the length directly from the U.Array.
45 fromUArrPA' :: Scalar a => U.Array a -> PArray a
46 {-# INLINE fromUArrPA' #-}
47 fromUArrPA' xs = fromUArrPA (U.length xs) xs
48
49
50 -- | Convert a PArray back to a plain U.Array.
51 toUArrPA :: Scalar a => PArray a -> U.Array a
52 {-# INLINE toUArrPA #-}
53 toUArrPA (PArray _ xs) = fromScalarPData xs
54
55
56 -- Tuple Conversions ----------------------------------------------------------
57 -- | Convert an U.Array of pairs to a PArray.
58 fromUArrPA_2
59 :: (Scalar a, Scalar b)
60 => Int -> U.Array (a,b) -> PArray (a,b)
61 {-# INLINE fromUArrPA_2 #-}
62 fromUArrPA_2 (I# n#) ps
63 = PArray n# (P_2 (toScalarPData xs) (toScalarPData ys))
64 where
65 (xs,ys) = U.unzip ps
66
67
68 -- | Convert a U.Array of pairs to a PArray,
69 -- reading the length directly from the U.Array.
70 fromUArrPA_2'
71 :: (Scalar a, Scalar b)
72 => U.Array (a,b) -> PArray (a, b)
73 {-# INLINE fromUArrPA_2' #-}
74 fromUArrPA_2' ps
75 = fromUArrPA_2 (U.length ps) ps
76
77
78 -- | Convert a U.Array of triples to a PArray.
79 fromUArrPA_3
80 :: (Scalar a, Scalar b, Scalar c)
81 => Int -> U.Array ((a,b),c) -> PArray (a,b,c)
82 {-# INLINE fromUArrPA_3 #-}
83 fromUArrPA_3 (I# n#) ps
84 = PArray n# (P_3 (toScalarPData xs)
85 (toScalarPData ys)
86 (toScalarPData zs))
87 where
88 (qs,zs) = U.unzip ps
89 (xs,ys) = U.unzip qs
90
91
92 -- | Convert a U.Array of triples to a PArray,
93 -- reading the length directly from the U.Array.
94 fromUArrPA_3'
95 :: (Scalar a, Scalar b, Scalar c)
96 => U.Array ((a,b),c) -> PArray (a, b, c)
97 {-# INLINE fromUArrPA_3' #-}
98 fromUArrPA_3' ps = fromUArrPA_3 (U.length ps) ps
99
100
101 -- Nesting arrays -------------------------------------------------------------
102 -- | O(1). Create a nested array.
103 nestUSegdPA
104 :: Int -- ^ total number of elements in the nested array
105 -> U.Segd -- ^ segment descriptor
106 -> PArray a -- ^ array of data elements.
107 -> PArray (PArray a)
108
109 {-# INLINE nestUSegdPA #-}
110 nestUSegdPA (I# n#) segd (PArray _ xs)
111 = PArray n# (PNested segd xs)
112
113
114 -- | O(1). Create a nested array,
115 -- using the same length as the source array.
116 nestUSegdPA'
117 :: U.Segd -- ^ segment descriptor
118 -> PArray a -- ^ array of data elements
119 -> PArray (PArray a)
120
121 {-# INLINE nestUSegdPA' #-}
122 nestUSegdPA' segd xs
123 = nestUSegdPA (U.lengthSegd segd) segd xs
124
125
126 -- Scalar Operators -----------------------------------------------------------
127 -- These work on PArrays of scalar elements.
128 -- TODO: Why do we need these versions as well as the standard ones?
129
130 -- | Apply a worker function to every element of an array, yielding a new array.
131 scalar_map
132 :: (Scalar a, Scalar b)
133 => (a -> b) -> PArray a -> PArray b
134
135 {-# INLINE_PA scalar_map #-}
136 scalar_map f xs
137 = fromUArrPA (prim_lengthPA xs)
138 . U.map f
139 $ toUArrPA xs
140
141
142 -- | Zip two arrays, yielding a new array.
143 scalar_zipWith
144 :: (Scalar a, Scalar b, Scalar c)
145 => (a -> b -> c) -> PArray a -> PArray b -> PArray c
146
147 {-# INLINE_PA scalar_zipWith #-}
148 scalar_zipWith f xs ys
149 = fromUArrPA (prim_lengthPA xs)
150 $ U.zipWith f (toUArrPA xs) (toUArrPA ys)
151
152
153 -- | Zip three arrays, yielding a new array.
154 scalar_zipWith3
155 :: (Scalar a, Scalar b, Scalar c, Scalar d)
156 => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
157
158 {-# INLINE_PA scalar_zipWith3 #-}
159 scalar_zipWith3 f xs ys zs
160 = fromUArrPA (prim_lengthPA xs)
161 $ U.zipWith3 f (toUArrPA xs) (toUArrPA ys) (toUArrPA zs)
162
163
164 -- | Left fold over an array.
165 scalar_fold
166 :: Scalar a
167 => (a -> a -> a) -> a -> PArray a -> a
168
169 {-# INLINE_PA scalar_fold #-}
170 scalar_fold f z
171 = U.fold f z . toUArrPA
172
173
174 -- | Left fold over an array, using the first element to initialise the state.
175 scalar_fold1
176 :: Scalar a
177 => (a -> a -> a) -> PArray a -> a
178
179 {-# INLINE_PA scalar_fold1 #-}
180 scalar_fold1 f
181 = U.fold1 f . toUArrPA
182
183
184 -- | Segmented fold of an array of arrays.
185 -- Each segment is folded individually, yielding an array of the fold results.
186 scalar_folds
187 :: Scalar a
188 => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a
189
190 {-# INLINE_PA scalar_folds #-}
191 scalar_folds f z xss
192 = fromUArrPA (prim_lengthPA (concatPA# xss))
193 . U.fold_s f z (segdPA# xss)
194 . toUArrPA
195 $ concatPA# xss
196
197
198 -- | Segmented fold of an array of arrays, using the first element of each
199 -- segment to initialse the state for that segment.
200 -- Each segment is folded individually, yielding an array of all the fold results.
201 scalar_fold1s
202 :: Scalar a
203 => (a -> a -> a) -> PArray (PArray a) -> PArray a
204
205 {-# INLINE_PA scalar_fold1s #-}
206 scalar_fold1s f xss
207 = fromUArrPA (prim_lengthPA (concatPA# xss))
208 . U.fold1_s f (segdPA# xss)
209 . toUArrPA
210 $ concatPA# xss
211
212
213 -- | Left fold over an array, also passing the index of each element
214 -- to the parameter function.
215 scalar_fold1Index
216 :: Scalar a
217 => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int
218
219 {-# INLINE_PA scalar_fold1Index #-}
220 scalar_fold1Index f
221 = fst . U.fold1 f . U.indexed . toUArrPA
222
223
224 -- | Segmented fold over an array, also passing the index of each
225 -- element to the parameter function.
226 scalar_fold1sIndex
227 :: Scalar a
228 => ((Int, a) -> (Int, a) -> (Int, a))
229 -> PArray (PArray a) -> PArray Int
230
231 {-# INLINE_PA scalar_fold1sIndex #-}
232 scalar_fold1sIndex f (PArray m# (PNested segd xs))
233 = PArray m#
234 $ toScalarPData
235 $ U.fsts
236 $ U.fold1_s f segd
237 $ U.zip (U.indices_s segd)
238 $ fromScalarPData xs
239