baf3c0298c04cef5c0007a95764fbcd6fffa8927
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats (solveNumWanted, addNumGiven) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , tcsLookupTyCon, tcsLookupClass
8 , CanonicalCt (..), CanonicalCts
9 , traceTcS
10 )
11 import TcCanonical (mkCanonicals)
12 import HsBinds (EvTerm(..))
13 import Class ( Class, className, classTyCon )
14 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
15 , tcEqType )
16 import TypeRep (Type(..))
17 import TyCon (TyCon, tyConName)
18 import Var (EvVar)
19 import Coercion ( mkUnsafeCoercion )
20 import Outputable
21 import PrelNames ( lessThanEqualClassName
22 , addTyFamName, mulTyFamName, expTyFamName
23 )
24 import Bag (bagToList, emptyBag, unionBags)
25 import Data.Maybe (fromMaybe)
26 import Data.List (nub, partition)
27 import Control.Monad (msum, mplus, zipWithM, (<=<), guard)
28
29
30
31 data Term = Var Xi | Num Integer (Maybe Xi)
32 data Op = Add | Mul | Exp deriving Eq
33 data Prop = EqFun Op Term Term Term
34 | Leq Term Term
35 | Eq Term Term
36
37 commutative :: Op -> Bool
38 commutative op = op == Add || op == Mul
39
40 associative :: Op -> Bool
41 associative op = op == Add || op == Mul
42
43 num :: Integer -> Term
44 num n = Num n Nothing
45
46 instance Eq Term where
47 Var x == Var y = tcEqType x y
48 Num x _ == Num y _ = x == y
49 _ == _ = False
50
51
52 --------------------------------------------------------------------------------
53 -- Interface with the type checker
54 --------------------------------------------------------------------------------
55
56
57 -- We keep the original type in numeric constants to preserve type synonyms.
58 toTerm :: Xi -> Term
59 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
60 Just n -> Num n (Just xi)
61 _ -> Var xi
62
63 fromTerm :: Term -> Xi
64 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
65 fromTerm (Var xi) = xi
66
67 toProp :: CanonicalCt -> Prop
68 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
69 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
70
71 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
72 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
73 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
74 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
75
76 where t1 = toTerm xi11
77 t2 = toTerm xi12
78 t3 = toTerm xi2
79
80 toProp p = panic $
81 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
82
83
84 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
85 TcS (Maybe CanonicalCt, CanonicalCts)
86 solveNumWanted given wanted prop =
87 do let asmps = map toProp $ bagToList $ unionBags given wanted
88 goal = toProp prop
89
90 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
91
92 case solve asmps goal of
93
94 Simplified sgs ->
95 do numTrace "Simplified to" (vmany sgs)
96 defineDummy (cc_id prop) =<< fromProp goal
97 evs <- mapM (newSubGoal <=< fromProp) sgs
98 goals <- mkCanonicals (cc_flavor prop) evs
99 return (Nothing, goals)
100
101 -- XXX: The new wanted might imply some of the existing wanteds...
102 Improved is ->
103 do numTrace "Improved by" (vmany is)
104 evs <- mapM (newSubGoal <=< fromProp) is
105 goals <- mkCanonicals (cc_flavor prop) evs
106 return (Just prop, goals)
107
108 -- XXX: Report an error
109 Impossible ->
110 do numTrace "Impssoble" empty
111 return (Just prop, emptyBag)
112
113
114 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
115 TcS (Maybe CanonicalCt, CanonicalCts, CanonicalCts)
116 addNumGiven given wanted prop =
117 do let asmps = map toProp $ bagToList $ unionBags given wanted
118 goal = toProp prop
119
120 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
121 case solve asmps goal of
122
123 Simplified sgs ->
124 do numTrace "Simplified to" (vmany sgs)
125 evs <- mapM (newFact <=< fromProp) sgs
126 facts <- mkCanonicals (cc_flavor prop) evs
127 return (Nothing, unionBags given wanted, facts)
128
129 Improved is ->
130 do numTrace "Improved by" (vmany is)
131 evs <- mapM (newFact <=< fromProp) is
132 facts <- mkCanonicals (cc_flavor prop) evs
133 return (Just prop, given, unionBags wanted facts)
134
135 -- XXX: report error
136 Impossible ->
137 do numTrace "Impossible" empty
138 return (Just prop, unionBags given wanted, emptyBag)
139
140
141
142 data CvtProp = CvtClass Class [Type]
143 | CvtCo Type Type
144
145 fromProp :: Prop -> TcS CvtProp
146 fromProp (Leq t1 t2) =
147 do cl <- tcsLookupClass lessThanEqualClassName
148 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
149
150 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
151
152 fromProp (EqFun op t1 t2 t3) =
153 do tc <- tcsLookupTyCon $ case op of
154 Add -> addTyFamName
155 Mul -> mulTyFamName
156 Exp -> expTyFamName
157 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
158
159
160 newSubGoal :: CvtProp -> TcS EvVar
161 newSubGoal (CvtClass c ts) = newDictVar c ts
162 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
163
164 newFact :: CvtProp -> TcS EvVar
165 newFact prop =
166 do d <- newSubGoal prop
167 defineDummy d prop
168 return d
169
170
171 -- If we decided that we want to generate evidence terms,
172 -- here we would set the evidence properly. For now, we skip this
173 -- step because evidence terms are not used for anything, and they
174 -- get quite large, at least, if we start with a small set of axioms.
175 defineDummy :: EvVar -> CvtProp -> TcS ()
176 defineDummy d (CvtClass c ts) =
177 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
178
179 defineDummy c (CvtCo t1 t2) =
180 setWantedCoBind c $ mkUnsafeCoercion t1 t2
181
182
183
184
185 --------------------------------------------------------------------------------
186 -- The Solver
187 --------------------------------------------------------------------------------
188
189
190 data Result = Impossible -- We know that the goal cannot be solved.
191 | Simplified [Prop] -- We reformulated the goal.
192 | Improved [Prop] -- We learned some new facts.
193
194 solve :: [Prop] -> Prop -> Result
195 solve asmps (Leq a b) =
196 let ps = propsToOrd asmps
197 in case isLeq ps a b of
198 True -> Simplified []
199 False ->
200 case improveLeq ps a b of
201 Nothing -> Impossible
202 Just ps -> Improved ps
203
204 solve asmps prop =
205 case solveWanted1 prop of
206 Just sgs -> Simplified sgs
207 Nothing ->
208 case msum $ zipWith solveWanted2 asmps (repeat prop) of
209 Just sgs -> Simplified sgs
210 Nothing ->
211 case improve1 prop of
212 Nothing -> Impossible
213 Just eqs ->
214 case zipWithM improve2 asmps (repeat prop) of
215 Nothing -> Impossible
216 Just eqs1 -> Improved (concat (eqs : eqs1))
217
218
219
220
221 improve1 :: Prop -> Maybe [Prop]
222 improve1 prop =
223 case prop of
224
225 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
226 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
227 EqFun Add (Num m _) s (Num n _)
228 | m <= n -> Just [ Eq (num (n-m)) s ]
229 | otherwise -> Nothing
230 EqFun Add r s t
231 | r == t -> Just [ Eq (num 0) s ]
232 | s == t -> Just [ Eq (num 0) r ]
233
234 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
235 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
236 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
237 EqFun Mul (Num _ _) s t
238 | s == t -> Just [ Eq (num 0) s ]
239
240 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
241 EqFun Mul (Num m _) s (Num n _)
242 | Just a <- divide n m -> Just [ Eq (num a) s ]
243 | otherwise -> Nothing
244
245
246 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
247
248 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
249 EqFun Exp (Num _ _) s t
250 | s == t -> Nothing
251 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
252 return [ Eq (num a) s ]
253
254 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
255 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
256 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
257 return [ Eq (num a) r ]
258
259 _ -> Just []
260
261 improve2 :: Prop -> Prop -> Maybe [ Prop ]
262 improve2 asmp prop =
263 case asmp of
264
265 EqFun Add a1 b1 c1 ->
266 case prop of
267 EqFun Add a2 b2 c2
268 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
269 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
270 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
271 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
272 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
273 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
274 _ -> Just []
275
276
277 EqFun Mul a1 b1 c1 ->
278 case prop of
279 EqFun Mul a2 b2 c2
280 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
281 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
282 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
283 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
284 _ -> Just []
285
286
287 _ -> Just []
288
289
290 solveWanted1 :: Prop -> Maybe [ Prop ]
291 solveWanted1 prop =
292 case prop of
293
294 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
295 EqFun Add (Num 0 _) s t | s == t -> Just []
296 EqFun Add r s t | r == s ->
297 Just [ EqFun Mul (num 2) r t ]
298
299 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
300 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
301 EqFun Mul (Num 1 _) s t | s == t -> Just []
302 EqFun Mul r s t | r == s ->
303 Just [ EqFun Exp r (num 2) t ]
304
305 -- Simple normalization of commutative operators
306 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
307 Just [ EqFun op s r t ]
308
309 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
310 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
311 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
312 EqFun Exp r (Num 1 _) t | r == t -> Just []
313 EqFun Exp r (Num _ _) t | r == t ->
314 Just [Leq r (num 1)]
315
316 _ -> Nothing
317
318
319 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
320 solveWanted2 asmp prop =
321 case (asmp, prop) of
322
323 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
324 | op1 == op2 && t1 == t2 &&
325 ( r1 == r2 && s1 == s2
326 || commutative op1 && r1 == s2 && s1 == r2
327 ) -> Just []
328
329 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
330 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
331 else Just [ EqFun Add (num (n - m)) d b ]
332
333 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
334 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
335 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
336
337 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
338 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
339 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
340 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
341
342 _ -> Nothing
343
344
345 --------------------------------------------------------------------------------
346 -- Reasoning about ordering.
347 --------------------------------------------------------------------------------
348
349 -- This function assumes that the assumptions are acyclic.
350 isLeq :: [(Term, Term)] -> Term -> Term -> Bool
351 isLeq _ (Num 0 _) _ = True
352 isLeq _ (Num m _) (Num n _) = m <= n
353 isLeq _ a b | a == b = True
354 isLeq ps (Num m _) a = or [ isLeq ps b a | (Num x _, b) <- ps, m <= x ]
355 isLeq ps a (Num m _) = or [ isLeq ps a b | (b, Num x _) <- ps, x <= m ]
356 isLeq ps a b = or [ isLeq ps c b | (a',c) <- ps, a == a' ]
357
358 isGt :: [(Term, Term)] -> Term -> Term -> Bool
359 isGt _ (Num m _) (Num n _) = m > n
360 isGt ps (Num m _) a = (m > 0) && isLeq ps a (num (m - 1))
361 isGt ps a (Num m _) = isLeq ps (num (m + 1)) a
362 isGt _ _ _ = False
363
364 improveLeq :: [(Term,Term)] -> Term -> Term -> Maybe [Prop]
365 improveLeq ps a b | isLeq ps b a = Just [Eq a b]
366 | isGt ps a b = Nothing
367 | otherwise = Just []
368
369 -- Ordering constraints derived from numeric predicates.
370 -- We do not consider equlities because they should be substituted away.
371 propsToOrd :: [Prop] -> [(Term,Term)]
372 propsToOrd props = loop (step [] unconditional)
373 where
374 loop ps = let new = filter (`notElem` ps) (step ps conditional)
375 in if null new then ps else loop (new ++ ps)
376
377 step ps = nub . concatMap (toOrd ps)
378
379 isConditional (EqFun op _ _ _) = op == Mul || op == Exp
380 isConditional _ = False
381
382 (conditional,unconditional) = partition isConditional props
383
384 toOrd _ (Leq a b) = [(a,b)]
385 toOrd _ (Eq _ _) = [] -- Would lead to a cycle, should be subst. away
386 toOrd ps (EqFun op a b c) =
387 case op of
388 Add -> [(a,c),(b,c)]
389
390 Mul -> (guard (isLeq ps (num 1) a) >> return (b,c)) ++
391 (guard (isLeq ps (num 1) b) >> return (a,c))
392 Exp
393 | Num 0 _ <- a -> [(c, num 1)]
394 | a == c -> [(a, num 1)]
395 | otherwise -> (guard (isLeq ps (num 2) a) >> return (b,c)) ++
396 (guard (isLeq ps (num 1) b) >> return (a,c))
397
398 --------------------------------------------------------------------------------
399 -- Associativity
400 --------------------------------------------------------------------------------
401
402 -- We try to compute new equalities by moving parens to the right:
403 -- if (a + b) + c = d then a + (b + c) = d
404 --
405 -- This desugars to:
406 -- a + b = p, p + c = q, b + c = r => a + r = q
407 --
408 -- Note that this fires only if we have a name, r, for the right-paren sum,
409 -- or if we can compute it.
410 -- a + 2 = p, p + 3 = q => a + 5 = q
411 --
412 -- So, when we add a new fact it could match any of the 3 positions
413 -- in the RHS of this rule.
414
415 impAssoc :: [Prop] -> Prop -> [Prop]
416 impAssoc ps (EqFun op x y z) | associative op = concatMap imps asmps
417 where
418 asmps = [ (a,b,c) | EqFun op' a b c <- ps, op == op' ]
419 imps (a,b,c)
420 -- new in position 2
421 | c == x = [ EqFun op a by z | by <- y_plus b ]
422
423 -- new in position 1
424 | a == z = [ EqFun op x yb c | yb <- y_plus b ]
425
426 -- new in position 3
427 | otherwise = [ EqFun op a z e | (c',d,e) <- asmps,
428 c == c' && sameSum b d x y ]
429 ++
430 [ EqFun op d z c | (d,e,a') <- asmps,
431 a == a' && sameSum e b x y ]
432
433 sameSum a b a' b' = (a == a' && b == b') || (a == b' && b == a')
434
435 y_plus (Num m _) | Num n _ <- y = doOp m n
436 y_plus b = [ w | (u,v,w) <- asmps, sameSum y b u v ]
437
438 doOp m n = case op of
439 Add -> [ num (m + n) ]
440 Mul -> [ num (m * n) ]
441 _ -> []
442
443 impAssoc _ _ = []
444
445 --------------------------------------------------------------------------------
446 -- Descrete Math
447 --------------------------------------------------------------------------------
448
449 descreteRoot :: Integer -> Integer -> Maybe Integer
450 descreteRoot root num = search 0 num
451 where
452 search from to = let x = from + div (to - from) 2
453 a = x ^ root
454 in case compare a num of
455 EQ -> Just x
456 LT | x /= from -> search x to
457 GT | x /= to -> search from x
458 _ -> Nothing
459
460 descreteLog :: Integer -> Integer -> Maybe Integer
461 descreteLog _ 0 = Just 0
462 descreteLog base num | base == num = Just 1
463 descreteLog base num = case divMod num base of
464 (x,0) -> fmap (1+) (descreteLog base x)
465 _ -> Nothing
466
467 divide :: Integer -> Integer -> Maybe Integer
468 divide _ 0 = Nothing
469 divide x y = case divMod x y of
470 (a,0) -> Just a
471 _ -> Nothing
472
473
474 --------------------------------------------------------------------------------
475 -- Debugging
476 --------------------------------------------------------------------------------
477
478 numTrace :: String -> SDoc -> TcS ()
479 numTrace x y = traceTcS ("[numerics] " ++ x) y
480
481 vmany :: Outputable a => [a] -> SDoc
482 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
483
484 instance Outputable Term where
485 ppr (Var xi) = ppr xi
486 ppr (Num n _) = integer n
487
488 instance Outputable Prop where
489 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
490 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
491 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
492
493 instance Outputable Op where
494 ppr op = case op of
495 Add -> char '+'
496 Mul -> char '*'
497 Exp -> char '^'
498