Separate length from data in array representation
[packages/dph.git] / dph-common / Data / Array / Parallel / Lifted / Scalar.hs
1 {-# LANGUAGE CPP #-}
2
3 #include "fusion-phases.h"
4
5 module Data.Array.Parallel.Lifted.Scalar
6 where
7
8 import Data.Array.Parallel.Lifted.PArray
9 import Data.Array.Parallel.Lifted.Unboxed
10 import Data.Array.Parallel.Lifted.Repr
11 import Data.Array.Parallel.Lifted.Instances
12 import Data.Array.Parallel.Lifted.Selector
13
14 import qualified Data.Array.Parallel.Unlifted as U
15
16 import Data.Array.Parallel.Base ((:*:)(..), fstS, pairS, unpairS,
17 fromBool, toBool)
18
19 import GHC.Exts ( Int(..), (-#) )
20 import GHC.Word ( Word8 )
21
22 class U.Elt a => Scalar a where
23 fromUArrPD :: U.Array a -> PData a
24 toUArrPD :: PData a -> U.Array a
25 primPA :: PA a
26
27 fromUArrPA :: Scalar a => Int -> U.Array a -> PArray a
28 {-# INLINE fromUArrPA #-}
29 fromUArrPA (I# n#) xs = PArray n# (fromUArrPD xs)
30
31 toUArrPA :: Scalar a => PArray a -> U.Array a
32 {-# INLINE toUArrPA #-}
33 toUArrPA (PArray _ xs) = toUArrPD xs
34
35 prim_lengthPA :: Scalar a => PArray a -> Int
36 {-# INLINE prim_lengthPA #-}
37 prim_lengthPA xs = I# (lengthPA# xs)
38
39 fromUArrPA' :: Scalar a => U.Array a -> PArray a
40 {-# INLINE fromUArrPA' #-}
41 fromUArrPA' xs = fromUArrPA (U.length xs) xs
42
43 scalar_map :: (Scalar a, Scalar b) => (a -> b) -> PArray a -> PArray b
44 {-# INLINE_PA scalar_map #-}
45 scalar_map f xs = fromUArrPA (prim_lengthPA xs)
46 . U.map f
47 $ toUArrPA xs
48
49 scalar_zipWith :: (Scalar a, Scalar b, Scalar c)
50 => (a -> b -> c) -> PArray a -> PArray b -> PArray c
51 {-# INLINE_PA scalar_zipWith #-}
52 scalar_zipWith f xs ys = fromUArrPA (prim_lengthPA xs)
53 $ U.zipWith f (toUArrPA xs) (toUArrPA ys)
54
55 scalar_zipWith3
56 :: (Scalar a, Scalar b, Scalar c, Scalar d)
57 => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d
58 {-# INLINE_PA scalar_zipWith3 #-}
59 scalar_zipWith3 f xs ys zs
60 = fromUArrPA (prim_lengthPA xs)
61 $ U.zipWith3 f (toUArrPA xs) (toUArrPA ys) (toUArrPA zs)
62
63 scalar_fold :: Scalar a => (a -> a -> a) -> a -> PArray a -> a
64 {-# INLINE_PA scalar_fold #-}
65 scalar_fold f z = U.fold f z . toUArrPA
66
67 scalar_fold1 :: Scalar a => (a -> a -> a) -> PArray a -> a
68 {-# INLINE_PA scalar_fold1 #-}
69 scalar_fold1 f = U.fold1 f . toUArrPA
70
71 scalar_folds :: Scalar a => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a
72 {-# INLINE_PA scalar_folds #-}
73 scalar_folds f z xss = fromUArrPA (prim_lengthPA (concatPA# xss))
74 . U.fold_s f z (segdPA# xss)
75 . toUArrPA
76 $ concatPA# xss
77
78 scalar_fold1s :: Scalar a => (a -> a -> a) -> PArray (PArray a) -> PArray a
79 {-# INLINE_PA scalar_fold1s #-}
80 scalar_fold1s f xss = fromUArrPA (prim_lengthPA (concatPA# xss))
81 . U.fold1_s f (segdPA# xss)
82 . toUArrPA
83 $ concatPA# xss
84
85 scalar_fold1Index :: Scalar a
86 => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int
87 {-# INLINE_PA scalar_fold1Index #-}
88 scalar_fold1Index f = fstS . U.fold1 f' . U.indexed . toUArrPA
89 where
90 {-# INLINE f' #-}
91 f' p q = pairS $ f (unpairS p) (unpairS q)
92
93 scalar_fold1sIndex :: Scalar a
94 => ((Int, a) -> (Int, a) -> (Int, a))
95 -> PArray (PArray a) -> PArray Int
96 {-# INLINE_PA scalar_fold1sIndex #-}
97 scalar_fold1sIndex f xss = fromUArrPA n
98 . U.fsts
99 . U.fold1_s f' segd
100 . U.zip (U.indices_s m segd n)
101 . toUArrPA
102 $ concatPA# xss
103 where
104 {-# INLINE f' #-}
105 f' p q = pairS $ f (unpairS p) (unpairS q)
106
107 m = I# (lengthPA# xss)
108 n = I# (lengthPA# (concatPA# xss))
109
110 segd = segdPA# xss
111
112 instance Scalar Int where
113 fromUArrPD xs = PInt xs
114 toUArrPD (PInt xs) = xs
115 primPA = dPA_Int
116
117 instance Scalar Word8 where
118 fromUArrPD xs = PWord8 xs
119 toUArrPD (PWord8 xs) = xs
120 primPA = dPA_Word8
121
122 instance Scalar Double where
123 fromUArrPD xs = PDouble xs
124 toUArrPD (PDouble xs) = xs
125 primPA = dPA_Double
126
127 instance Scalar Bool where
128 {-# INLINE fromUArrPD #-}
129 fromUArrPD bs
130 = PBool (tagsToSel2 (U.map fromBool bs))
131
132 {-# INLINE toUArrPD #-}
133 toUArrPD (PBool sel) = U.map toBool (tagsSel2 sel)
134
135 primPA = dPA_Bool
136
137
138 fromUArrPA_2 :: (Scalar a, Scalar b) => Int -> U.Array (a :*: b) -> PArray (a,b)
139 {-# INLINE fromUArrPA_2 #-}
140 fromUArrPA_2 (I# n#) ps = PArray n# (P_2 (fromUArrPD xs) (fromUArrPD ys))
141 where
142 xs :*: ys = U.unzip ps
143
144 fromUArrPA_2' :: (Scalar a, Scalar b) => U.Array (a :*: b) -> PArray (a, b)
145 {-# INLINE fromUArrPA_2' #-}
146 fromUArrPA_2' ps = fromUArrPA_2 (U.length ps) ps
147
148 fromUArrPA_3 :: (Scalar a, Scalar b, Scalar c)
149 => Int -> U.Array (a :*: b :*: c) -> PArray (a,b,c)
150 {-# INLINE fromUArrPA_3 #-}
151 fromUArrPA_3 (I# n#) ps = PArray n# (P_3 (fromUArrPD xs)
152 (fromUArrPD ys)
153 (fromUArrPD zs))
154 where
155 xs :*: ys :*: zs = U.unzip3 ps
156
157 fromUArrPA_3' :: (Scalar a, Scalar b, Scalar c) => U.Array (a :*: b :*: c) -> PArray (a, b, c)
158 {-# INLINE fromUArrPA_3' #-}
159 fromUArrPA_3' ps = fromUArrPA_3 (U.length ps) ps
160
161 nestUSegdPA :: Int -> U.Segd -> PArray a -> PArray (PArray a)
162 {-# INLINE nestUSegdPA #-}
163 nestUSegdPA (I# n#) segd (PArray _ xs) = PArray n# (PNested segd xs)
164
165 nestUSegdPA' :: U.Segd -> PArray a -> PArray (PArray a)
166 {-# INLINE nestUSegdPA' #-}
167 nestUSegdPA' segd xs = nestUSegdPA (U.lengthSegd segd) segd xs
168