Merge branch 'master' into type-nats
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats (solveNumWanted, addNumGiven) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , tcsLookupTyCon, tcsLookupClass
8 , CanonicalCt (..), CanonicalCts
9 , traceTcS
10 )
11 import TcCanonical (mkCanonicals)
12 import HsBinds (EvTerm(..))
13 import Class ( Class, className, classTyCon )
14 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
15 , tcEqType )
16 import TypeRep (Type(..))
17 import TyCon (TyCon, tyConName)
18 import Var (EvVar)
19 import Coercion ( mkUnsafeCoercion )
20 import Outputable
21 import PrelNames ( lessThanEqualClassName
22 , addTyFamName, mulTyFamName, expTyFamName
23 )
24 import Bag (bagToList, emptyBag, unionBags)
25 import Data.Maybe (fromMaybe)
26 import Control.Monad (msum, mplus, zipWithM, (<=<))
27
28
29
30 data Term = Var Xi | Num Integer (Maybe Xi)
31 data Op = Add | Mul | Exp deriving Eq
32 data Prop = EqFun Op Term Term Term
33 | Leq Term Term
34 | Eq Term Term
35
36 commutative :: Op -> Bool
37 commutative op = op == Add || op == Mul
38
39 num :: Integer -> Term
40 num n = Num n Nothing
41
42 instance Eq Term where
43 Var x == Var y = tcEqType x y
44 Num x _ == Num y _ = x == y
45 _ == _ = False
46
47
48 --------------------------------------------------------------------------------
49 -- Interface with the type checker
50 --------------------------------------------------------------------------------
51
52
53 -- We keep the original type in numeric constants to preserve type synonyms.
54 toTerm :: Xi -> Term
55 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
56 Just n -> Num n (Just xi)
57 _ -> Var xi
58
59 fromTerm :: Term -> Xi
60 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
61 fromTerm (Var xi) = xi
62
63 toProp :: CanonicalCt -> Prop
64 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
65 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
66
67 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
68 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
69 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
70 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
71
72 where t1 = toTerm xi11
73 t2 = toTerm xi12
74 t3 = toTerm xi2
75
76 toProp p = panic $
77 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
78
79
80 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
81 TcS (Maybe CanonicalCt, CanonicalCts)
82 solveNumWanted given wanted prop =
83 do let asmps = map toProp $ bagToList $ unionBags given wanted
84 goal = toProp prop
85
86 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
87
88 case solve asmps goal of
89
90 Simplified sgs ->
91 do numTrace "Simplified to" (vmany sgs)
92 defineDummy (cc_id prop) =<< fromProp goal
93 evs <- mapM (newSubGoal <=< fromProp) sgs
94 goals <- mkCanonicals (cc_flavor prop) evs
95 return (Nothing, goals)
96
97 -- XXX: The new wanted might imply some of the existing wanteds...
98 Improved is ->
99 do numTrace "Improved by" (vmany is)
100 evs <- mapM (newSubGoal <=< fromProp) is
101 goals <- mkCanonicals (cc_flavor prop) evs
102 return (Just prop, goals)
103
104 -- XXX: Report an error
105 Impossible ->
106 do numTrace "Impssoble" empty
107 return (Just prop, emptyBag)
108
109
110 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
111 TcS (Maybe CanonicalCt, CanonicalCts, CanonicalCts)
112 addNumGiven given wanted prop =
113 do let asmps = map toProp $ bagToList $ unionBags given wanted
114 goal = toProp prop
115
116 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
117 case solve asmps goal of
118
119 Simplified sgs ->
120 do numTrace "Simplified to" (vmany sgs)
121 evs <- mapM (newFact <=< fromProp) sgs
122 facts <- mkCanonicals (cc_flavor prop) evs
123 return (Nothing, unionBags given wanted, facts)
124
125 Improved is ->
126 do numTrace "Improved by" (vmany is)
127 evs <- mapM (newFact <=< fromProp) is
128 facts <- mkCanonicals (cc_flavor prop) evs
129 return (Just prop, given, unionBags wanted facts)
130
131 -- XXX: report error
132 Impossible ->
133 do numTrace "Impossible" empty
134 return (Just prop, unionBags given wanted, emptyBag)
135
136
137
138 data CvtProp = CvtClass Class [Type]
139 | CvtCo Type Type
140
141 fromProp :: Prop -> TcS CvtProp
142 fromProp (Leq t1 t2) =
143 do cl <- tcsLookupClass lessThanEqualClassName
144 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
145
146 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
147
148 fromProp (EqFun op t1 t2 t3) =
149 do tc <- tcsLookupTyCon $ case op of
150 Add -> addTyFamName
151 Mul -> mulTyFamName
152 Exp -> expTyFamName
153 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
154
155
156 newSubGoal :: CvtProp -> TcS EvVar
157 newSubGoal (CvtClass c ts) = newDictVar c ts
158 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
159
160 newFact :: CvtProp -> TcS EvVar
161 newFact prop =
162 do d <- newSubGoal prop
163 defineDummy d prop
164 return d
165
166
167 -- If we decided that we want to generate evidence terms,
168 -- here we would set the evidence properly. For now, we skip this
169 -- step because evidence terms are not used for anything, and they
170 -- get quite large, at least, if we start with a small set of axioms.
171 defineDummy :: EvVar -> CvtProp -> TcS ()
172 defineDummy d (CvtClass c ts) =
173 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
174
175 defineDummy c (CvtCo t1 t2) =
176 setWantedCoBind c $ mkUnsafeCoercion t1 t2
177
178
179
180
181 --------------------------------------------------------------------------------
182 -- The Solver
183 --------------------------------------------------------------------------------
184
185
186 data Result = Impossible -- We know that the goal cannot be solved.
187 | Simplified [Prop] -- We reformulated the goal.
188 | Improved [Prop] -- We learned some new facts.
189
190 solve :: [Prop] -> Prop -> Result
191 solve asmps prop =
192 case solveWanted1 prop of
193 Just sgs -> Simplified sgs
194 Nothing ->
195 case msum $ zipWith solveWanted2 asmps (repeat prop) of
196 Just sgs -> Simplified sgs
197 Nothing ->
198 case improve1 prop of
199 Nothing -> Impossible
200 Just eqs ->
201 case zipWithM improve2 asmps (repeat prop) of
202 Nothing -> Impossible
203 Just eqs1 -> Improved (concat (eqs : eqs1))
204
205
206
207
208 improve1 :: Prop -> Maybe [Prop]
209 improve1 prop =
210 case prop of
211
212 Leq s t@(Num 0 _) -> Just [ Eq s t ]
213
214
215 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
216 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
217 EqFun Add (Num m _) s (Num n _)
218 | m <= n -> Just [ Eq (num (n-m)) s ]
219 | otherwise -> Nothing
220 EqFun Add r s t
221 | r == t -> Just [ Eq (num 0) s ]
222 | s == t -> Just [ Eq (num 0) r ]
223
224 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
225 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
226 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
227 EqFun Mul (Num _ _) s t
228 | s == t -> Just [ Eq (num 0) s ]
229
230 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
231 EqFun Mul (Num m _) s (Num n _)
232 | Just a <- divide n m -> Just [ Eq (num a) s ]
233 | otherwise -> Nothing
234
235
236 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
237
238 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
239 EqFun Exp (Num _ _) s t
240 | s == t -> Nothing
241 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
242 return [ Eq (num a) s ]
243
244 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
245 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
246 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
247 return [ Eq (num a) r ]
248
249 _ -> Just []
250
251 improve2 :: Prop -> Prop -> Maybe [ Prop ]
252 improve2 asmp prop =
253 case asmp of
254
255 EqFun Add a1 b1 c1 ->
256 case prop of
257 EqFun Add a2 b2 c2
258 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
259 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
260 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
261 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
262 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
263 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
264 _ -> Just []
265
266
267 EqFun Mul a1 b1 c1 ->
268 case prop of
269 EqFun Mul a2 b2 c2
270 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
271 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
272 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
273 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
274 _ -> Just []
275
276
277 _ -> Just []
278
279
280 solveWanted1 :: Prop -> Maybe [ Prop ]
281 solveWanted1 prop =
282 case prop of
283
284 Leq (Num m _) (Num n _) | m <= n -> Just []
285 Leq (Num 0 _) _ -> Just []
286 Leq s t | s == t -> Just []
287
288 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
289 EqFun Add (Num 0 _) s t | s == t -> Just []
290 EqFun Add r s t | r == s ->
291 Just [ EqFun Mul (num 2) r t ]
292
293 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
294 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
295 EqFun Mul (Num 1 _) s t | s == t -> Just []
296 EqFun Mul r s t | r == s ->
297 Just [ EqFun Exp r (num 2) t ]
298
299 -- Simple normalization of commutative operators
300 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
301 Just [ EqFun op s r t ]
302
303 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
304 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
305 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
306 EqFun Exp r (Num 1 _) t | r == t -> Just []
307 EqFun Exp r (Num _ _) t | r == t ->
308 Just [Leq r (num 1)]
309
310 _ -> Nothing
311
312
313 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
314 solveWanted2 asmp prop =
315 case (asmp, prop) of
316
317 (Leq s1 t1, Leq s2 t2)
318 | s1 == s2 && t1 == t2 -> Just []
319
320 (Leq (Num m _) t1, Leq (Num n _) t2)
321 | t1 == t2 && m >= n -> Just []
322
323 (Leq s1 (Num m _), Leq s2 (Num n _))
324 | s1 == s2 && m <= n -> Just []
325
326 (EqFun Add r1 s1 t1, Leq r2 t2)
327 | t1 == t2 && r1 == r2 -> Just []
328 | t1 == t2 && s1 == r2 -> Just []
329
330 -- XXX: others, transitivity
331
332 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
333 | op1 == op2 && t1 == t2 &&
334 ( r1 == r2 && s1 == s2
335 || commutative op1 && r1 == s2 && s1 == r2
336 ) -> Just []
337
338 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
339 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
340 else Just [ EqFun Add (num (n - m)) d b ]
341
342 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
343 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
344 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
345
346 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
347 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
348 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
349 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
350
351 _ -> Nothing
352
353
354
355
356 --------------------------------------------------------------------------------
357 -- Descrete Math
358 --------------------------------------------------------------------------------
359
360 descreteRoot :: Integer -> Integer -> Maybe Integer
361 descreteRoot root num = search 0 num
362 where
363 search from to = let x = from + div (to - from) 2
364 a = x ^ root
365 in case compare a num of
366 EQ -> Just x
367 LT | x /= from -> search x to
368 GT | x /= to -> search from x
369 _ -> Nothing
370
371 descreteLog :: Integer -> Integer -> Maybe Integer
372 descreteLog _ 0 = Just 0
373 descreteLog base num | base == num = Just 1
374 descreteLog base num = case divMod num base of
375 (x,0) -> fmap (1+) (descreteLog base x)
376 _ -> Nothing
377
378 divide :: Integer -> Integer -> Maybe Integer
379 divide _ 0 = Nothing
380 divide x y = case divMod x y of
381 (a,0) -> Just a
382 _ -> Nothing
383
384
385 --------------------------------------------------------------------------------
386 -- Debugging
387 --------------------------------------------------------------------------------
388
389 numTrace :: String -> SDoc -> TcS ()
390 numTrace x y = traceTcS ("[numerics] " ++ x) y
391
392 vmany :: Outputable a => [a] -> SDoc
393 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
394
395 instance Outputable Term where
396 ppr (Var xi) = ppr xi
397 ppr (Num n _) = integer n
398
399 instance Outputable Prop where
400 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
401 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
402 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
403
404 instance Outputable Op where
405 ppr op = case op of
406 Add -> char '+'
407 Mul -> char '*'
408 Exp -> char '^'
409