fe447c2bcc541d0557e960ee611a672bdb739e30
[darcs-mirrors/vector.git] / internal / GenUnboxTuple.hs
1 {-# LANGUAGE ParallelListComp #-}
2 module Main where
3
4 import Text.PrettyPrint
5
6 import System.Environment ( getArgs )
7
8 main = do
9 [s] <- getArgs
10 let n = read s
11 mapM_ (putStrLn . render . generate) [2..n]
12
13 generate :: Int -> Doc
14 generate n =
15 vcat [ text "#ifdef DEFINE_INSTANCES"
16 , data_instance "MVector s" "MV"
17 , data_instance "Vector" "V"
18 , class_instance "Unbox"
19 , class_instance "M.MVector MVector" <+> text "where"
20 , nest 2 $ vcat $ map method methods_MVector
21 , class_instance "G.Vector Vector" <+> text "where"
22 , nest 2 $ vcat $ map method methods_Vector
23 , text "#endif"
24 , text "#ifdef DEFINE_MUTABLE"
25 , define_zip "MVector s" "MV"
26 , define_unzip "MVector s" "MV"
27 , text "#endif"
28 , text "#ifdef DEFINE_IMMUTABLE"
29 , define_zip "Vector" "V"
30 , define_zip_rule
31 , define_unzip "Vector" "V"
32 , text "#endif"
33 ]
34
35 where
36 vars = map char $ take n ['a'..]
37 varss = map (<> char 's') vars
38 tuple xs = parens $ hsep $ punctuate comma xs
39 vtuple xs = parens $ sep $ punctuate comma xs
40 con s = text s <> char '_' <> int n
41 var c = text (c : "_")
42
43 data_instance ty c
44 = hang (hsep [text "data instance", text ty, tuple vars])
45 4
46 (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
47 , vcat $ map (\v -> parens (text ty <+> v)) vars])
48
49 class_instance cls
50 = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
51 <+> text "=>" <+> text cls <+> tuple vars
52
53
54 define_zip ty c
55 = sep [name <+> text "::"
56 <+> vtuple [text "Unbox" <+> v | v <- vars]
57 <+> text "=>"
58 <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
59 <+> text "->"
60 <+> text ty <+> tuple vars
61 ,text "{-# INLINE_STREAM" <+> name <+> text "#-}"
62 ,name <+> sep varss
63 <+> text "="
64 <+> con c
65 <+> text "len"
66 <+> sep [parens $ text "unsafeSlice"
67 <+> vs
68 <+> char '0'
69 <+> text "len" | vs <- varss]
70 ,nest 2 $ hang (text "where")
71 2
72 $ text "len ="
73 <+> sep (punctuate (text " `min`")
74 [text "length" <+> vs | vs <- varss])
75 ]
76 where
77 name | n == 2 = text "zip"
78 | otherwise = text "zip" <> int n
79
80 define_zip_rule
81 = hang (text "{-# RULES" <+> text "\"stream/" <> name "zip"
82 <> text " [Vector.Unboxed]\" forall" <+> sep varss <+> char '.')
83 2 $
84 text "G.stream" <+> parens (name "zip" <+> sep varss)
85 <+> char '='
86 <+> text "Stream." <> name "zipWith" <+> tuple (replicate n empty)
87 <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
88 $$ text "#-}"
89 where
90 name s | n == 2 = text s
91 | otherwise = text s <> int n
92
93
94 define_unzip ty c
95 = sep [name <+> text "::"
96 <+> vtuple [text "Unbox" <+> v | v <- vars]
97 <+> text "=>"
98 <+> text ty <+> tuple vars
99 <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
100 ,text "{-# INLINE" <+> name <+> text "#-}"
101 ,name <+> pat c <+> text "="
102 <+> vtuple varss
103 ]
104 where
105 name | n == 2 = text "unzip"
106 | otherwise = text "unzip" <> int n
107
108 pat c = parens $ con c <+> var 'n' <+> sep varss
109 patn c n = parens $ con c <+> (var 'n' <> int n)
110 <+> sep [v <> int n | v <- varss]
111
112 qM s = text "M." <> text s
113 qG s = text "G." <> text s
114
115 gen_length c _ = (pat c, var 'n')
116
117 gen_unsafeSlice mod c rec
118 = (pat c <+> var 'i' <+> var 'm',
119 con c <+> var 'm'
120 <+> vcat [parens
121 $ text mod <> char '.' <> text rec
122 <+> vs <+> var 'i' <+> var 'm'
123 | vs <- varss])
124
125
126 gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
127 vcat $ r : [text "||" <+> r | r <- rs])
128 where
129 r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]
130
131 gen_unsafeNew rec
132 = (var 'n',
133 mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
134 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
135
136 gen_unsafeNewWith rec
137 = (var 'n' <+> tuple vars,
138 mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
139 | v <- vars | vs <- varss]
140 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
141
142 gen_unsafeRead rec
143 = (pat "MV" <+> var 'i',
144 mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v <- vars
145 | vs <- varss]
146 $ text "return" <+> tuple vars)
147
148 gen_unsafeWrite rec
149 = (pat "MV" <+> var 'i' <+> tuple vars,
150 mk_do [qM rec <+> vs <+> var 'i' <+> v | v <- vars | vs <- varss]
151 empty)
152
153 gen_clear rec
154 = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
155
156 gen_set rec
157 = (pat "MV" <+> tuple vars,
158 mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)
159
160 gen_unsafeCopy rec
161 = (patn "MV" 1 <+> patn "MV" 2,
162 mk_do [qM rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
163 empty)
164
165 gen_unsafeGrow rec
166 = (pat "MV" <+> var 'm',
167 mk_do [qM rec <+> vs <+> var 'm' | vs <- varss]
168 $ text "return $" <+> con "MV"
169 <+> parens (var 'm' <> char '+' <> var 'n')
170 <+> sep varss)
171
172 gen_unsafeFreeze rec
173 = (pat "MV",
174 mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
175 $ text "return $" <+> con "V" <+> var 'n'
176 <+> sep [vs <> char '\'' | vs <- varss])
177
178 gen_basicUnsafeIndexM rec
179 = (pat "V" <+> var 'i',
180 mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
181 | vs <- varss | v <- vars]
182 $ text "return" <+> tuple vars)
183
184
185
186
187 mk_do cmds ret = hang (text "do")
188 2
189 $ vcat $ cmds ++ [ret]
190
191 method (s, f) = case f s of
192 (p,e) -> text "{-# INLINE" <+> text s <+> text " #-}"
193 $$ hang (text s <+> p)
194 4
195 (char '=' <+> e)
196
197
198 methods_MVector = [("basicLength", gen_length "MV")
199 ,("basicUnsafeSlice", gen_unsafeSlice "M" "MV")
200 ,("basicOverlaps", gen_overlaps)
201 ,("basicUnsafeNew", gen_unsafeNew)
202 ,("basicUnsafeNewWith", gen_unsafeNewWith)
203 ,("basicUnsafeRead", gen_unsafeRead)
204 ,("basicUnsafeWrite", gen_unsafeWrite)
205 ,("basicClear", gen_clear)
206 ,("basicSet", gen_set)
207 ,("basicUnsafeCopy", gen_unsafeCopy)
208 ,("basicUnsafeGrow", gen_unsafeGrow)]
209
210 methods_Vector = [("unsafeFreeze", gen_unsafeFreeze)
211 ,("basicLength", gen_length "V")
212 ,("basicUnsafeSlice", gen_unsafeSlice "G" "V")
213 ,("basicUnsafeIndexM", gen_basicUnsafeIndexM)]