Use vector 0.7 instead of arrays from dph-base
[packages/dph.git] / dph-prim-par / Data / Array / Parallel / Unlifted / Distributed / Types.hs
1 {-# OPTIONS -fno-warn-incomplete-patterns #-}
2 -----------------------------------------------------------------------------
3 -- |
4 -- Module : Data.Array.Parallel.Unlifted.Distributed.Types
5 -- Copyright : (c) 2006 Roman Leshchinskiy
6 -- License : see libraries/ndp/LICENSE
7 --
8 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
9 -- Stability : experimental
10 -- Portability : non-portable (GHC Extensions)
11 --
12 -- Distributed types.
13 --
14
15 {-# LANGUAGE CPP #-}
16
17 #include "fusion-phases.h"
18
19 module Data.Array.Parallel.Unlifted.Distributed.Types (
20 -- * Distributed types
21 DT, Dist, MDist, DPrim(..),
22
23 -- * Operations on immutable distributed types
24 indexD, unitD, zipD, unzipD, fstD, sndD, lengthD,
25 newD,
26 -- zipSD, unzipSD, fstSD, sndSD,
27 deepSeqD,
28
29 lengthUSegdD, lengthsUSegdD, indicesUSegdD, elementsUSegdD,
30
31 -- * Operations on mutable distributed types
32 newMD, readMD, writeMD, unsafeFreezeMD,
33
34 -- * Assertions
35 checkGangD, checkGangMD,
36
37 -- * Debugging functions
38 sizeD, sizeMD, measureD, debugD
39 ) where
40
41 import Data.Array.Parallel.Unlifted.Distributed.Gang (
42 Gang, gangSize )
43 import Data.Array.Parallel.Unlifted.Sequential.Vector ( Unbox, Vector )
44 import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
45 import Data.Array.Parallel.Unlifted.Sequential.Segmented
46 import Data.Array.Parallel.Base
47
48 import qualified Data.Vector.Unboxed as V
49 import qualified Data.Vector.Unboxed.Mutable as MV
50 import qualified Data.Vector as BV
51 import qualified Data.Vector.Mutable as MBV
52
53 import Data.Word (Word8)
54 import Control.Monad (liftM, liftM2, liftM3)
55
56 import Data.List ( intercalate )
57
58 infixl 9 `indexD`
59
60 here s = "Data.Array.Parallel.Unlifted.Distributed.Types." ++ s
61
62 -- |Distributed types
63 -- ----------------------------
64
65 -- | Class of distributable types. Instances of 'DT' can be
66 -- distributed across all workers of a 'Gang'. All such types
67 -- must be hyperstrict as we do not want to pass thunks into distributed
68 -- computations.
69 class DT a where
70 data Dist a
71 data MDist a :: * -> *
72
73 -- | Extract a single element of an immutable distributed value.
74 indexD :: Dist a -> Int -> a
75
76 -- | Create an unitialised distributed value for the given 'Gang'.
77 -- The gang is used (only) to know how many elements are needed
78 -- in the distributed value.
79 newMD :: Gang -> ST s (MDist a s)
80
81 -- | Extract an element from a mutable distributed value.
82 readMD :: MDist a s -> Int -> ST s a
83
84 -- | Write an element of a mutable distributed value.
85 writeMD :: MDist a s -> Int -> a -> ST s ()
86
87 -- | Unsafely freeze a mutable distributed value.
88 unsafeFreezeMD :: MDist a s -> ST s (Dist a)
89
90 deepSeqD :: a -> b -> b
91 deepSeqD = seq
92
93 -- | Number of elements in the distributed value. This is for debugging
94 -- only.
95 sizeD :: Dist a -> Int
96
97 -- | Number of elements in the mutable distributed value. This is for
98 -- debugging only.
99 sizeMD :: MDist a s -> Int
100
101 measureD :: a -> String
102 measureD _ = "?"
103
104 -- Distributed values must always be hyperstrict.
105 -- instance DT a => HS (Dist a)
106
107 -- | Check that the sizes of the 'Gang' and of the distributed value match.
108 checkGangD :: DT a => String -> Gang -> Dist a -> b -> b
109 checkGangD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeD d) v
110
111 -- | Check that the sizes of the 'Gang' and of the mutable distributed value
112 -- match.
113 checkGangMD :: DT a => String -> Gang -> MDist a s -> b -> b
114 checkGangMD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeMD d) v
115
116 -- Show instance (for debugging only)
117 instance (Show a, DT a) => Show (Dist a) where
118 show d = show (Prelude.map (indexD d) [0 .. sizeD d - 1])
119
120 -- | 'DT' instances
121 -- ----------------
122
123 instance DT () where
124 data Dist () = DUnit !Int
125 data MDist () s = MDUnit !Int
126
127 indexD (DUnit n) i = check (here "indexD[()]") n i $ ()
128 newMD = return . MDUnit . gangSize
129 readMD (MDUnit n) i = check (here "readMD[()]") n i $
130 return ()
131 writeMD (MDUnit n) i () = check (here "writeMD[()]") n i $
132 return ()
133 unsafeFreezeMD (MDUnit n) = return $ DUnit n
134
135 class Unbox e => DPrim e where
136 mkDPrim :: V.Vector e -> Dist e
137 unDPrim :: Dist e -> V.Vector e
138
139 mkMDPrim :: MV.STVector s e -> MDist e s
140 unMDPrim :: MDist e s -> MV.STVector s e
141
142 primIndexD :: DPrim a => Dist a -> Int -> a
143 {-# INLINE primIndexD #-}
144 primIndexD = (V.!) . unDPrim
145
146 primNewMD :: DPrim a => Gang -> ST s (MDist a s)
147 {-# INLINE primNewMD #-}
148 primNewMD = liftM mkMDPrim . MV.new . gangSize
149
150 primReadMD :: DPrim a => MDist a s -> Int -> ST s a
151 {-# INLINE primReadMD #-}
152 primReadMD = MV.read . unMDPrim
153
154 primWriteMD :: DPrim a => MDist a s -> Int -> a -> ST s ()
155 {-# INLINE primWriteMD #-}
156 primWriteMD = MV.write . unMDPrim
157
158 primUnsafeFreezeMD :: DPrim a => MDist a s -> ST s (Dist a)
159 {-# INLINE primUnsafeFreezeMD #-}
160 primUnsafeFreezeMD = liftM mkDPrim . V.unsafeFreeze . unMDPrim
161
162 primSizeD :: DPrim a => Dist a -> Int
163 {-# INLINE primSizeD #-}
164 primSizeD = V.length . unDPrim
165
166 primSizeMD :: DPrim a => MDist a s -> Int
167 {-# INLINE primSizeMD #-}
168 primSizeMD = MV.length . unMDPrim
169
170 instance DPrim Bool where
171 mkDPrim = DBool
172 unDPrim (DBool a) = a
173
174 mkMDPrim = MDBool
175 unMDPrim (MDBool a) = a
176
177 instance DT Bool where
178 data Dist Bool = DBool !(V.Vector Bool)
179 data MDist Bool s = MDBool !(MV.STVector s Bool)
180
181 indexD = primIndexD
182 newMD = primNewMD
183 readMD = primReadMD
184 writeMD = primWriteMD
185 unsafeFreezeMD = primUnsafeFreezeMD
186 sizeD = primSizeD
187 sizeMD = primSizeMD
188
189 instance DPrim Char where
190 mkDPrim = DChar
191 unDPrim (DChar a) = a
192
193 mkMDPrim = MDChar
194 unMDPrim (MDChar a) = a
195
196 instance DT Char where
197 data Dist Char = DChar !(V.Vector Char)
198 data MDist Char s = MDChar !(MV.STVector s Char)
199
200 indexD = primIndexD
201 newMD = primNewMD
202 readMD = primReadMD
203 writeMD = primWriteMD
204 unsafeFreezeMD = primUnsafeFreezeMD
205 sizeD = primSizeD
206 sizeMD = primSizeMD
207
208 instance DPrim Int where
209 mkDPrim = DInt
210 unDPrim (DInt a) = a
211
212 mkMDPrim = MDInt
213 unMDPrim (MDInt a) = a
214
215 instance DT Int where
216 data Dist Int = DInt !(V.Vector Int)
217 data MDist Int s = MDInt !(MV.STVector s Int)
218
219 indexD = primIndexD
220 newMD = primNewMD
221 readMD = primReadMD
222 writeMD = primWriteMD
223 unsafeFreezeMD = primUnsafeFreezeMD
224 sizeD = primSizeD
225 sizeMD = primSizeMD
226
227 measureD n = "int " ++ show n
228
229 instance DPrim Word8 where
230 mkDPrim = DWord8
231 unDPrim (DWord8 a) = a
232
233 mkMDPrim = MDWord8
234 unMDPrim (MDWord8 a) = a
235
236 instance DT Word8 where
237 data Dist Word8 = DWord8 !(V.Vector Word8)
238 data MDist Word8 s = MDWord8 !(MV.STVector s Word8)
239
240 indexD = primIndexD
241 newMD = primNewMD
242 readMD = primReadMD
243 writeMD = primWriteMD
244 unsafeFreezeMD = primUnsafeFreezeMD
245 sizeD = primSizeD
246 sizeMD = primSizeMD
247
248 instance DPrim Float where
249 mkDPrim = DFloat
250 unDPrim (DFloat a) = a
251
252 mkMDPrim = MDFloat
253 unMDPrim (MDFloat a) = a
254
255 instance DT Float where
256 data Dist Float = DFloat !(V.Vector Float)
257 data MDist Float s = MDFloat !(MV.STVector s Float)
258
259 indexD = primIndexD
260 newMD = primNewMD
261 readMD = primReadMD
262 writeMD = primWriteMD
263 unsafeFreezeMD = primUnsafeFreezeMD
264 sizeD = primSizeD
265 sizeMD = primSizeMD
266
267 instance DPrim Double where
268 mkDPrim = DDouble
269 unDPrim (DDouble a) = a
270
271 mkMDPrim = MDDouble
272 unMDPrim (MDDouble a) = a
273
274 instance DT Double where
275 data Dist Double = DDouble !(V.Vector Double)
276 data MDist Double s = MDDouble !(MV.STVector s Double)
277
278 indexD = primIndexD
279 newMD = primNewMD
280 readMD = primReadMD
281 writeMD = primWriteMD
282 unsafeFreezeMD = primUnsafeFreezeMD
283 sizeD = primSizeD
284 sizeMD = primSizeMD
285
286 instance (DT a, DT b) => DT (a,b) where
287 data Dist (a,b) = DProd !(Dist a) !(Dist b)
288 data MDist (a,b) s = MDProd !(MDist a s) !(MDist b s)
289
290 indexD d i = (fstD d `indexD` i,sndD d `indexD` i)
291 newMD g = liftM2 MDProd (newMD g) (newMD g)
292 readMD (MDProd xs ys) i = liftM2 (,) (readMD xs i) (readMD ys i)
293 writeMD (MDProd xs ys) i (x,y)
294 = writeMD xs i x >> writeMD ys i y
295 unsafeFreezeMD (MDProd xs ys)
296 = liftM2 DProd (unsafeFreezeMD xs)
297 (unsafeFreezeMD ys)
298
299 {-# INLINE deepSeqD #-}
300 deepSeqD (x,y) z = deepSeqD x (deepSeqD y z)
301
302 sizeD (DProd x _) = sizeD x
303 sizeMD (MDProd x _) = sizeMD x
304
305 measureD (x,y) = "(" ++ measureD x ++ "," ++ measureD y ++ ")"
306
307 instance DT a => DT (Maybe a) where
308 data Dist (Maybe a) = DMaybe !(Dist Bool) !(Dist a)
309 data MDist (Maybe a) s = MDMaybe !(MDist Bool s) !(MDist a s)
310
311 indexD (DMaybe bs as) i
312 | bs `indexD` i = Just $ as `indexD` i
313 | otherwise = Nothing
314 newMD g = liftM2 MDMaybe (newMD g) (newMD g)
315 readMD (MDMaybe bs as) i =
316 do
317 b <- readMD bs i
318 if b then liftM Just $ readMD as i
319 else return Nothing
320 writeMD (MDMaybe bs as) i Nothing = writeMD bs i False
321 writeMD (MDMaybe bs as) i (Just x) = writeMD bs i True
322 >> writeMD as i x
323 unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs)
324 (unsafeFreezeMD as)
325
326 {-# INLINE deepSeqD #-}
327 deepSeqD Nothing z = z
328 deepSeqD (Just x) z = deepSeqD x z
329
330 sizeD (DMaybe b _) = sizeD b
331 sizeMD (MDMaybe b _) = sizeMD b
332
333 measureD Nothing = "Nothing"
334 measureD (Just x) = "Just (" ++ measureD x ++ ")"
335
336 {-
337 instance DT a => DT (MaybeS a) where
338 data Dist (MaybeS a) = DMaybe !(Dist Bool) !(Dist a)
339 data MDist (MaybeS a) s = MDMaybe !(MDist Bool s) !(MDist a s)
340
341 indexD (DMaybe bs as) i
342 | bs `indexD` i = JustS $ as `indexD` i
343 | otherwise = NothingS
344 newMD g = liftM2 MDMaybe (newMD g) (newMD g)
345 readMD (MDMaybe bs as) i =
346 do
347 b <- readMD bs i
348 if b then liftM JustS $ readMD as i
349 else return NothingS
350 writeMD (MDMaybe bs as) i NothingS = writeMD bs i False
351 writeMD (MDMaybe bs as) i (JustS x) = writeMD bs i True
352 >> writeMD as i x
353 unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs)
354 (unsafeFreezeMD as)
355 sizeD (DMaybe b _) = sizeD b
356 sizeMD (MDMaybe b _) = sizeMD b
357
358 measureD NothingS = "Nothing"
359 measureD (JustS x) = "Just (" ++ measureD x ++ ")"
360 -}
361
362 instance Unbox a => DT (V.Vector a) where
363 data Dist (Vector a) = DVector !(Dist Int) !(BV.Vector (Vector a))
364 data MDist (Vector a) s = MDVector !(MDist Int s) !(MBV.STVector s (Vector a))
365
366 indexD (DVector _ a) i = a BV.! i
367 newMD g = liftM2 MDVector (newMD g) (MBV.replicate (gangSize g)
368 (error "MDist (Vector a) - uninitalised"))
369 readMD (MDVector _ marr) = MBV.read marr
370 writeMD (MDVector mlen marr) i a =
371 do
372 writeMD mlen i (V.length a)
373 MBV.write marr i $! a
374 unsafeFreezeMD (MDVector len a) = liftM2 DVector (unsafeFreezeMD len)
375 (BV.unsafeFreeze a)
376 sizeD (DVector _ a) = BV.length a
377 sizeMD (MDVector _ a) = MBV.length a
378
379 measureD xs = "Vector " ++ show (V.length xs)
380
381 instance DT USegd where
382 data Dist USegd = DUSegd !(Dist (Vector Int))
383 !(Dist (Vector Int))
384 !(Dist Int)
385 data MDist USegd s = MDUSegd !(MDist (Vector Int) s)
386 !(MDist (Vector Int) s)
387 !(MDist Int s)
388
389 indexD (DUSegd lens idxs eles) i
390 = mkUSegd (indexD lens i) (indexD idxs i) (indexD eles i)
391 newMD g = liftM3 MDUSegd (newMD g) (newMD g) (newMD g)
392 readMD (MDUSegd lens idxs eles) i
393 = liftM3 mkUSegd (readMD lens i) (readMD idxs i) (readMD eles i)
394 writeMD (MDUSegd lens idxs eles) i segd
395 = do
396 writeMD lens i (lengthsUSegd segd)
397 writeMD idxs i (indicesUSegd segd)
398 writeMD eles i (elementsUSegd segd)
399 unsafeFreezeMD (MDUSegd lens idxs eles)
400 = liftM3 DUSegd (unsafeFreezeMD lens)
401 (unsafeFreezeMD idxs)
402 (unsafeFreezeMD eles)
403
404 deepSeqD segd z = deepSeqD (lengthsUSegd segd)
405 $ deepSeqD (indicesUSegd segd)
406 $ deepSeqD (elementsUSegd segd) z
407
408 sizeD (DUSegd _ _ eles) = sizeD eles
409 sizeMD (MDUSegd _ _ eles) = sizeMD eles
410
411 measureD segd = "Segd " ++ show (lengthUSegd segd) ++ "|" ++ show (elementsUSegd segd)
412
413 lengthUSegdD :: Dist USegd -> Dist Int
414 {-# INLINE_DIST lengthUSegdD #-}
415 lengthUSegdD (DUSegd lens _ _) = lengthD lens
416
417 lengthsUSegdD :: Dist USegd -> Dist (Vector Int)
418 {-# INLINE_DIST lengthsUSegdD #-}
419 lengthsUSegdD (DUSegd lens _ _ ) = lens
420
421 indicesUSegdD :: Dist USegd -> Dist (Vector Int)
422 {-# INLINE_DIST indicesUSegdD #-}
423 indicesUSegdD (DUSegd _ idxs _) = idxs
424
425 elementsUSegdD :: Dist USegd -> Dist Int
426 {-# INLINE_DIST elementsUSegdD #-}
427 elementsUSegdD (DUSegd _ _ dns) = dns
428
429 -- |Basic operations on immutable distributed types
430 -- -------------------------------------------
431
432 newD :: DT a => Gang -> (forall s . MDist a s -> ST s ()) -> Dist a
433 newD g init =
434 runST (do
435 mdt <- newMD g
436 init mdt
437 unsafeFreezeMD mdt)
438
439
440 -- | Yield a distributed unit.
441 unitD :: Gang -> Dist ()
442 {-# INLINE_DIST unitD #-}
443 unitD = DUnit . gangSize
444
445 -- | Pairing of distributed values.
446 -- /The two values must belong to the same/ 'Gang'.
447 zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a,b)
448 {-# INLINE [0] zipD #-}
449 zipD !x !y = checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y) $
450 DProd x y
451
452 -- | Unpairing of distributed values.
453 unzipD :: (DT a, DT b) => Dist (a,b) -> (Dist a, Dist b)
454 {-# INLINE_DIST unzipD #-}
455 unzipD (DProd dx dy) = (dx,dy)
456
457 -- | Extract the first elements of a distributed pair.
458 fstD :: (DT a, DT b) => Dist (a,b) -> Dist a
459 {-# INLINE_DIST fstD #-}
460 fstD = fst . unzipD
461
462 -- | Extract the second elements of a distributed pair.
463 sndD :: (DT a, DT b) => Dist (a,b) -> Dist b
464 {-# INLINE_DIST sndD #-}
465 sndD = snd . unzipD
466
467 {-
468 -- | Pairing of distributed values.
469 -- /The two values must belong to the same/ 'Gang'.
470 zipSD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a,b)
471 {-# INLINE [0] zipSD #-}
472 zipSD !x !y = checkEq (here "zipSD") "Size mismatch" (sizeD x) (sizeD y) $
473 SDProd x y
474
475 -- | Unpairing of distributed values.
476 unzipSD :: (DT a, DT b) => Dist (a,b) -> (Dist a, Dist b)
477 {-# INLINE_DIST unzipSD #-}
478 unzipSD (SDProd dx dy) = (dx,dy)
479
480 -- | Extract the first elements of a distributed pair.
481 fstSD :: (DT a, DT b) => Dist (a,b) -> Dist a
482 {-# INLINE_DIST fstSD #-}
483 fstSD = fst . unzipSD
484
485 -- | Extract the second elements of a distributed pair.
486 sndSD :: (DT a, DT b) => Dist (a,b) -> Dist b
487 {-# INLINE_DIST sndSD #-}
488 sndSD = snd . unzipSD
489 -}
490
491 -- | Yield the distributed length of a distributed array.
492 lengthD :: Unbox a => Dist (Vector a) -> Dist Int
493 lengthD (DVector l _) = l
494
495 debugD :: DT a => Dist a -> String
496 debugD d = "["
497 ++ intercalate "," [measureD (indexD d i) | i <- [0 .. sizeD d-1]]
498 ++ "]"
499