Refactor a bit and report errors due to impossible constraints.
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats ( canonicalNum, NumericsResult(..) ) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , getWantedLoc
8 , tcsLookupTyCon, tcsLookupClass
9 , CanonicalCt (..), CanonicalCts
10 , mkFrozenError
11 , traceTcS
12 )
13 import TcRnTypes ( CtFlavor(..) )
14 import TcCanonical (mkCanonicals)
15 import HsBinds (EvTerm(..))
16 import Class ( Class, className, classTyCon )
17 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
18 , tcEqType )
19 import TypeRep (Type(..))
20 import TyCon (TyCon, tyConName)
21 import Var (EvVar)
22 import Coercion ( mkUnsafeCoercion )
23 import Outputable
24 import PrelNames ( lessThanEqualClassName
25 , addTyFamName, mulTyFamName, expTyFamName
26 )
27 import Bag (bagToList, emptyBag, unionManyBags, unionBags)
28 import Data.Maybe (fromMaybe)
29 import Data.List (nub, partition)
30 import Control.Monad (msum, mplus, zipWithM, (<=<), guard)
31
32
33
34 data Term = Var Xi | Num Integer (Maybe Xi)
35 data Op = Add | Mul | Exp deriving Eq
36 data Prop = EqFun Op Term Term Term
37 | Leq Term Term
38 | Eq Term Term
39
40 commutative :: Op -> Bool
41 commutative op = op == Add || op == Mul
42
43 associative :: Op -> Bool
44 associative op = op == Add || op == Mul
45
46 num :: Integer -> Term
47 num n = Num n Nothing
48
49 instance Eq Term where
50 Var x == Var y = tcEqType x y
51 Num x _ == Num y _ = x == y
52 _ == _ = False
53
54
55 --------------------------------------------------------------------------------
56 -- Interface with the type checker
57 --------------------------------------------------------------------------------
58
59 data NumericsResult = NumericsResult
60 { numNewWork :: CanonicalCts
61 , numInert :: Maybe CanonicalCts -- Nothing for "no change"
62 , numNext :: Maybe CanonicalCt
63 }
64
65 -- We keep the original type in numeric constants to preserve type synonyms.
66 toTerm :: Xi -> Term
67 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
68 Just n -> Num n (Just xi)
69 _ -> Var xi
70
71 fromTerm :: Term -> Xi
72 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
73 fromTerm (Var xi) = xi
74
75 toProp :: CanonicalCt -> Prop
76 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
77 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
78
79 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
80 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
81 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
82 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
83
84 where t1 = toTerm xi11
85 t2 = toTerm xi12
86 t3 = toTerm xi2
87
88 toProp p = panic $
89 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
90
91
92 canonicalNum :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
93 TcS NumericsResult
94 canonicalNum given derived wanted prop =
95 case cc_flavor prop of
96 Wanted {} -> solveNumWanted given derived wanted prop
97 Derived {} -> addNumDerived given derived wanted prop
98 Given {} -> addNumGiven given derived wanted prop
99
100
101 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
102 TcS NumericsResult
103 solveNumWanted given derived wanted prop =
104 do let asmps = map toProp $ bagToList $ unionManyBags [given,derived,wanted]
105 goal = toProp prop
106
107 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
108
109 case solve asmps goal of
110
111 Simplified sgs ->
112 do numTrace "Simplified to" (vmany sgs)
113 defineDummy (cc_id prop) =<< fromProp goal
114 evs <- mapM (newSubGoal <=< fromProp) sgs
115 goals <- mkCanonicals (cc_flavor prop) evs
116 return NumericsResult
117 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
118
119 -- XXX: The new wanted might imply some of the existing wanteds...
120 Improved is ->
121 do numTrace "Improved by" (vmany is)
122 evs <- mapM (newSubGoal <=< fromProp) is
123 goals <- mkCanonicals (cc_flavor prop) evs
124 return NumericsResult
125 { numNext = Just prop, numInert = Nothing, numNewWork = goals }
126
127 Impossible -> impossible prop
128
129
130 -- XXX: Need to understand derived work better.
131 addNumDerived :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
132 TcS NumericsResult
133 addNumDerived given derived wanted prop =
134 do let asmps = map toProp $ bagToList given
135 goal = toProp prop
136
137 numTrace "addNumDerived" (vmany asmps <+> text "|-" <+> ppr goal)
138
139 case solve asmps goal of
140
141 Simplified sgs ->
142 do numTrace "Simplified to" (vmany sgs)
143 defineDummy (cc_id prop) =<< fromProp goal
144 evs <- mapM (newSubGoal <=< fromProp) sgs
145 goals <- mkCanonicals (cc_flavor prop) evs
146 return NumericsResult
147 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
148
149 -- XXX: watch out for cycles because of the wanteds being restarted:
150 -- W => D && D => W, we could solve W by W
151 -- if W <=> D, then we should simplify W to D, not make it derived.
152 Improved is ->
153 do numTrace "Improved by" (vmany is)
154 evs <- mapM (newSubGoal <=< fromProp) is
155 goals <- mkCanonicals (Derived (getWantedLoc prop)) evs
156 return NumericsResult
157 { numNext = Just prop, numInert = Just (unionBags given derived)
158 , numNewWork = unionBags wanted goals }
159
160 Impossible -> impossible prop
161
162
163 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
164 TcS NumericsResult
165 addNumGiven given derived wanted prop =
166 do let asmps = map toProp (bagToList given)
167 goal = toProp prop
168
169 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
170 case solve asmps goal of
171
172 Simplified sgs ->
173 do numTrace "Simplified to" (vmany sgs)
174 evs <- mapM (newFact <=< fromProp) sgs
175 facts <- mkCanonicals (cc_flavor prop) evs
176 return NumericsResult
177 { numNext = Nothing, numInert = Nothing, numNewWork = facts }
178
179 Improved is ->
180 do numTrace "Improved by" (vmany is)
181 evs <- mapM (newFact <=< fromProp) is
182 facts <- mkCanonicals (cc_flavor prop) evs
183 return NumericsResult
184 { numNext = Just prop, numInert = Just (unionBags given derived)
185 , numNewWork = unionBags wanted facts }
186
187 Impossible -> impossible prop
188
189 impossible :: CanonicalCt -> TcS NumericsResult
190 impossible c =
191 do numTrace "Impossible" empty
192 let err = mkFrozenError (cc_flavor c) (cc_id c)
193 return NumericsResult
194 { numNext = Just err, numInert = Nothing, numNewWork = emptyBag }
195
196
197
198 data CvtProp = CvtClass Class [Type]
199 | CvtCo Type Type
200
201 fromProp :: Prop -> TcS CvtProp
202 fromProp (Leq t1 t2) =
203 do cl <- tcsLookupClass lessThanEqualClassName
204 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
205
206 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
207
208 fromProp (EqFun op t1 t2 t3) =
209 do tc <- tcsLookupTyCon $ case op of
210 Add -> addTyFamName
211 Mul -> mulTyFamName
212 Exp -> expTyFamName
213 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
214
215
216 newSubGoal :: CvtProp -> TcS EvVar
217 newSubGoal (CvtClass c ts) = newDictVar c ts
218 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
219
220 newFact :: CvtProp -> TcS EvVar
221 newFact prop =
222 do d <- newSubGoal prop
223 defineDummy d prop
224 return d
225
226
227 -- If we decided that we want to generate evidence terms,
228 -- here we would set the evidence properly. For now, we skip this
229 -- step because evidence terms are not used for anything, and they
230 -- get quite large, at least, if we start with a small set of axioms.
231 defineDummy :: EvVar -> CvtProp -> TcS ()
232 defineDummy d (CvtClass c ts) =
233 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
234
235 defineDummy c (CvtCo t1 t2) =
236 setWantedCoBind c $ mkUnsafeCoercion t1 t2
237
238
239
240
241 --------------------------------------------------------------------------------
242 -- The Solver
243 --------------------------------------------------------------------------------
244
245
246 data Result = Impossible -- We know that the goal cannot be solved.
247 | Simplified [Prop] -- We reformulated the goal.
248 | Improved [Prop] -- We learned some new facts.
249
250 solve :: [Prop] -> Prop -> Result
251 solve asmps (Leq a b) =
252 let ps = propsToOrd asmps
253 in case isLeq ps a b of
254 True -> Simplified []
255 False ->
256 case improveLeq ps a b of
257 Nothing -> Impossible
258 Just ps -> Improved ps
259
260 solve asmps prop =
261 case solveWanted1 prop of
262 Just sgs -> Simplified sgs
263 Nothing ->
264 case msum $ zipWith solveWanted2 asmps (repeat prop) of
265 Just sgs -> Simplified sgs
266 Nothing ->
267 case improve1 prop of
268 Nothing -> Impossible
269 Just eqs ->
270 case zipWithM improve2 asmps (repeat prop) of
271 Nothing -> Impossible
272 Just eqs1 -> Improved (concat (eqs : eqs1))
273
274
275
276
277 improve1 :: Prop -> Maybe [Prop]
278 improve1 prop =
279 case prop of
280
281 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
282 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
283 EqFun Add (Num m _) s (Num n _)
284 | m <= n -> Just [ Eq (num (n-m)) s ]
285 | otherwise -> Nothing
286 EqFun Add r s t
287 | r == t -> Just [ Eq (num 0) s ]
288 | s == t -> Just [ Eq (num 0) r ]
289
290 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
291 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
292 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
293 EqFun Mul (Num _ _) s t
294 | s == t -> Just [ Eq (num 0) s ]
295
296 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
297 EqFun Mul (Num m _) s (Num n _)
298 | Just a <- divide n m -> Just [ Eq (num a) s ]
299 | otherwise -> Nothing
300
301
302 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
303
304 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
305 EqFun Exp (Num _ _) s t
306 | s == t -> Nothing
307 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
308 return [ Eq (num a) s ]
309
310 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
311 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
312 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
313 return [ Eq (num a) r ]
314
315 _ -> Just []
316
317 improve2 :: Prop -> Prop -> Maybe [ Prop ]
318 improve2 asmp prop =
319 case asmp of
320
321 EqFun Add a1 b1 c1 ->
322 case prop of
323 EqFun Add a2 b2 c2
324 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
325 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
326 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
327 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
328 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
329 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
330 _ -> Just []
331
332
333 EqFun Mul a1 b1 c1 ->
334 case prop of
335 EqFun Mul a2 b2 c2
336 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
337 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
338 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
339 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
340 _ -> Just []
341
342
343 _ -> Just []
344
345
346 solveWanted1 :: Prop -> Maybe [ Prop ]
347 solveWanted1 prop =
348 case prop of
349
350 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
351 EqFun Add (Num 0 _) s t | s == t -> Just []
352 EqFun Add r s t | r == s ->
353 Just [ EqFun Mul (num 2) r t ]
354
355 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
356 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
357 EqFun Mul (Num 1 _) s t | s == t -> Just []
358 EqFun Mul r s t | r == s ->
359 Just [ EqFun Exp r (num 2) t ]
360
361 -- Simple normalization of commutative operators
362 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
363 Just [ EqFun op s r t ]
364
365 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
366 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
367 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
368 EqFun Exp r (Num 1 _) t | r == t -> Just []
369 EqFun Exp r (Num _ _) t | r == t ->
370 Just [Leq r (num 1)]
371
372 _ -> Nothing
373
374
375 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
376 solveWanted2 asmp prop =
377 case (asmp, prop) of
378
379 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
380 | op1 == op2 && t1 == t2 &&
381 ( r1 == r2 && s1 == s2
382 || commutative op1 && r1 == s2 && s1 == r2
383 ) -> Just []
384
385 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
386 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
387 else Just [ EqFun Add (num (n - m)) d b ]
388
389 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
390 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
391 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
392
393 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
394 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
395 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
396 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
397
398 _ -> Nothing
399
400
401 --------------------------------------------------------------------------------
402 -- Reasoning about ordering.
403 --------------------------------------------------------------------------------
404
405 -- This function assumes that the assumptions are acyclic.
406 isLeq :: [(Term, Term)] -> Term -> Term -> Bool
407 isLeq _ (Num 0 _) _ = True
408 isLeq _ (Num m _) (Num n _) = m <= n
409 isLeq _ a b | a == b = True
410 isLeq ps (Num m _) a = or [ isLeq ps b a | (Num x _, b) <- ps, m <= x ]
411 isLeq ps a (Num m _) = or [ isLeq ps a b | (b, Num x _) <- ps, x <= m ]
412 isLeq ps a b = or [ isLeq ps c b | (a',c) <- ps, a == a' ]
413
414 isGt :: [(Term, Term)] -> Term -> Term -> Bool
415 isGt _ (Num m _) (Num n _) = m > n
416 isGt ps (Num m _) a = (m > 0) && isLeq ps a (num (m - 1))
417 isGt ps a (Num m _) = isLeq ps (num (m + 1)) a
418 isGt _ _ _ = False
419
420 improveLeq :: [(Term,Term)] -> Term -> Term -> Maybe [Prop]
421 improveLeq ps a b | isLeq ps b a = Just [Eq a b]
422 | isGt ps a b = Nothing
423 | otherwise = Just []
424
425 -- Ordering constraints derived from numeric predicates.
426 -- We do not consider equlities because they should be substituted away.
427 propsToOrd :: [Prop] -> [(Term,Term)]
428 propsToOrd props = loop (step [] unconditional)
429 where
430 loop ps = let new = filter (`notElem` ps) (step ps conditional)
431 in if null new then ps else loop (new ++ ps)
432
433 step ps = nub . concatMap (toOrd ps)
434
435 isConditional (EqFun op _ _ _) = op == Mul || op == Exp
436 isConditional _ = False
437
438 (conditional,unconditional) = partition isConditional props
439
440 toOrd _ (Leq a b) = [(a,b)]
441 toOrd _ (Eq _ _) = [] -- Would lead to a cycle, should be subst. away
442 toOrd ps (EqFun op a b c) =
443 case op of
444 Add -> [(a,c),(b,c)]
445
446 Mul -> (guard (isLeq ps (num 1) a) >> return (b,c)) ++
447 (guard (isLeq ps (num 1) b) >> return (a,c))
448 Exp
449 | Num 0 _ <- a -> [(c, num 1)]
450 | a == c -> [(a, num 1)]
451 | otherwise -> (guard (isLeq ps (num 2) a) >> return (b,c)) ++
452 (guard (isLeq ps (num 1) b) >> return (a,c))
453
454 --------------------------------------------------------------------------------
455 -- Associativity (assumes commutative operators)
456 --------------------------------------------------------------------------------
457
458 {-
459 We try to compute new equalities by moving parens to the right:
460 if (a + b) + c = d then a + (b + c) = d
461 a + (c + b) = d
462 b + (a + c) = d
463 b + (c + a) = d
464 c + (a + b) = d
465 x + (b + a) = d
466
467 We use the rule when we already have a name for the intermediate sum
468 (RHS parens), or if it is a constant because then we learn a new
469 relationships between variables.
470
471 Example for how it can be a constnat:
472 (a + 2 = p, p + 3 = q) => a + 5 = q
473
474
475 impAssoc :: [Prop] -> Prop -> [Prop]
476 impAssoc ps (EqFun op x y z) | associative op = concatMap imps asmps
477 where
478 asmps = [ (a,b,c) | EqFun op' a b c <- ps, op == op' ]
479 imps (a,b,c) =
480
481 -- (a + b = cx, cx + y = z) => (a + by = z)
482 (guard (c == x) >> [ EqFun op a by z | by <- y_plus b ]) ++
483
484 -- (az + b = c, x + y = az) => (x + yb = c)
485 (guard (a == z) >> [ EqFun op x yb c | yb <- y_plus b ]) ++
486
487 -- or, the new fact could be a name for the RHS paren intermediate:
488 [ EqFun op v z e | (c',d,e) <- asmps, c == c' -- a + b + d = e
489 , v <- [ d | sameSum x y a b ] ++
490 [ a | sameSum x y b d ] ++
491 [ b | sameSum x y a d ]
492 ] ++
493 [ EqFun op v z c | (d,e,a') <- asmps, a == a' -- d + e + b = c
494 , v <- [ b | sameSum x y d e ] ++
495 [ d | sameSum x y e b ] ++
496 [ e | sameSum x y d b ]
497 ]
498
499 sameSum a b a' b' = (a == a' && b == b') || (a == b' && b == a')
500
501 y_plus (Num m _) | Num n _ <- y = doOp m n
502 y_plus b = [ w | (u,v,w) <- asmps, sameSum y b u v ]
503
504 doOp m n = case op of
505 Add -> [ num (m + n) ]
506 Mul -> [ num (m * n) ]
507 _ -> []
508
509 impAssoc _ _ = []
510
511 -}
512 --------------------------------------------------------------------------------
513 -- Descrete Math
514 --------------------------------------------------------------------------------
515
516 descreteRoot :: Integer -> Integer -> Maybe Integer
517 descreteRoot root num = search 0 num
518 where
519 search from to = let x = from + div (to - from) 2
520 a = x ^ root
521 in case compare a num of
522 EQ -> Just x
523 LT | x /= from -> search x to
524 GT | x /= to -> search from x
525 _ -> Nothing
526
527 descreteLog :: Integer -> Integer -> Maybe Integer
528 descreteLog _ 0 = Just 0
529 descreteLog base num | base == num = Just 1
530 descreteLog base num = case divMod num base of
531 (x,0) -> fmap (1+) (descreteLog base x)
532 _ -> Nothing
533
534 divide :: Integer -> Integer -> Maybe Integer
535 divide _ 0 = Nothing
536 divide x y = case divMod x y of
537 (a,0) -> Just a
538 _ -> Nothing
539
540
541 --------------------------------------------------------------------------------
542 -- Debugging
543 --------------------------------------------------------------------------------
544
545 numTrace :: String -> SDoc -> TcS ()
546 numTrace x y = traceTcS ("[numerics] " ++ x) y
547
548 vmany :: Outputable a => [a] -> SDoc
549 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
550
551 instance Outputable Term where
552 ppr (Var xi) = ppr xi
553 ppr (Num n _) = integer n
554
555 instance Outputable Prop where
556 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
557 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
558 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
559
560 instance Outputable Op where
561 ppr op = case op of
562 Add -> char '+'
563 Mul -> char '*'
564 Exp -> char '^'
565