f2ff2ca939a837234ab9a76fbb13fa4e432a707b
[darcs-mirrors/vector.git] / internal / GenUnboxTuple.hs
1 {-# LANGUAGE ParallelListComp #-}
2 module Main where
3
4 import Text.PrettyPrint
5
6 import System.Environment ( getArgs )
7
8 main = do
9 [s] <- getArgs
10 let n = read s
11 mapM_ (putStrLn . render . generate) [2..n]
12
13 generate :: Int -> Doc
14 generate n =
15 vcat [ text "#ifdef DEFINE_INSTANCES"
16 , data_instance "MVector s" "MV"
17 , data_instance "Vector" "V"
18 , class_instance "Unbox"
19 , class_instance "M.MVector MVector" <+> text "where"
20 , nest 2 $ vcat $ map method methods_MVector
21 , class_instance "G.Vector Vector" <+> text "where"
22 , nest 2 $ vcat $ map method methods_Vector
23 , text "#endif"
24 , text "#ifdef DEFINE_MUTABLE"
25 , define_zip "MVector s" "MV"
26 , define_unzip "MVector s" "MV"
27 , text "#endif"
28 , text "#ifdef DEFINE_IMMUTABLE"
29 , define_zip "Vector" "V"
30 , define_zip_rule
31 , define_unzip "Vector" "V"
32 , text "#endif"
33 ]
34
35 where
36 vars = map char $ take n ['a'..]
37 varss = map (<> char 's') vars
38 tuple xs = parens $ hsep $ punctuate comma xs
39 vtuple xs = parens $ sep $ punctuate comma xs
40 con s = text s <> char '_' <> int n
41 var c = text (c : "_")
42
43 data_instance ty c
44 = hang (hsep [text "data instance", text ty, tuple vars])
45 4
46 (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
47 , vcat $ map (\v -> char '!' <> parens (text ty <+> v)) vars])
48
49 class_instance cls
50 = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
51 <+> text "=>" <+> text cls <+> tuple vars
52
53
54 define_zip ty c
55 = sep [text "-- | /O(1)/ Zip" <+> int n <+> text "vectors"
56 ,name <+> text "::"
57 <+> vtuple [text "Unbox" <+> v | v <- vars]
58 <+> text "=>"
59 <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
60 <+> text "->"
61 <+> text ty <+> tuple vars
62 ,text "{-# INLINE_STREAM" <+> name <+> text "#-}"
63 ,name <+> sep varss
64 <+> text "="
65 <+> con c
66 <+> text "len"
67 <+> sep [parens $ text "unsafeSlice"
68 <+> char '0'
69 <+> text "len"
70 <+> vs | vs <- varss]
71 ,nest 2 $ hang (text "where")
72 2
73 $ text "len ="
74 <+> sep (punctuate (text " `min`")
75 [text "length" <+> vs | vs <- varss])
76 ]
77 where
78 name | n == 2 = text "zip"
79 | otherwise = text "zip" <> int n
80
81 define_zip_rule
82 = hang (text "{-# RULES" <+> text "\"stream/" <> name "zip"
83 <> text " [Vector.Unboxed]\" forall" <+> sep varss <+> char '.')
84 2 $
85 text "G.stream" <+> parens (name "zip" <+> sep varss)
86 <+> char '='
87 <+> text "Stream." <> name "zipWith" <+> tuple (replicate n empty)
88 <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
89 $$ text "#-}"
90 where
91 name s | n == 2 = text s
92 | otherwise = text s <> int n
93
94
95 define_unzip ty c
96 = sep [text "-- | /O(1)/ Unzip" <+> int n <+> text "vectors"
97 ,name <+> text "::"
98 <+> vtuple [text "Unbox" <+> v | v <- vars]
99 <+> text "=>"
100 <+> text ty <+> tuple vars
101 <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
102 ,text "{-# INLINE" <+> name <+> text "#-}"
103 ,name <+> pat c <+> text "="
104 <+> vtuple varss
105 ]
106 where
107 name | n == 2 = text "unzip"
108 | otherwise = text "unzip" <> int n
109
110 pat c = parens $ con c <+> var 'n' <+> sep varss
111 patn c n = parens $ con c <+> (var 'n' <> int n)
112 <+> sep [v <> int n | v <- varss]
113
114 qM s = text "M." <> text s
115 qG s = text "G." <> text s
116
117 gen_length c _ = (pat c, var 'n')
118
119 gen_unsafeSlice mod c rec
120 = (var 'i' <+> var 'm' <+> pat c,
121 con c <+> var 'm'
122 <+> vcat [parens
123 $ text mod <> char '.' <> text rec
124 <+> var 'i' <+> var 'm' <+> vs
125 | vs <- varss])
126
127
128 gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
129 vcat $ r : [text "||" <+> r | r <- rs])
130 where
131 r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]
132
133 gen_unsafeNew rec
134 = (var 'n',
135 mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
136 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
137
138 gen_unsafeReplicate rec
139 = (var 'n' <+> tuple vars,
140 mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
141 | v <- vars | vs <- varss]
142 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
143
144 gen_unsafeRead rec
145 = (pat "MV" <+> var 'i',
146 mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v <- vars
147 | vs <- varss]
148 $ text "return" <+> tuple vars)
149
150 gen_unsafeWrite rec
151 = (pat "MV" <+> var 'i' <+> tuple vars,
152 mk_do [qM rec <+> vs <+> var 'i' <+> v | v <- vars | vs <- varss]
153 empty)
154
155 gen_clear rec
156 = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
157
158 gen_set rec
159 = (pat "MV" <+> tuple vars,
160 mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)
161
162 gen_unsafeCopy c q rec
163 = (patn "MV" 1 <+> patn c 2,
164 mk_do [q rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
165 empty)
166
167 gen_unsafeGrow rec
168 = (pat "MV" <+> var 'm',
169 mk_do [vs <> char '\'' <+> text "<-"
170 <+> qM rec <+> vs <+> var 'm' | vs <- varss]
171 $ text "return $" <+> con "MV"
172 <+> parens (var 'm' <> char '+' <> var 'n')
173 <+> sep (map (<> char '\'') varss))
174
175 gen_unsafeFreeze rec
176 = (pat "MV",
177 mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
178 $ text "return $" <+> con "V" <+> var 'n'
179 <+> sep [vs <> char '\'' | vs <- varss])
180
181 gen_basicUnsafeIndexM rec
182 = (pat "V" <+> var 'i',
183 mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
184 | vs <- varss | v <- vars]
185 $ text "return" <+> tuple vars)
186
187 gen_elemseq rec
188 = (char '_' <+> tuple vars,
189 vcat $ r : [char '.' <+> r | r <- rs])
190 where
191 r : rs = [qG rec <+> parens (text "undefined :: Vector" <+> v)
192 <+> v | v <- vars]
193
194 mk_do cmds ret = hang (text "do")
195 2
196 $ vcat $ cmds ++ [ret]
197
198 method (s, f) = case f s of
199 (p,e) -> text "{-# INLINE" <+> text s <+> text " #-}"
200 $$ hang (text s <+> p)
201 4
202 (char '=' <+> e)
203
204
205 methods_MVector = [("basicLength", gen_length "MV")
206 ,("basicUnsafeSlice", gen_unsafeSlice "M" "MV")
207 ,("basicOverlaps", gen_overlaps)
208 ,("basicUnsafeNew", gen_unsafeNew)
209 ,("basicUnsafeReplicate", gen_unsafeReplicate)
210 ,("basicUnsafeRead", gen_unsafeRead)
211 ,("basicUnsafeWrite", gen_unsafeWrite)
212 ,("basicClear", gen_clear)
213 ,("basicSet", gen_set)
214 ,("basicUnsafeCopy", gen_unsafeCopy "MV" qM)
215 ,("basicUnsafeGrow", gen_unsafeGrow)]
216
217 methods_Vector = [("unsafeFreeze", gen_unsafeFreeze)
218 ,("basicLength", gen_length "V")
219 ,("basicUnsafeSlice", gen_unsafeSlice "G" "V")
220 ,("basicUnsafeIndexM", gen_basicUnsafeIndexM)
221 ,("basicUnsafeCopy", gen_unsafeCopy "V" qG)
222 ,("elemseq", gen_elemseq)]