[project @ 1999-10-06 11:10:40 by simonmar]
[nofib.git] / ghc / matrix / Matrix.hs
1 module Matrix (
2 Matrix(..),
3 Vector(..),
4 fplMatrix,
5 createVector,
6 newVector,
7 amx,
8 saxpy,
9 saxmy,
10 xmy,
11 negx,
12 innerProd
13 ) where
14
15 import MutableArray
16 import ByteArray
17 import ST
18
19 type Vector s = MutableByteArray s Int
20 type Diagonal = ByteArray Int
21 type Matrix = (Int, Diagonal, Diagonal, Diagonal)
22
23 fplMatrix :: Int -> Matrix
24 fplMatrix size = (size, d0, d1, d2)
25 where
26 n = size * size
27 d0 = al n (\i -> 4)
28 d1 = al (n - 1) (\i -> if ((i + 1) `mod` size == 0) then 0 else (-1))
29 d2 = al (n - size) (\i -> -1)
30
31 al n f = runST (do
32 a <- newDoubleArray (0 :: Int, n - 1)
33 a <- al_ 0 a
34 freezeDoubleArray a)
35 where
36 al_ i a
37 | i >= n = return a
38 | otherwise = do
39 writeDoubleArray a i (f i)
40 al_ (i + 1) a
41
42 createVector :: [Double] -> ST s (Vector s)
43 createVector xs = do
44 a <- newDoubleArray (0 :: Int, length xs - 1)
45 createVector_ xs 0 a
46 where
47 createVector_ [] i a = return a
48 createVector_ (x:xs) i a = do
49 writeDoubleArray a i x
50 createVector_ xs (i + 1) a
51
52 newVector :: Vector s -> ST s (Vector s)
53 newVector v = newDoubleArray (p, q)
54 where
55 (p, q) = boundsOfMutableByteArray v
56
57 saxmy :: Double -> Vector s -> Vector s -> ST s (Vector s)
58 saxmy a x y =
59 saxmy' 0 n x y
60 where
61 n = div (sizeofMutableByteArray x) 8
62
63 saxmy' :: Int -> Int -> Vector s -> Vector s -> ST s (Vector s)
64 saxmy' i n x y
65 | i >= n = return x
66 | otherwise = do
67 xe <- readDoubleArray x i
68 ye <- readDoubleArray y i
69 writeDoubleArray x i (a * xe - ye)
70 saxmy' (i + 1) n x y
71
72 saxpy :: Double -> Vector s -> Vector s -> ST s (Vector s)
73 saxpy a x y =
74 saxpy' 0 n x y
75 where
76 n = div (sizeofMutableByteArray x) 8
77
78 saxpy' :: Int -> Int -> Vector s -> Vector s -> ST s (Vector s)
79 saxpy' i n x y
80 | i >= n = return y
81 | otherwise = do
82 xe <- readDoubleArray x i
83 ye <- readDoubleArray y i
84 writeDoubleArray y i (a * xe + ye)
85 saxpy' (i + 1) n x y
86
87 xmy :: Vector s -> Vector s -> ST s (Vector s)
88 xmy x y =
89 xmy' 0 n x y
90 where
91 n = div (sizeofMutableByteArray x) 8
92
93 xmy' :: Int -> Int -> Vector s -> Vector s -> ST s (Vector s)
94 xmy' i n x y
95 | i >= n = return x
96 | otherwise = do
97 xe <- readDoubleArray x i
98 ye <- readDoubleArray y i
99 writeDoubleArray x i (xe - ye)
100 xmy' (i + 1) n x y
101
102 negx :: Vector s -> Vector s -> ST s (Vector s)
103 negx x u =
104 negx' 0 n x u
105 where
106 n = div (sizeofMutableByteArray x) 8
107
108 negx' :: Int -> Int -> Vector s -> Vector s -> ST s (Vector s)
109 negx' i n x u
110 | i >= n = return u
111 | otherwise = do
112 xe <- readDoubleArray x i
113 writeDoubleArray u i (negate xe)
114 negx' (i + 1) n x u
115
116 innerProd :: Vector s -> Vector s -> ST s Double
117 innerProd x y =
118 innerProd' 0 0 n x y
119 where
120 n = div (sizeofMutableByteArray x) 8
121
122 innerProd' :: Double -> Int -> Int -> Vector s -> Vector s -> ST s Double
123 innerProd' r i n x y
124 | i >= n = return r
125 | otherwise = do
126 xe <- readDoubleArray x i
127 ye <- readDoubleArray y i
128 innerProd' (r + xe * ye) (i + 1) n x y
129
130 amx :: Vector s -> Matrix -> Vector s -> ST s (Vector s)
131 amx u (offset, d0, d1, d2) v = do
132 u <- mul0 0 n d0 v u
133 u <- mul1 0 1 n d1 v u
134 u <- mul1 0 offset n d2 v u
135 return u
136 where
137 n = div (sizeofMutableByteArray v) 8
138
139 -- mul0 :: Int -> Int -> Diagonal -> Vector s -> Vector s -> ST s (Vector s)
140 mul0 i (n :: Int) d0 v u
141 | (i :: Int) >= n = return u
142 | otherwise = do
143 ve <- readDoubleArray v i
144 let de = indexDoubleArray d0 i
145 writeDoubleArray u i (de * ve)
146 mul0 (i + 1) n d0 v u
147
148 -- mul1 :: Int -> Int -> Int -> Diagonal -> Vector s -> Vector s -> ST s (Vector s)
149 mul1 i1 i2 (n :: Int) d v u
150 | (i2 :: Int) >= n = return u
151 | otherwise = do
152 let de = indexDoubleArray d i1
153 e1 <- readDoubleArray u i1
154 e2 <- readDoubleArray u i2
155 ve1 <- readDoubleArray v i1
156 ve2 <- readDoubleArray v i2
157 writeDoubleArray u i1 (e1 + de * ve2)
158 writeDoubleArray u i2 (e2 + de * ve1)
159 mul1 (i1 + 1) (i2 + 1) n d v u