Rewrite rules for stream/zip<n> from Unboxed
[darcs-mirrors/vector.git] / internal / GenUnboxTuple.hs
1 {-# LANGUAGE ParallelListComp #-}
2 module Main where
3
4 import Text.PrettyPrint
5
6 import System.Environment ( getArgs )
7
8 main = do
9 [s] <- getArgs
10 let n = read s
11 mapM_ (putStrLn . render . generate) [2..n]
12
13 generate :: Int -> Doc
14 generate n =
15 vcat [ text "#ifdef DEFINE_INSTANCES"
16 , data_instance "MVector s" "MV"
17 , data_instance "Vector" "V"
18 , class_instance "Unbox"
19 , class_instance "M.MVector MVector" <+> text "where"
20 , nest 2 $ vcat $ map method methods_MVector
21 , class_instance "G.Vector Vector" <+> text "where"
22 , nest 2 $ vcat $ map method methods_Vector
23 , text "#endif"
24 , text "#ifdef DEFINE_MUTABLE"
25 , define_zip "MVector s" "MV"
26 , define_unzip "MVector s" "MV"
27 , text "#endif"
28 , text "#ifdef DEFINE_IMMUTABLE"
29 , define_zip "Vector" "V"
30 , define_unzip "Vector" "V"
31 , text "#endif"
32 ]
33
34 where
35 vars = map char $ take n ['a'..]
36 varss = map (<> char 's') vars
37 tuple xs = parens $ hsep $ punctuate comma xs
38 vtuple xs = parens $ sep $ punctuate comma xs
39 con s = text s <> char '_' <> int n
40 var c = text (c : "_")
41
42 data_instance ty c
43 = hang (hsep [text "data instance", text ty, tuple vars])
44 4
45 (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
46 , vcat $ map (\v -> parens (text ty <+> v)) vars])
47
48 class_instance cls
49 = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
50 <+> text "=>" <+> text cls <+> tuple vars
51
52
53 define_zip ty c
54 = sep [name <+> text "::"
55 <+> vtuple [text "Unbox" <+> v | v <- vars]
56 <+> text "=>"
57 <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
58 <+> text "->"
59 <+> text ty <+> tuple vars
60 ,text "{-# INLINE_STREAM" <+> name <+> text "#-}"
61 ,name <+> sep varss
62 <+> text "="
63 <+> con c
64 <+> text "len"
65 <+> sep [parens $ text "unsafeSlice"
66 <+> vs
67 <+> char '0'
68 <+> text "len" | vs <- varss]
69 ,nest 2 $ hang (text "where")
70 2
71 $ text "len ="
72 <+> sep (punctuate (text " `min`")
73 [text "length" <+> vs | vs <- varss])
74 ,hang (text "{-# RULES" <+> text "\"stream/" <> name
75 <> text "\" forall" <+> sep varss <+> char '.')
76 2 $
77 text "G.stream" <+> parens (name <+> sep varss)
78 <+> char '='
79 <+> text "Stream." <> zw <+> tuple (replicate n empty)
80 <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
81 $$ text "#-}"
82 ]
83 where
84 name | n == 2 = text "zip"
85 | otherwise = text "zip" <> int n
86
87 zw | n == 2 = text "zipWith"
88 | otherwise = text "zipWith" <> int n
89
90 define_unzip ty c
91 = sep [name <+> text "::"
92 <+> vtuple [text "Unbox" <+> v | v <- vars]
93 <+> text "=>"
94 <+> text ty <+> tuple vars
95 <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
96 ,text "{-# INLINE" <+> name <+> text "#-}"
97 ,name <+> pat c <+> text "="
98 <+> vtuple varss
99 ]
100 where
101 name | n == 2 = text "unzip"
102 | otherwise = text "unzip" <> int n
103
104 pat c = parens $ con c <+> var 'n' <+> sep varss
105 patn c n = parens $ con c <+> (var 'n' <> int n)
106 <+> sep [v <> int n | v <- varss]
107
108 qM s = text "M." <> text s
109 qG s = text "G." <> text s
110
111 gen_length c _ = (pat c, var 'n')
112
113 gen_unsafeSlice mod c rec
114 = (pat c <+> var 'i' <+> var 'm',
115 con c <+> var 'm'
116 <+> vcat [parens
117 $ text mod <> char '.' <> text rec
118 <+> vs <+> var 'i' <+> var 'm'
119 | vs <- varss])
120
121
122 gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
123 vcat $ r : [text "||" <+> r | r <- rs])
124 where
125 r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]
126
127 gen_unsafeNew rec
128 = (var 'n',
129 mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
130 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
131
132 gen_unsafeNewWith rec
133 = (var 'n' <+> tuple vars,
134 mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
135 | v <- vars | vs <- varss]
136 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
137
138 gen_unsafeRead rec
139 = (pat "MV" <+> var 'i',
140 mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v <- vars
141 | vs <- varss]
142 $ text "return" <+> tuple vars)
143
144 gen_unsafeWrite rec
145 = (pat "MV" <+> var 'i' <+> tuple vars,
146 mk_do [qM rec <+> vs <+> var 'i' <+> v | v <- vars | vs <- varss]
147 empty)
148
149 gen_clear rec
150 = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
151
152 gen_set rec
153 = (pat "MV" <+> tuple vars,
154 mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)
155
156 gen_unsafeCopy rec
157 = (patn "MV" 1 <+> patn "MV" 2,
158 mk_do [qM rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
159 empty)
160
161 gen_unsafeGrow rec
162 = (pat "MV" <+> var 'm',
163 mk_do [qM rec <+> vs <+> var 'm' | vs <- varss]
164 $ text "return $" <+> con "MV"
165 <+> parens (var 'm' <> char '+' <> var 'n')
166 <+> sep varss)
167
168 gen_unsafeFreeze rec
169 = (pat "MV",
170 mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
171 $ text "return $" <+> con "V" <+> var 'n'
172 <+> sep [vs <> char '\'' | vs <- varss])
173
174 gen_basicUnsafeIndexM rec
175 = (pat "V" <+> var 'i',
176 mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
177 | vs <- varss | v <- vars]
178 $ text "return" <+> tuple vars)
179
180
181
182
183 mk_do cmds ret = hang (text "do")
184 2
185 $ vcat $ cmds ++ [ret]
186
187 method (s, f) = case f s of
188 (p,e) -> text "{-# INLINE" <+> text s <+> text " #-}"
189 $$ hang (text s <+> p)
190 4
191 (char '=' <+> e)
192
193
194 methods_MVector = [("basicLength", gen_length "MV")
195 ,("basicUnsafeSlice", gen_unsafeSlice "M" "MV")
196 ,("basicOverlaps", gen_overlaps)
197 ,("basicUnsafeNew", gen_unsafeNew)
198 ,("basicUnsafeNewWith", gen_unsafeNewWith)
199 ,("basicUnsafeRead", gen_unsafeRead)
200 ,("basicUnsafeWrite", gen_unsafeWrite)
201 ,("basicClear", gen_clear)
202 ,("basicSet", gen_set)
203 ,("basicUnsafeCopy", gen_unsafeCopy)
204 ,("basicUnsafeGrow", gen_unsafeGrow)]
205
206 methods_Vector = [("unsafeFreeze", gen_unsafeFreeze)
207 ,("basicLength", gen_length "V")
208 ,("basicUnsafeSlice", gen_unsafeSlice "G" "V")
209 ,("basicUnsafeIndexM", gen_basicUnsafeIndexM)]