More Unboxed arrays
[darcs-mirrors/vector.git] / util / GenUnboxTuple.hs
1 {-# LANGUAGE ParallelListComp #-}
2 module Main where
3
4 import Text.PrettyPrint
5
6 import System.Environment ( getArgs )
7
8 main = do
9 [s] <- getArgs
10 let n = read s
11 mapM_ (putStrLn . render . generate) [2..n]
12
13 generate :: Int -> Doc
14 generate n =
15 vcat [ data_instance "MVector s" "MV"
16 , data_instance "Vector" "V"
17 , class_instance "Unbox"
18 , class_instance "M.MVector MVector" <+> text "where"
19 , nest 2 $ vcat $ map method methods_MVector
20 , class_instance "G.Vector Vector" <+> text "where"
21 , nest 2 $ vcat $ map method methods_Vector
22 ]
23
24 where
25 vars = map char $ take n ['a'..]
26 varss = map (<> char 's') vars
27 tuple f = parens $ hsep $ punctuate comma $ map f vars
28 vtuple f = parens $ sep $ punctuate comma $ map f vars
29 con s = text s <> char '_' <> int n
30
31 data_instance ty c
32 = hang (hsep [text "data instance", text ty, tuple id])
33 4
34 (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
35 , vcat $ map (\v -> parens (text ty <+> v)) vars])
36
37 class_instance cls
38 = text "instance" <+> vtuple (text "Unbox" <+>)
39 <+> text "=>" <+> text cls <+> tuple id
40
41
42 pat c = parens $ con c <+> char 'n' <+> sep varss
43 patn c n = parens $ con c <+> (char 'n' <> int n)
44 <+> sep [v <> int n | v <- varss]
45
46 gen_length c = (pat c, char 'n')
47
48 gen_unsafeSlice mod c
49 = (pat c <+> char 'i' <+> char 'm',
50 con c <+> char 'm'
51 <+> vcat [parens $ text mod <> char '.' <> text "unsafeSlice"
52 <+> vs <+> char 'i' <+> char 'm'
53 | vs <- varss])
54
55
56 gen_overlaps = (patn "MV" 1 <+> patn "MV" 2,
57 vcat $ r : [text "||" <+> r | r <- rs])
58 where
59 r : rs = [text "M.overlaps" <+> v <> char '1' <+> v <> char '2'
60 | v <- vars]
61
62 gen_unsafeNew
63 = (char 'n',
64 mk_do [v <+> text "<- M.unsafeNew n" | v <- varss]
65 $ text "return $" <+> con "MV" <+> sep varss)
66
67 gen_unsafeNewWith
68 = (char 'n' <+> tuple id,
69 mk_do [vs <+> text "<- M.unsafeNewWith n" <+> v | v <- vars
70 | vs <- varss]
71 $ text "return $" <+> con "MV" <+> sep varss)
72
73 gen_unsafeRead
74 = (pat "MV" <+> char 'i',
75 mk_do [v <+> text "<- M.unsafeRead" <+> vs <+> char 'i' | v <- vars
76 | vs <- varss]
77 $ text "return" <+> tuple id)
78
79 gen_unsafeWrite
80 = (pat "MV" <+> char 'i' <+> tuple id,
81 mk_do [text "M.unsafeWrite" <+> vs <+> char 'i' <+> v | v <- vars
82 | vs <- varss]
83 empty)
84
85 gen_clear
86 = (pat "MV", mk_do [text "M.clear" <+> vs | vs <- varss] empty)
87
88 gen_set
89 = (pat "MV" <+> tuple id,
90 mk_do [text "M.set" <+> vs <+> v | vs <- varss | v <- vars] empty)
91
92 gen_unsafeCopy
93 = (patn "MV" 1 <+> patn "MV" 2,
94 mk_do [text "M.unsafeCopy" <+> vs <> char '1' <+> vs <> char '2'
95 | vs <- varss] empty)
96
97 gen_unsafeGrow
98 = (pat "MV" <+> char 'm',
99 mk_do [text "M.unsafeGrow" <+> vs <+> char 'm' | vs <- varss]
100 $ text "return $" <+> con "MV" <+> text "(m+n)"
101 <+> sep varss)
102
103 gen_unsafeFreeze
104 = (pat "MV",
105 mk_do [vs <> char '\'' <+> text "<- G.unsafeFreeze" <+> vs
106 | vs <- varss]
107 $ text "return $" <+> con "V" <+> char 'n'
108 <+> sep [vs <> char '\'' | vs <- varss])
109
110 gen_basicUnsafeIndexM
111 = (pat "V" <+> char 'i',
112 mk_do [v <+> text "<- G.basicUnsafeIndexM" <+> vs <+> char 'i'
113 | vs <- varss | v <- vars]
114 $ text "return" <+> tuple id)
115
116
117
118
119 mk_do cmds ret = hang (text "do")
120 2
121 $ vcat $ cmds ++ [ret]
122
123 method (s, (p,e)) = text "{-# INLINE" <+> text s <+> text " #-}"
124 $$ hang (text s <+> p)
125 4
126 (char '=' <+> e)
127
128
129 methods_MVector = [("length", gen_length "MV")
130 ,("unsafeSlice", gen_unsafeSlice "M" "MV")
131 ,("overlaps", gen_overlaps)
132 ,("unsafeNew", gen_unsafeNew)
133 ,("unsafeNewWith", gen_unsafeNewWith)
134 ,("unsafeRead", gen_unsafeRead)
135 ,("unsafeWrite", gen_unsafeWrite)
136 ,("clear", gen_clear)
137 ,("set", gen_set)
138 ,("unsafeCopy", gen_unsafeCopy)
139 ,("unsafeGrow", gen_unsafeGrow)]
140
141 methods_Vector = [("unsafeFreeze", gen_unsafeFreeze)
142 ,("basicLength", gen_length "V")
143 ,("unsafeSlice", gen_unsafeSlice "G" "V")
144 ,("basicUnsafeIndexM", gen_basicUnsafeIndexM)]