7a8da67ad8b122bcb96612c82c996ba16d9446b9
[packages/dph.git] / dph-common / Data / Array / Parallel / Lifted / Scalar.hs
1 {-# LANGUAGE CPP #-}
2
3 #include "fusion-phases.h"
4
5 module Data.Array.Parallel.Lifted.Scalar
6 where
7
8 import Data.Array.Parallel.Lifted.PArray
9 import Data.Array.Parallel.Lifted.Unboxed
10 import Data.Array.Parallel.Lifted.Repr
11 import Data.Array.Parallel.Lifted.Instances
12
13 import qualified Data.Array.Parallel.Unlifted as U
14
15 import Data.Array.Parallel.Base ((:*:)(..), fstS, pairS, unpairS)
16
17 import GHC.Exts ( Int(..), (-#) )
18 import GHC.Word ( Word8 )
19
20 class U.Elt a => Scalar a where
21 fromUArrPA :: Int -> U.Array a -> PArray a
22 toUArrPA :: PArray a -> U.Array a
23 primPA :: PA a
24
25 prim_lengthPA :: Scalar a => PArray a -> Int
26 {-# INLINE prim_lengthPA #-}
27 prim_lengthPA xs = I# (lengthPA# primPA xs)
28
29 fromUArrPA' :: Scalar a => U.Array a -> PArray a
30 {-# INLINE fromUArrPA' #-}
31 fromUArrPA' xs = fromUArrPA (U.length xs) xs
32
33 scalar_map :: (Scalar a, Scalar b) => (a -> b) -> PArray a -> PArray b
34 {-# INLINE_PA scalar_map #-}
35 scalar_map f xs = fromUArrPA (prim_lengthPA xs)
36 . U.map f
37 $ toUArrPA xs
38
39 scalar_zipWith :: (Scalar a, Scalar b, Scalar c)
40 => (a -> b -> c) -> PArray a -> PArray b -> PArray c
41 {-# INLINE_PA scalar_zipWith #-}
42 scalar_zipWith f xs ys = fromUArrPA (prim_lengthPA xs)
43 $ U.zipWith f (toUArrPA xs) (toUArrPA ys)
44
45 scalar_zipWith3
46 :: (Scalar a, Scalar b, Scalar c, Scalar d)
47 => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
48 {-# INLINE_PA scalar_zipWith3 #-}
49 scalar_zipWith3 f xs ys zs
50 = fromUArrPA (prim_lengthPA xs)
51 $ U.zipWith3 f (toUArrPA xs) (toUArrPA ys) (toUArrPA zs)
52
53 scalar_fold :: Scalar a => (a -> a -> a) -> a -> PArray a -> a
54 {-# INLINE_PA scalar_fold #-}
55 scalar_fold f z = U.fold f z . toUArrPA
56
57 scalar_fold1 :: Scalar a => (a -> a -> a) -> PArray a -> a
58 {-# INLINE_PA scalar_fold1 #-}
59 scalar_fold1 f = U.fold1 f . toUArrPA
60
61 scalar_folds :: Scalar a => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a
62 {-# INLINE_PA scalar_folds #-}
63 scalar_folds f z xss = fromUArrPA (prim_lengthPA (concatPA# xss))
64 . U.fold_s f z (segdOfPA# primPA xss)
65 . toUArrPA
66 $ concatPA# xss
67
68 scalar_fold1s :: Scalar a => (a -> a -> a) -> PArray (PArray a) -> PArray a
69 {-# INLINE_PA scalar_fold1s #-}
70 scalar_fold1s f xss = fromUArrPA (prim_lengthPA (concatPA# xss))
71 . U.fold1_s f (segdOfPA# primPA xss)
72 . toUArrPA
73 $ concatPA# xss
74
75 scalar_fold1Index :: Scalar a
76 => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int
77 {-# INLINE_PA scalar_fold1Index #-}
78 scalar_fold1Index f = fstS . U.fold1 f' . U.indexed . toUArrPA
79 where
80 {-# INLINE f' #-}
81 f' p q = pairS $ f (unpairS p) (unpairS q)
82
83 scalar_fold1sIndex :: Scalar a
84 => ((Int, a) -> (Int, a) -> (Int, a))
85 -> PArray (PArray a) -> PArray Int
86 {-# INLINE_PA scalar_fold1sIndex #-}
87 scalar_fold1sIndex f xss = fromUArrPA n
88 . U.fsts
89 . U.fold1_s f' segd
90 . U.zip (U.indices_s m segd n)
91 . toUArrPA
92 $ concatPA# xss
93 where
94 {-# INLINE f' #-}
95 f' p q = pairS $ f (unpairS p) (unpairS q)
96
97 m = I# (lengthPA# (dPA_PArray primPA) xss)
98 n = I# (lengthPA# primPA (concatPA# xss))
99
100 segd = segdOfPA# primPA xss
101
102 instance Scalar Int where
103 fromUArrPA (I# n#) xs = PInt n# xs
104 toUArrPA (PInt _ xs) = xs
105 primPA = dPA_Int
106
107 instance Scalar Word8 where
108 fromUArrPA (I# n#) xs = PWord8 n# xs
109 toUArrPA (PWord8 _ xs) = xs
110 primPA = dPA_Word8
111
112 instance Scalar Double where
113 fromUArrPA (I# n#) xs = PDouble n# xs
114 toUArrPA (PDouble _ xs) = xs
115 primPA = dPA_Double
116
117 instance Scalar Bool where
118 {-# INLINE fromUArrPA #-}
119 fromUArrPA (I# n#) bs
120 = PBool n# ts is
121 (PVoid (n# -# m#))
122 (PVoid m#)
123 where
124 ts = U.map (\b -> if b then 1 else 0) bs
125
126 is = U.zipWith3 if_ ts (U.scan (+) 0 ts) (U.scan (+) 0 $ U.map not_ ts)
127
128 !m# = case U.sum ts of I# m# -> m#
129
130 {-# INLINE if_ #-}
131 if_ 0 x y = y
132 if_ _ x y = x
133
134 {-# INLINE not_ #-}
135 not_ 0 = 1
136 not_ _ = 0
137
138 {-# INLINE toUArrPA #-}
139 toUArrPA (PBool _ ts _ _ _) = U.map (/= 0) ts
140
141 primPA = dPA_Bool
142
143
144 fromUArrPA_2 :: (Scalar a, Scalar b) => Int -> U.Array (a :*: b) -> PArray (a,b)
145 {-# INLINE fromUArrPA_2 #-}
146 fromUArrPA_2 (I# n#) ps = P_2 n# (fromUArrPA (I# n#) xs) (fromUArrPA (I# n#) ys)
147 where
148 xs :*: ys = U.unzip ps
149
150
151
152 fromUArrPA_2' :: (Scalar a, Scalar b) => U.Array (a :*: b) -> PArray (a, b)
153 {-# INLINE fromUArrPA_2' #-}
154 fromUArrPA_2' ps = fromUArrPA_2 (U.length ps) ps
155
156 fromUArrPA_3 :: (Scalar a, Scalar b, Scalar c) => Int -> U.Array (a :*: b :*: c) -> PArray (a,b,c)
157 {-# INLINE fromUArrPA_3 #-}
158 fromUArrPA_3 (I# n#) ps = P_3 n# (fromUArrPA (I# n#) xs) (fromUArrPA (I# n#) ys) (fromUArrPA (I# n#) zs)
159 where
160 xs :*: ys :*: zs = U.unzip3 ps
161
162 fromUArrPA_3' :: (Scalar a, Scalar b, Scalar c) => U.Array (a :*: b :*: c) -> PArray (a, b, c)
163 {-# INLINE fromUArrPA_3' #-}
164 fromUArrPA_3' ps = fromUArrPA_3 (U.length ps) ps
165
166 nestUSegdPA :: Int -> U.Segd -> PArray a -> PArray (PArray a)
167 {-# INLINE nestUSegdPA #-}
168 nestUSegdPA (I# n#) segd xs = PNested n# (U.lengthsSegd segd)
169 (U.indicesSegd segd)
170 xs
171
172 nestUSegdPA' :: U.Segd -> PArray a -> PArray (PArray a)
173 {-# INLINE nestUSegdPA' #-}
174 nestUSegdPA' segd xs = nestUSegdPA (U.lengthSegd segd) segd xs
175
176
177 {-
178 fromSUArrPA :: Scalar a => Int -> Int -> U.SArray a -> PArray (PArray a)
179 {-# INLINE fromSUArrPA #-}
180 fromSUArrPA (I# m#) n xss
181 = PNested m# (U.lengths_s xss)
182 (U.indices_s xss)
183 (fromUArrPA n (U.concat xss))
184
185 toSUArrPA :: Scalar a => PArray (PArray a) -> U.SArray a
186 {-# INLINE toSUArrPA #-}
187 toSUArrPA (PNested _ lens idxs xs) = U.toSegd (U.zip lens idxs) U.>: toUArrPA xs
188
189 fromSUArrPA_2 :: (Scalar a, Scalar b)
190 => Int -> Int -> U.SArray (a :*: b) -> PArray (PArray (a, b))
191 {-# INLINE fromSUArrPA_2 #-}
192 fromSUArrPA_2 (I# m#) n pss = PNested m# (U.lengths_s pss)
193 (U.indices_s pss)
194 (fromUArrPA_2 n (U.concat pss))
195
196 fromSUArrPA' :: Scalar a => U.SArray a -> PArray (PArray a)
197 {-# INLINE fromSUArrPA' #-}
198 fromSUArrPA' xss = fromSUArrPA (U.length_s xss)
199 (U.length (U.concat xss))
200 xss
201
202 fromSUArrPA_2' :: (Scalar a, Scalar b)
203 => U.SArray (a :*: b) -> PArray (PArray (a, b))
204 {-# INLINE fromSUArrPA_2' #-}
205 fromSUArrPA_2' pss = fromSUArrPA_2 (U.length_s pss)
206 (U.length (U.concat pss))
207 pss
208 -}
209