Slight optimisation
[packages/dph.git] / dph-prim-seq / Data / Array / Parallel / Unlifted / Sequential / Flat / Permute.hs
1 {-# LANGUAGE ScopedTypeVariables #-}
2 ----------------------------------------------------------------------------
3 -- |
4 -- Module : Data.Array.Parallel.Unlifted.Sequential.Flat.Permute
5 -- Copyright : (c) [2001..2002] Manuel M T Chakravarty & Gabriele Keller
6 -- (c) 2006 Manuel M T Chakravarty & Roman Leshchinskiy
7 -- License : see libraries/ndp/LICENSE
8 --
9 -- Maintainer : Roman Leshchinskiy <rl@cse.unsw.edu.au>
10 -- Stability : experimental
11 -- Portability : portable
12 --
13 -- Description ---------------------------------------------------------------
14 --
15 -- Permutations on flat unlifted arrays.
16 --
17 -- Todo ----------------------------------------------------------------------
18 --
19
20 {-# LANGUAGE CPP #-}
21
22 #include "fusion-phases.h"
23
24 module Data.Array.Parallel.Unlifted.Sequential.Flat.Permute (
25 permuteU, permuteMU, mbpermuteU, bpermuteU, bpermuteDftU, reverseU, updateU,
26 atomicUpdateMU
27 ) where
28
29 import Data.Array.Parallel.Base (
30 ST, runST, (:*:)(..), Rebox(..))
31 import Data.Array.Parallel.Base.DTrace
32 import Data.Array.Parallel.Stream (
33 Step(..), Stream(..), mapS, sArgs)
34 import Data.Array.Parallel.Unlifted.Sequential.Flat.UArr (
35 UA, UArr, MUArr,
36 lengthU, newU, newDynU, newMU, unsafeFreezeAllMU, writeMU,
37 sliceU)
38 import Data.Array.Parallel.Unlifted.Sequential.Flat.Stream (
39 unstreamU, streamU, unstreamMU)
40 import Data.Array.Parallel.Unlifted.Sequential.Flat.Basics (
41 (!:))
42 import Data.Array.Parallel.Unlifted.Sequential.Flat.Enum (
43 enumFromToU)
44 import Data.Array.Parallel.Unlifted.Sequential.Flat.Combinators (
45 mapU)
46
47 -- |Permutations
48 -- -------------
49
50 permuteMU :: UA e => MUArr e s -> UArr e -> UArr Int -> ST s ()
51 permuteMU mpa arr is = permute 0
52 where
53 n = lengthU arr
54 permute i
55 | i == n = return ()
56 | otherwise = writeMU mpa (is!:i) (arr!:i) >> permute (i + 1)
57
58
59 -- |Standard permutation
60 --
61 permuteU :: UA e => UArr e -> UArr Int -> UArr e
62 {-# INLINE_U permuteU #-}
63 permuteU arr is = newU (lengthU arr) $ \mpa -> permuteMU mpa arr is
64
65 -- |Back permutation operation (ie, the permutation vector determines for each
66 -- position in the result array its origin in the input array)
67 --
68 -- WARNING: DO NOT rewrite this as unstreamU . bpermuteUS es . streamU
69 -- because GHC won't be able to figure out its strictness.
70 --
71 bpermuteU :: UA e => UArr e -> UArr Int -> UArr e
72 {-# INLINE_U bpermuteU #-}
73 bpermuteU es is = unstreamU (bpermuteUS es (streamU is))
74
75 mbpermuteU:: (UA e, UA d) => (e -> d) -> UArr e -> UArr Int -> UArr d
76 {-# INLINE_STREAM mbpermuteU #-}
77 mbpermuteU f es is = unstreamU (mbpermuteUS f es (streamU is))
78
79
80
81 bpermuteUS :: UA e => UArr e -> Stream Int -> Stream e
82 {-# INLINE_STREAM bpermuteUS #-}
83 bpermuteUS !a s = mapS (a!:) s
84
85 mbpermuteUS:: (UA e, UA d) => (e -> d) -> UArr e -> Stream Int -> Stream d
86 {-# INLINE_STREAM mbpermuteUS #-}
87 mbpermuteUS f !a = mapS (f . (a!:))
88
89 -- |Default back permute
90 --
91 -- * The values of the index-value pairs are written into the position in the
92 -- result array that is indicated by the corresponding index.
93 --
94 -- * All positions not covered by the index-value pairs will have the value
95 -- determined by the initialiser function for that index position.
96 --
97 bpermuteDftU :: UA e
98 => Int -- ^ length of result array
99 -> (Int -> e) -- ^ initialiser function
100 -> UArr (Int :*: e) -- ^ index-value pairs
101 -> UArr e
102 {-# INLINE_U bpermuteDftU #-}
103 bpermuteDftU n init = updateU (mapU init . enumFromToU 0 $ n-1)
104
105 atomicUpdateMU :: UA e => MUArr e s -> UArr (Int :*: e) -> ST s ()
106 {-# INLINE_U atomicUpdateMU #-}
107 atomicUpdateMU marr upd = updateM writeMU marr (streamU upd)
108
109 updateM :: UA e => (MUArr e s -> Int -> e -> ST s ())
110 -> MUArr e s -> Stream (Int :*: e) -> ST s ()
111 {-# INLINE_STREAM updateM #-}
112 updateM write marr (Stream next s _ c)
113 = traceLoopST ("updateM" `sArgs` c) $ upd s
114 where
115 upd s = case next s of
116 Done -> return ()
117 Skip s' -> upd s'
118 Yield (i :*: x) s' -> do
119 write marr i x
120 upd s'
121
122 -- | Yield an array constructed by updating the first array by the
123 -- associations from the second array (which contains index\/value pairs).
124 --
125 updateU :: UA e => UArr e -> UArr (Int :*: e) -> UArr e
126 {-# INLINE_U updateU #-}
127 updateU arr upd = update (streamU arr) (streamU upd)
128
129 update :: UA e => Stream e -> Stream (Int :*: e) -> UArr e
130 {-# INLINE_STREAM update #-}
131 update s1@(Stream _ _ n _) !s2 = newDynU n (\marr ->
132 do
133 i <- unstreamMU marr s1
134 updateM writeMU marr s2
135 return i
136 )
137
138 -- |Reverse the order of elements in an array
139 --
140 reverseU :: UA e => UArr e -> UArr e
141 {-# INLINE_U reverseU #-}
142 --reverseU a = mapU (a!:) . enumFromToU 0 $ lengthU a - 1
143 reverseU = rev . streamU
144
145 rev :: forall e. UA e => Stream e -> UArr e
146 {-# INLINE_STREAM rev #-}
147 rev (Stream next s n c) =
148 runST (do
149 marr <- newMU n
150 i <- traceLoopST ("rev" `sArgs` c) $ fill marr
151 a <- unsafeFreezeAllMU marr
152 return $ sliceU a i (n-i)
153 )
154 where
155 fill :: forall s. MUArr e s -> ST s Int
156 fill marr = fill0 s n
157 where
158 fill0 s !i = case next s of
159 Done -> return i
160 Skip s' -> s' `dseq` fill0 s' i
161 Yield x s' -> s' `dseq`
162 let i' = i-1
163 in
164 do
165 writeMU marr i' x
166 fill0 s' i'
167