04e7df1e8c30568481b3fa3aca47fc79912df092
[darcs-mirrors/vector.git] / internal / GenUnboxTuple.hs
1 {-# LANGUAGE ParallelListComp #-}
2 module Main where
3
4 import Text.PrettyPrint
5
6 import System.Environment ( getArgs )
7
8 main = do
9 [s] <- getArgs
10 let n = read s
11 mapM_ (putStrLn . render . generate) [2..n]
12
13 generate :: Int -> Doc
14 generate n =
15 vcat [ data_instance "MVector s" "MV"
16 , data_instance "Vector" "V"
17 , class_instance "Unbox"
18 , class_instance "M.MVector MVector" <+> text "where"
19 , nest 2 $ vcat $ map method methods_MVector
20 , class_instance "G.Vector Vector" <+> text "where"
21 , nest 2 $ vcat $ map method methods_Vector
22 ]
23
24 where
25 vars = map char $ take n ['a'..]
26 varss = map (<> char 's') vars
27 tuple f = parens $ hsep $ punctuate comma $ map f vars
28 vtuple f = parens $ sep $ punctuate comma $ map f vars
29 con s = text s <> char '_' <> int n
30 var c = text (c : "_")
31
32 data_instance ty c
33 = hang (hsep [text "data instance", text ty, tuple id])
34 4
35 (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
36 , vcat $ map (\v -> parens (text ty <+> v)) vars])
37
38 class_instance cls
39 = text "instance" <+> vtuple (text "Unbox" <+>)
40 <+> text "=>" <+> text cls <+> tuple id
41
42
43 pat c = parens $ con c <+> var 'n' <+> sep varss
44 patn c n = parens $ con c <+> (var 'n' <> int n)
45 <+> sep [v <> int n | v <- varss]
46
47 gen_length c = (pat c, var 'n')
48
49 gen_unsafeSlice mod c
50 = (pat c <+> var 'i' <+> var 'm',
51 con c <+> var 'm'
52 <+> vcat [parens $ text mod <> char '.' <> text "unsafeSlice"
53 <+> vs <+> var 'i' <+> var 'm'
54 | vs <- varss])
55
56
57 gen_overlaps = (patn "MV" 1 <+> patn "MV" 2,
58 vcat $ r : [text "||" <+> r | r <- rs])
59 where
60 r : rs = [text "M.overlaps" <+> v <> char '1' <+> v <> char '2'
61 | v <- varss]
62
63 gen_unsafeNew
64 = (var 'n',
65 mk_do [v <+> text "<- M.unsafeNew" <+> var 'n' | v <- varss]
66 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
67
68 gen_unsafeNewWith
69 = (var 'n' <+> tuple id,
70 mk_do [vs <+> text "<- M.unsafeNewWith" <+> var 'n' <+> v
71 | v <- vars | vs <- varss]
72 $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
73
74 gen_unsafeRead
75 = (pat "MV" <+> var 'i',
76 mk_do [v <+> text "<- M.unsafeRead" <+> vs <+> var 'i' | v <- vars
77 | vs <- varss]
78 $ text "return" <+> tuple id)
79
80 gen_unsafeWrite
81 = (pat "MV" <+> var 'i' <+> tuple id,
82 mk_do [text "M.unsafeWrite" <+> vs <+> var 'i' <+> v | v <- vars
83 | vs <- varss]
84 empty)
85
86 gen_clear
87 = (pat "MV", mk_do [text "M.clear" <+> vs | vs <- varss] empty)
88
89 gen_set
90 = (pat "MV" <+> tuple id,
91 mk_do [text "M.set" <+> vs <+> v | vs <- varss | v <- vars] empty)
92
93 gen_unsafeCopy
94 = (patn "MV" 1 <+> patn "MV" 2,
95 mk_do [text "M.unsafeCopy" <+> vs <> char '1' <+> vs <> char '2'
96 | vs <- varss] empty)
97
98 gen_unsafeGrow
99 = (pat "MV" <+> var 'm',
100 mk_do [text "M.unsafeGrow" <+> vs <+> var 'm' | vs <- varss]
101 $ text "return $" <+> con "MV"
102 <+> parens (var 'm' <> char '+' <> var 'n')
103 <+> sep varss)
104
105 gen_unsafeFreeze
106 = (pat "MV",
107 mk_do [vs <> char '\'' <+> text "<- G.unsafeFreeze" <+> vs
108 | vs <- varss]
109 $ text "return $" <+> con "V" <+> var 'n'
110 <+> sep [vs <> char '\'' | vs <- varss])
111
112 gen_basicUnsafeIndexM
113 = (pat "V" <+> var 'i',
114 mk_do [v <+> text "<- G.basicUnsafeIndexM" <+> vs <+> var 'i'
115 | vs <- varss | v <- vars]
116 $ text "return" <+> tuple id)
117
118
119
120
121 mk_do cmds ret = hang (text "do")
122 2
123 $ vcat $ cmds ++ [ret]
124
125 method (s, (p,e)) = text "{-# INLINE" <+> text s <+> text " #-}"
126 $$ hang (text s <+> p)
127 4
128 (char '=' <+> e)
129
130
131 methods_MVector = [("length", gen_length "MV")
132 ,("unsafeSlice", gen_unsafeSlice "M" "MV")
133 ,("overlaps", gen_overlaps)
134 ,("unsafeNew", gen_unsafeNew)
135 ,("unsafeNewWith", gen_unsafeNewWith)
136 ,("unsafeRead", gen_unsafeRead)
137 ,("unsafeWrite", gen_unsafeWrite)
138 ,("clear", gen_clear)
139 ,("set", gen_set)
140 ,("unsafeCopy", gen_unsafeCopy)
141 ,("unsafeGrow", gen_unsafeGrow)]
142
143 methods_Vector = [("unsafeFreeze", gen_unsafeFreeze)
144 ,("basicLength", gen_length "V")
145 ,("basicUnsafeSlice", gen_unsafeSlice "G" "V")
146 ,("basicUnsafeIndexM", gen_basicUnsafeIndexM)]