Merge remote branch 'origin/master' into type-nats
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats ( canonicalNum, NumericsResult(..) ) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , getWantedLoc
8 , tcsLookupTyCon, tcsLookupClass
9 , CanonicalCt (..), CanonicalCts
10 , mkFrozenError
11 , traceTcS
12 )
13 import TcRnTypes ( CtFlavor(..) )
14 import TcCanonical (mkCanonicals)
15 import HsBinds (EvTerm(..))
16 import Class ( Class, className, classTyCon )
17 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
18 , tcEqType, tcCmpType, pprType
19 )
20 import TypeRep (Type(..))
21 import TyCon (TyCon, tyConName)
22 import Var (EvVar)
23 import Coercion ( mkUnsafeCoercion )
24 import Outputable
25 import PrelNames ( lessThanEqualClassName
26 , addTyFamName, mulTyFamName, expTyFamName
27 )
28 import Bag (bagToList, emptyBag, unionManyBags, unionBags)
29 import Data.Maybe (fromMaybe)
30 import Data.List (nub, partition)
31 import Control.Monad (msum, mplus, zipWithM, (<=<), guard)
32
33 import Debug.Trace
34 import Unique(getUnique)
35 import Type(getTyVar)
36
37
38 data Term = Var Xi | Num Integer (Maybe Xi)
39 data Op = Add | Mul | Exp deriving Eq
40 data Prop = EqFun Op Term Term Term
41 | Leq Term Term
42 | Eq Term Term
43
44 commutative :: Op -> Bool
45 commutative op = op == Add || op == Mul
46
47 associative :: Op -> Bool
48 associative op = op == Add || op == Mul
49
50 num :: Integer -> Term
51 num n = Num n Nothing
52
53 instance Eq Term where
54 Var x == Var y = tcEqType x y
55 Num x _ == Num y _ = x == y
56 _ == _ = False
57
58 instance Ord Term where
59 compare (Num x _) (Num y _) = compare x y
60 compare (Num _ _) (Var _) = LT
61 compare (Var x) (Var y) = tcCmpType x y
62 compare (Var _) (Num _ _) = GT
63
64
65 --------------------------------------------------------------------------------
66 -- Interface with the type checker
67 --------------------------------------------------------------------------------
68
69 data NumericsResult = NumericsResult
70 { numNewWork :: CanonicalCts
71 , numInert :: Maybe CanonicalCts -- Nothing for "no change"
72 , numNext :: Maybe CanonicalCt
73 }
74
75 -- We keep the original type in numeric constants to preserve type synonyms.
76 toTerm :: Xi -> Term
77 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
78 Just n -> Num n (Just xi)
79 _ -> Var xi
80
81 fromTerm :: Term -> Xi
82 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
83 fromTerm (Var xi) = xi
84
85 toProp :: CanonicalCt -> Prop
86 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
87 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
88
89 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
90 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
91 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
92 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
93
94 where t1 = toTerm xi11
95 t2 = toTerm xi12
96 t3 = toTerm xi2
97
98 toProp p = panic $
99 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
100
101
102 canonicalNum :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
103 TcS NumericsResult
104 canonicalNum given derived wanted prop =
105 case cc_flavor prop of
106 Wanted {} -> solveNumWanted given derived wanted prop
107 Derived {} -> addNumDerived given derived wanted prop
108 Given {} -> addNumGiven given derived wanted prop
109
110
111 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
112 TcS NumericsResult
113 solveNumWanted given derived wanted prop =
114 do let asmps = map toProp $ bagToList $ unionManyBags [given,derived,wanted]
115 goal = toProp prop
116
117 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
118
119 case solve asmps goal of
120
121 Simplified sgs ->
122 do numTrace "Simplified to" (vmany sgs)
123 defineDummy (cc_id prop) =<< fromProp goal
124 evs <- mapM (newSubGoal <=< fromProp) sgs
125 goals <- mkCanonicals (cc_flavor prop) evs
126 return NumericsResult
127 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
128
129 -- XXX: The new wanted might imply some of the existing wanteds...
130 Improved is ->
131 do numTrace "Improved by" (vmany is)
132 evs <- mapM (newSubGoal <=< fromProp) is
133 goals <- mkCanonicals (cc_flavor prop) evs
134 return NumericsResult
135 { numNext = Just prop, numInert = Nothing, numNewWork = goals }
136
137 Impossible -> impossible prop
138
139
140 -- XXX: Need to understand derived work better.
141 addNumDerived :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
142 TcS NumericsResult
143 addNumDerived given derived wanted prop =
144 do let asmps = map toProp $ bagToList given
145 goal = toProp prop
146
147 numTrace "addNumDerived" (vmany asmps <+> text "|-" <+> ppr goal)
148
149 case solve asmps goal of
150
151 Simplified sgs ->
152 do numTrace "Simplified to" (vmany sgs)
153 defineDummy (cc_id prop) =<< fromProp goal
154 evs <- mapM (newSubGoal <=< fromProp) sgs
155 goals <- mkCanonicals (cc_flavor prop) evs
156 return NumericsResult
157 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
158
159 -- XXX: watch out for cycles because of the wanteds being restarted:
160 -- W => D && D => W, we could solve W by W
161 -- if W <=> D, then we should simplify W to D, not make it derived.
162 Improved is ->
163 do numTrace "Improved by" (vmany is)
164 evs <- mapM (newSubGoal <=< fromProp) is
165 goals <- mkCanonicals (Derived (getWantedLoc prop)) evs
166 return NumericsResult
167 { numNext = Just prop, numInert = Just (unionBags given derived)
168 , numNewWork = unionBags wanted goals }
169
170 Impossible -> impossible prop
171
172
173 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
174 TcS NumericsResult
175 addNumGiven given derived wanted prop =
176 do let asmps = map toProp (bagToList given)
177 goal = toProp prop
178
179 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
180 case solve asmps goal of
181
182 Simplified sgs ->
183 do numTrace "Simplified to" (vmany sgs)
184 evs <- mapM (newFact <=< fromProp) sgs
185 facts <- mkCanonicals (cc_flavor prop) evs
186 return NumericsResult
187 { numNext = Nothing, numInert = Nothing, numNewWork = facts }
188
189 Improved is ->
190 do numTrace "Improved by" (vmany is)
191 evs <- mapM (newFact <=< fromProp) is
192 facts <- mkCanonicals (cc_flavor prop) evs
193 return NumericsResult
194 { numNext = Just prop, numInert = Just (unionBags given derived)
195 , numNewWork = unionBags wanted facts }
196
197 Impossible -> impossible prop
198
199 impossible :: CanonicalCt -> TcS NumericsResult
200 impossible c =
201 do numTrace "Impossible" empty
202 let err = mkFrozenError (cc_flavor c) (cc_id c)
203 return NumericsResult
204 { numNext = Just err, numInert = Nothing, numNewWork = emptyBag }
205
206
207
208 data CvtProp = CvtClass Class [Type]
209 | CvtCo Type Type
210
211 fromProp :: Prop -> TcS CvtProp
212 fromProp (Leq t1 t2) =
213 do cl <- tcsLookupClass lessThanEqualClassName
214 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
215
216 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
217
218 fromProp (EqFun op t1 t2 t3) =
219 do tc <- tcsLookupTyCon $ case op of
220 Add -> addTyFamName
221 Mul -> mulTyFamName
222 Exp -> expTyFamName
223 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
224
225
226 newSubGoal :: CvtProp -> TcS EvVar
227 newSubGoal (CvtClass c ts) = newDictVar c ts
228 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
229
230 newFact :: CvtProp -> TcS EvVar
231 newFact prop =
232 do d <- newSubGoal prop
233 defineDummy d prop
234 return d
235
236
237 -- If we decided that we want to generate evidence terms,
238 -- here we would set the evidence properly. For now, we skip this
239 -- step because evidence terms are not used for anything, and they
240 -- get quite large, at least, if we start with a small set of axioms.
241 defineDummy :: EvVar -> CvtProp -> TcS ()
242 defineDummy d (CvtClass c ts) =
243 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
244
245 defineDummy c (CvtCo t1 t2) =
246 setWantedCoBind c $ mkUnsafeCoercion t1 t2
247
248
249
250
251 --------------------------------------------------------------------------------
252 -- The Solver
253 --------------------------------------------------------------------------------
254
255
256 data Result = Impossible -- We know that the goal cannot be solved.
257 | Simplified [Prop] -- We reformulated the goal.
258 | Improved [Prop] -- We learned some new facts.
259
260 solve :: [Prop] -> Prop -> Result
261 solve asmps (Leq a b) =
262 let ps = propsToOrd asmps
263 in case isLeq ps a b of
264 True -> Simplified []
265 False ->
266 case improveLeq ps a b of
267 Nothing -> Impossible
268 Just ps -> Improved ps
269
270 solve asmps prop =
271 case solveWanted1 prop of
272 Just sgs -> Simplified sgs
273 Nothing ->
274 case msum $ zipWith solveWanted2 asmps (repeat prop) of
275 Just sgs -> Simplified sgs
276 Nothing
277 | solveAC asmps prop -> Simplified []
278 | otherwise ->
279 case improve1 prop of
280 Nothing -> Impossible
281 Just eqs ->
282 case zipWithM improve2 asmps (repeat prop) of
283 Nothing -> Impossible
284 Just eqs1 -> Improved (concat (eqs : eqs1))
285
286
287
288
289 improve1 :: Prop -> Maybe [Prop]
290 improve1 prop =
291 case prop of
292
293 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
294 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
295 EqFun Add (Num m _) s (Num n _)
296 | m <= n -> Just [ Eq (num (n-m)) s ]
297 | otherwise -> Nothing
298 EqFun Add r s (Num 0 _) -> Just [ Eq (num 0) r, Eq (num 0) s ]
299 EqFun Add r s t
300 | r == t -> Just [ Eq (num 0) s ]
301 | s == t -> Just [ Eq (num 0) r ]
302
303 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
304 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
305 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
306 EqFun Mul (Num _ _) s t
307 | s == t -> Just [ Eq (num 0) s ]
308
309 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
310 EqFun Mul (Num m _) s (Num n _)
311 | Just a <- divide n m -> Just [ Eq (num a) s ]
312 | otherwise -> Nothing
313
314
315 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
316
317 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
318 EqFun Exp (Num _ _) s t
319 | s == t -> Nothing
320 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
321 return [ Eq (num a) s ]
322
323 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
324 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
325 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
326 return [ Eq (num a) r ]
327
328 _ -> Just []
329
330 improve2 :: Prop -> Prop -> Maybe [ Prop ]
331 improve2 asmp prop =
332 case asmp of
333
334 EqFun Add a1 b1 c1 ->
335 case prop of
336 EqFun Add a2 b2 c2
337 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
338 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
339 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
340 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
341 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
342 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
343 _ -> Just []
344
345
346 EqFun Mul a1 b1 c1 ->
347 case prop of
348 EqFun Mul a2 b2 c2
349 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
350 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
351 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
352 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
353 _ -> Just []
354
355
356 _ -> Just []
357
358
359 solveWanted1 :: Prop -> Maybe [ Prop ]
360 solveWanted1 prop =
361 case prop of
362
363 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
364 EqFun Add (Num 0 _) s t | s == t -> Just []
365 EqFun Add r s t | r == s ->
366 Just [ EqFun Mul (num 2) r t ]
367
368 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
369 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
370 EqFun Mul (Num 1 _) s t | s == t -> Just []
371 EqFun Mul r s t | r == s ->
372 Just [ EqFun Exp r (num 2) t ]
373
374 -- Simple normalization of commutative operators
375 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
376 Just [ EqFun op s r t ]
377
378 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
379 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
380 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
381 EqFun Exp r (Num 1 _) t | r == t -> Just []
382 EqFun Exp r (Num _ _) t | r == t ->
383 Just [Leq r (num 1)]
384
385 _ -> Nothing
386
387
388 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
389 solveWanted2 asmp prop =
390 case (asmp, prop) of
391
392 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
393 | op1 == op2 && t1 == t2 &&
394 ( r1 == r2 && s1 == s2
395 || commutative op1 && r1 == s2 && s1 == r2
396 ) -> Just []
397
398 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
399 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
400 else Just [ EqFun Add (num (n - m)) d b ]
401
402 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
403 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
404 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
405
406 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
407 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
408 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
409 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
410
411 _ -> Nothing
412
413
414 --------------------------------------------------------------------------------
415 -- Reasoning about ordering.
416 --------------------------------------------------------------------------------
417
418 -- This function assumes that the assumptions are acyclic.
419 isLeq :: [(Term, Term)] -> Term -> Term -> Bool
420 isLeq _ (Num 0 _) _ = True
421 isLeq _ (Num m _) (Num n _) = m <= n
422 isLeq _ a b | a == b = True
423 isLeq ps (Num m _) a = or [ isLeq ps b a | (Num x _, b) <- ps, m <= x ]
424 isLeq ps a (Num m _) = or [ isLeq ps a b | (b, Num x _) <- ps, x <= m ]
425 isLeq ps a b = or [ isLeq ps c b | (a',c) <- ps, a == a' ]
426
427 isGt :: [(Term, Term)] -> Term -> Term -> Bool
428 isGt _ (Num m _) (Num n _) = m > n
429 isGt ps (Num m _) a = (m > 0) && isLeq ps a (num (m - 1))
430 isGt ps a (Num m _) = isLeq ps (num (m + 1)) a
431 isGt _ _ _ = False
432
433 improveLeq :: [(Term,Term)] -> Term -> Term -> Maybe [Prop]
434 improveLeq ps a b | isLeq ps b a = Just [Eq a b]
435 | isGt ps a b = Nothing
436 | otherwise = Just []
437
438 -- Ordering constraints derived from numeric predicates.
439 -- We do not consider equlities because they should be substituted away.
440 propsToOrd :: [Prop] -> [(Term,Term)]
441 propsToOrd props = loop (step [] unconditional)
442 where
443 loop ps = let new = filter (`notElem` ps) (step ps conditional)
444 in if null new then ps else loop (new ++ ps)
445
446 step ps = nub . concatMap (toOrd ps)
447
448 isConditional (EqFun op _ _ _) = op == Mul || op == Exp
449 isConditional _ = False
450
451 (conditional,unconditional) = partition isConditional props
452
453 toOrd _ (Leq a b) = [(a,b)]
454 toOrd _ (Eq _ _) = [] -- Would lead to a cycle, should be subst. away
455 toOrd ps (EqFun op a b c) =
456 case op of
457 Add -> [(a,c),(b,c)]
458
459 Mul -> (guard (isLeq ps (num 1) a) >> return (b,c)) ++
460 (guard (isLeq ps (num 1) b) >> return (a,c))
461 Exp
462 | Num 0 _ <- a -> [(c, num 1)]
463 | a == c -> [(a, num 1)]
464 | otherwise -> (guard (isLeq ps (num 2) a) >> return (b,c)) ++
465 (guard (isLeq ps (num 1) b) >> return (a,c))
466
467 --------------------------------------------------------------------------------
468 -- Associative and Commutative Operators
469 --------------------------------------------------------------------------------
470
471 -- XXX: recursion may non-terminate: x * y = x
472 -- XXX: does not do improvements
473
474
475
476 solveAC :: [Prop] -> Prop -> Bool
477 solveAC ps (EqFun op x y z)
478 | commutative op && associative op =
479 (xs_ys === z) || or [ add as xs_ys === r | (as,r) <- cancelCands z ]
480
481 where
482 xs_ys = add (sums x) (sums y)
483
484 candidates c = [ (a,b) | (a,b,c') <- asmps, c == c' ]
485
486 -- (xs,e) `elem` cancelCands g ==> xs + g = e
487 cancelCands :: Term -> [([[Term]],Term)]
488 cancelCands g = do (a,b,c) <- asmps
489 let us = filter (all mayCancel)
490 $ case () of
491 _ | a == g -> sums b
492 | b == g -> sums a
493 | otherwise -> [ ]
494 guard (not (null us))
495 (us,c) : [ (add us vs, e) | (vs,e) <- cancelCands c ]
496
497
498 mayCancel x = case op of
499 Add -> True
500 Mul -> isLeq ordProps (num 1) x
501 Exp -> False
502
503 -- xs `elem` sums a ==> sum xs = a
504 sums a = [a] : [ p | (b,c) <- candidates a, p <- add (sums b) (sums c)]
505
506 add :: [[Term]] -> [[Term]] -> [[Term]]
507 add as bs = [ merge u v | u <- as, v <- bs ]
508
509 (===) :: [[Term]] -> Term -> Bool
510 as === b = any (`elem` sums b) as
511
512 -- Facts in a more convenient from
513 asmps = [ (a,b,c) | EqFun op' a b c <- ps, op == op' ]
514 ordProps = propsToOrd ps
515
516
517 solveAC _ _ = False
518
519
520 merge :: Ord a => [a] -> [a] -> [a]
521 merge xs@(a:as) ys@(b:bs)
522 | a <= b = a : merge as ys
523 | otherwise = b : merge xs bs
524 merge [] ys = ys
525 merge xs [] = xs
526
527
528
529 --------------------------------------------------------------------------------
530 -- Descrete Math
531 --------------------------------------------------------------------------------
532
533 descreteRoot :: Integer -> Integer -> Maybe Integer
534 descreteRoot root num = search 0 num
535 where
536 search from to = let x = from + div (to - from) 2
537 a = x ^ root
538 in case compare a num of
539 EQ -> Just x
540 LT | x /= from -> search x to
541 GT | x /= to -> search from x
542 _ -> Nothing
543
544 descreteLog :: Integer -> Integer -> Maybe Integer
545 descreteLog _ 0 = Just 0
546 descreteLog base num | base == num = Just 1
547 descreteLog base num = case divMod num base of
548 (x,0) -> fmap (1+) (descreteLog base x)
549 _ -> Nothing
550
551 divide :: Integer -> Integer -> Maybe Integer
552 divide _ 0 = Nothing
553 divide x y = case divMod x y of
554 (a,0) -> Just a
555 _ -> Nothing
556
557
558 --------------------------------------------------------------------------------
559 -- Debugging
560 --------------------------------------------------------------------------------
561
562 numTrace :: String -> SDoc -> TcS ()
563 numTrace x y = traceTcS ("[numerics] " ++ x) y
564
565 pNumTrace :: String -> SDoc -> a -> a
566 pNumTrace x y = trace ("[numerics] " ++ x ++ " " ++ showSDoc y)
567
568 vmany :: Outputable a => [a] -> SDoc
569 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
570
571 instance Outputable Term where
572 ppr (Var xi) = pprType xi <> un
573 where un = brackets $ text $ show $ getUnique (getTyVar "numerics dbg" xi)
574 ppr (Num n _) = integer n
575
576 instance Outputable Prop where
577 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
578 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
579 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
580
581 instance Outputable Op where
582 ppr op = case op of
583 Add -> char '+'
584 Mul -> char '*'
585 Exp -> char '^'
586