Yet another variation on associativity.
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats ( canonicalNum, NumericsResult(..) ) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , getWantedLoc
8 , tcsLookupTyCon, tcsLookupClass
9 , CanonicalCt (..), CanonicalCts
10 , mkFrozenError
11 , traceTcS
12 )
13 import TcRnTypes ( CtFlavor(..) )
14 import TcCanonical (mkCanonicals)
15 import HsBinds (EvTerm(..))
16 import Class ( Class, className, classTyCon )
17 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
18 , tcEqType, tcCmpType )
19 import TypeRep (Type(..))
20 import TyCon (TyCon, tyConName)
21 import Var (EvVar)
22 import Coercion ( mkUnsafeCoercion )
23 import Outputable
24 import PrelNames ( lessThanEqualClassName
25 , addTyFamName, mulTyFamName, expTyFamName
26 )
27 import Bag (bagToList, emptyBag, unionManyBags, unionBags)
28 import Data.Maybe (fromMaybe)
29 import Data.List (nub, partition)
30 import Control.Monad (msum, mplus, zipWithM, (<=<), guard)
31
32
33 data Term = Var Xi | Num Integer (Maybe Xi)
34 data Op = Add | Mul | Exp deriving Eq
35 data Prop = EqFun Op Term Term Term
36 | Leq Term Term
37 | Eq Term Term
38
39 commutative :: Op -> Bool
40 commutative op = op == Add || op == Mul
41
42 associative :: Op -> Bool
43 associative op = op == Add || op == Mul
44
45 num :: Integer -> Term
46 num n = Num n Nothing
47
48 instance Eq Term where
49 Var x == Var y = tcEqType x y
50 Num x _ == Num y _ = x == y
51 _ == _ = False
52
53 instance Ord Term where
54 compare (Num x _) (Num y _) = compare x y
55 compare (Num _ _) (Var _) = LT
56 compare (Var x) (Var y) = tcCmpType x y
57 compare (Var _) (Num _ _) = GT
58
59
60 --------------------------------------------------------------------------------
61 -- Interface with the type checker
62 --------------------------------------------------------------------------------
63
64 data NumericsResult = NumericsResult
65 { numNewWork :: CanonicalCts
66 , numInert :: Maybe CanonicalCts -- Nothing for "no change"
67 , numNext :: Maybe CanonicalCt
68 }
69
70 -- We keep the original type in numeric constants to preserve type synonyms.
71 toTerm :: Xi -> Term
72 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
73 Just n -> Num n (Just xi)
74 _ -> Var xi
75
76 fromTerm :: Term -> Xi
77 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
78 fromTerm (Var xi) = xi
79
80 toProp :: CanonicalCt -> Prop
81 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
82 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
83
84 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
85 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
86 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
87 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
88
89 where t1 = toTerm xi11
90 t2 = toTerm xi12
91 t3 = toTerm xi2
92
93 toProp p = panic $
94 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
95
96
97 canonicalNum :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
98 TcS NumericsResult
99 canonicalNum given derived wanted prop =
100 case cc_flavor prop of
101 Wanted {} -> solveNumWanted given derived wanted prop
102 Derived {} -> addNumDerived given derived wanted prop
103 Given {} -> addNumGiven given derived wanted prop
104
105
106 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
107 TcS NumericsResult
108 solveNumWanted given derived wanted prop =
109 do let asmps = map toProp $ bagToList $ unionManyBags [given,derived,wanted]
110 goal = toProp prop
111
112 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
113
114 case solve asmps goal of
115
116 Simplified sgs ->
117 do numTrace "Simplified to" (vmany sgs)
118 defineDummy (cc_id prop) =<< fromProp goal
119 evs <- mapM (newSubGoal <=< fromProp) sgs
120 goals <- mkCanonicals (cc_flavor prop) evs
121 return NumericsResult
122 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
123
124 -- XXX: The new wanted might imply some of the existing wanteds...
125 Improved is ->
126 do numTrace "Improved by" (vmany is)
127 evs <- mapM (newSubGoal <=< fromProp) is
128 goals <- mkCanonicals (cc_flavor prop) evs
129 return NumericsResult
130 { numNext = Just prop, numInert = Nothing, numNewWork = goals }
131
132 Impossible -> impossible prop
133
134
135 -- XXX: Need to understand derived work better.
136 addNumDerived :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
137 TcS NumericsResult
138 addNumDerived given derived wanted prop =
139 do let asmps = map toProp $ bagToList given
140 goal = toProp prop
141
142 numTrace "addNumDerived" (vmany asmps <+> text "|-" <+> ppr goal)
143
144 case solve asmps goal of
145
146 Simplified sgs ->
147 do numTrace "Simplified to" (vmany sgs)
148 defineDummy (cc_id prop) =<< fromProp goal
149 evs <- mapM (newSubGoal <=< fromProp) sgs
150 goals <- mkCanonicals (cc_flavor prop) evs
151 return NumericsResult
152 { numNext = Nothing, numInert = Nothing, numNewWork = goals }
153
154 -- XXX: watch out for cycles because of the wanteds being restarted:
155 -- W => D && D => W, we could solve W by W
156 -- if W <=> D, then we should simplify W to D, not make it derived.
157 Improved is ->
158 do numTrace "Improved by" (vmany is)
159 evs <- mapM (newSubGoal <=< fromProp) is
160 goals <- mkCanonicals (Derived (getWantedLoc prop)) evs
161 return NumericsResult
162 { numNext = Just prop, numInert = Just (unionBags given derived)
163 , numNewWork = unionBags wanted goals }
164
165 Impossible -> impossible prop
166
167
168 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCts -> CanonicalCt ->
169 TcS NumericsResult
170 addNumGiven given derived wanted prop =
171 do let asmps = map toProp (bagToList given)
172 goal = toProp prop
173
174 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
175 case solve asmps goal of
176
177 Simplified sgs ->
178 do numTrace "Simplified to" (vmany sgs)
179 evs <- mapM (newFact <=< fromProp) sgs
180 facts <- mkCanonicals (cc_flavor prop) evs
181 return NumericsResult
182 { numNext = Nothing, numInert = Nothing, numNewWork = facts }
183
184 Improved is ->
185 do numTrace "Improved by" (vmany is)
186 evs <- mapM (newFact <=< fromProp) is
187 facts <- mkCanonicals (cc_flavor prop) evs
188 return NumericsResult
189 { numNext = Just prop, numInert = Just (unionBags given derived)
190 , numNewWork = unionBags wanted facts }
191
192 Impossible -> impossible prop
193
194 impossible :: CanonicalCt -> TcS NumericsResult
195 impossible c =
196 do numTrace "Impossible" empty
197 let err = mkFrozenError (cc_flavor c) (cc_id c)
198 return NumericsResult
199 { numNext = Just err, numInert = Nothing, numNewWork = emptyBag }
200
201
202
203 data CvtProp = CvtClass Class [Type]
204 | CvtCo Type Type
205
206 fromProp :: Prop -> TcS CvtProp
207 fromProp (Leq t1 t2) =
208 do cl <- tcsLookupClass lessThanEqualClassName
209 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
210
211 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
212
213 fromProp (EqFun op t1 t2 t3) =
214 do tc <- tcsLookupTyCon $ case op of
215 Add -> addTyFamName
216 Mul -> mulTyFamName
217 Exp -> expTyFamName
218 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
219
220
221 newSubGoal :: CvtProp -> TcS EvVar
222 newSubGoal (CvtClass c ts) = newDictVar c ts
223 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
224
225 newFact :: CvtProp -> TcS EvVar
226 newFact prop =
227 do d <- newSubGoal prop
228 defineDummy d prop
229 return d
230
231
232 -- If we decided that we want to generate evidence terms,
233 -- here we would set the evidence properly. For now, we skip this
234 -- step because evidence terms are not used for anything, and they
235 -- get quite large, at least, if we start with a small set of axioms.
236 defineDummy :: EvVar -> CvtProp -> TcS ()
237 defineDummy d (CvtClass c ts) =
238 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
239
240 defineDummy c (CvtCo t1 t2) =
241 setWantedCoBind c $ mkUnsafeCoercion t1 t2
242
243
244
245
246 --------------------------------------------------------------------------------
247 -- The Solver
248 --------------------------------------------------------------------------------
249
250
251 data Result = Impossible -- We know that the goal cannot be solved.
252 | Simplified [Prop] -- We reformulated the goal.
253 | Improved [Prop] -- We learned some new facts.
254
255 solve :: [Prop] -> Prop -> Result
256 solve asmps (Leq a b) =
257 let ps = propsToOrd asmps
258 in case isLeq ps a b of
259 True -> Simplified []
260 False ->
261 case improveLeq ps a b of
262 Nothing -> Impossible
263 Just ps -> Improved ps
264
265 solve asmps prop =
266 case solveWanted1 prop of
267 Just sgs -> Simplified sgs
268 Nothing ->
269 case msum $ zipWith solveWanted2 asmps (repeat prop) of
270 Just sgs -> Simplified sgs
271 Nothing
272 | solveAC asmps prop -> Simplified []
273 | otherwise ->
274 case improve1 prop of
275 Nothing -> Impossible
276 Just eqs ->
277 case zipWithM improve2 asmps (repeat prop) of
278 Nothing -> Impossible
279 Just eqs1 -> Improved (concat (eqs : eqs1))
280
281
282
283
284 improve1 :: Prop -> Maybe [Prop]
285 improve1 prop =
286 case prop of
287
288 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
289 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
290 EqFun Add (Num m _) s (Num n _)
291 | m <= n -> Just [ Eq (num (n-m)) s ]
292 | otherwise -> Nothing
293 EqFun Add r s (Num 0 _) -> Just [ Eq (num 0) r, Eq (num 0) s ]
294 EqFun Add r s t
295 | r == t -> Just [ Eq (num 0) s ]
296 | s == t -> Just [ Eq (num 0) r ]
297
298 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
299 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
300 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
301 EqFun Mul (Num _ _) s t
302 | s == t -> Just [ Eq (num 0) s ]
303
304 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
305 EqFun Mul (Num m _) s (Num n _)
306 | Just a <- divide n m -> Just [ Eq (num a) s ]
307 | otherwise -> Nothing
308
309
310 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
311
312 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
313 EqFun Exp (Num _ _) s t
314 | s == t -> Nothing
315 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
316 return [ Eq (num a) s ]
317
318 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
319 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
320 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
321 return [ Eq (num a) r ]
322
323 _ -> Just []
324
325 improve2 :: Prop -> Prop -> Maybe [ Prop ]
326 improve2 asmp prop =
327 case asmp of
328
329 EqFun Add a1 b1 c1 ->
330 case prop of
331 EqFun Add a2 b2 c2
332 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
333 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
334 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
335 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
336 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
337 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
338 _ -> Just []
339
340
341 EqFun Mul a1 b1 c1 ->
342 case prop of
343 EqFun Mul a2 b2 c2
344 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
345 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
346 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
347 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
348 _ -> Just []
349
350
351 _ -> Just []
352
353
354 solveWanted1 :: Prop -> Maybe [ Prop ]
355 solveWanted1 prop =
356 case prop of
357
358 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
359 EqFun Add (Num 0 _) s t | s == t -> Just []
360 EqFun Add r s t | r == s ->
361 Just [ EqFun Mul (num 2) r t ]
362
363 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
364 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
365 EqFun Mul (Num 1 _) s t | s == t -> Just []
366 EqFun Mul r s t | r == s ->
367 Just [ EqFun Exp r (num 2) t ]
368
369 -- Simple normalization of commutative operators
370 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
371 Just [ EqFun op s r t ]
372
373 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
374 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
375 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
376 EqFun Exp r (Num 1 _) t | r == t -> Just []
377 EqFun Exp r (Num _ _) t | r == t ->
378 Just [Leq r (num 1)]
379
380 _ -> Nothing
381
382
383 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
384 solveWanted2 asmp prop =
385 case (asmp, prop) of
386
387 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
388 | op1 == op2 && t1 == t2 &&
389 ( r1 == r2 && s1 == s2
390 || commutative op1 && r1 == s2 && s1 == r2
391 ) -> Just []
392
393 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
394 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
395 else Just [ EqFun Add (num (n - m)) d b ]
396
397 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
398 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
399 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
400
401 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
402 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
403 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
404 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
405
406 _ -> Nothing
407
408
409 --------------------------------------------------------------------------------
410 -- Reasoning about ordering.
411 --------------------------------------------------------------------------------
412
413 -- This function assumes that the assumptions are acyclic.
414 isLeq :: [(Term, Term)] -> Term -> Term -> Bool
415 isLeq _ (Num 0 _) _ = True
416 isLeq _ (Num m _) (Num n _) = m <= n
417 isLeq _ a b | a == b = True
418 isLeq ps (Num m _) a = or [ isLeq ps b a | (Num x _, b) <- ps, m <= x ]
419 isLeq ps a (Num m _) = or [ isLeq ps a b | (b, Num x _) <- ps, x <= m ]
420 isLeq ps a b = or [ isLeq ps c b | (a',c) <- ps, a == a' ]
421
422 isGt :: [(Term, Term)] -> Term -> Term -> Bool
423 isGt _ (Num m _) (Num n _) = m > n
424 isGt ps (Num m _) a = (m > 0) && isLeq ps a (num (m - 1))
425 isGt ps a (Num m _) = isLeq ps (num (m + 1)) a
426 isGt _ _ _ = False
427
428 improveLeq :: [(Term,Term)] -> Term -> Term -> Maybe [Prop]
429 improveLeq ps a b | isLeq ps b a = Just [Eq a b]
430 | isGt ps a b = Nothing
431 | otherwise = Just []
432
433 -- Ordering constraints derived from numeric predicates.
434 -- We do not consider equlities because they should be substituted away.
435 propsToOrd :: [Prop] -> [(Term,Term)]
436 propsToOrd props = loop (step [] unconditional)
437 where
438 loop ps = let new = filter (`notElem` ps) (step ps conditional)
439 in if null new then ps else loop (new ++ ps)
440
441 step ps = nub . concatMap (toOrd ps)
442
443 isConditional (EqFun op _ _ _) = op == Mul || op == Exp
444 isConditional _ = False
445
446 (conditional,unconditional) = partition isConditional props
447
448 toOrd _ (Leq a b) = [(a,b)]
449 toOrd _ (Eq _ _) = [] -- Would lead to a cycle, should be subst. away
450 toOrd ps (EqFun op a b c) =
451 case op of
452 Add -> [(a,c),(b,c)]
453
454 Mul -> (guard (isLeq ps (num 1) a) >> return (b,c)) ++
455 (guard (isLeq ps (num 1) b) >> return (a,c))
456 Exp
457 | Num 0 _ <- a -> [(c, num 1)]
458 | a == c -> [(a, num 1)]
459 | otherwise -> (guard (isLeq ps (num 2) a) >> return (b,c)) ++
460 (guard (isLeq ps (num 1) b) >> return (a,c))
461
462 --------------------------------------------------------------------------------
463 -- Associative and Commutative Operators
464 --------------------------------------------------------------------------------
465
466 -- XXX: recursion may non-terminate: x * y = x
467 -- XXX: does not do improvements
468
469
470
471 solveAC :: [Prop] -> Prop -> Bool
472 solveAC ps (EqFun op x y z)
473 {-
474 | trace (showSDoc ( hang (text "solveAC") 2 $ vcat
475 [ text "X:" <+> fsep (map ppr (sums x))
476 , text "Y:" <+> fsep (map ppr (sums y))
477 , text "Z:" <+> fsep (map ppr (sums z))
478 ]
479 )) False = undefined
480 -}
481 | commutative op && associative op =
482 any (`elem` sums z) (add (sums x) (sums y))
483
484 where
485 asmps = [ (a,b,c) | EqFun op' a b c <- ps, op == op' ]
486
487 candidates c = [ (a,b) | (a,b,c') <- asmps, c == c' ]
488
489 sums a = [a] : [ p | (b,c) <- candidates a, p <- add (sums b) (sums c)]
490
491 add as bs = [ merge u v | u <- as, v <- bs ]
492
493 solveAC _ _ = False
494
495
496 merge :: Ord a => [a] -> [a] -> [a]
497 merge xs@(a:as) ys@(b:bs)
498 | a <= b = a : merge as ys
499 | otherwise = b : merge xs bs
500 merge [] ys = ys
501 merge xs [] = xs
502
503
504
505 --------------------------------------------------------------------------------
506 -- Descrete Math
507 --------------------------------------------------------------------------------
508
509 descreteRoot :: Integer -> Integer -> Maybe Integer
510 descreteRoot root num = search 0 num
511 where
512 search from to = let x = from + div (to - from) 2
513 a = x ^ root
514 in case compare a num of
515 EQ -> Just x
516 LT | x /= from -> search x to
517 GT | x /= to -> search from x
518 _ -> Nothing
519
520 descreteLog :: Integer -> Integer -> Maybe Integer
521 descreteLog _ 0 = Just 0
522 descreteLog base num | base == num = Just 1
523 descreteLog base num = case divMod num base of
524 (x,0) -> fmap (1+) (descreteLog base x)
525 _ -> Nothing
526
527 divide :: Integer -> Integer -> Maybe Integer
528 divide _ 0 = Nothing
529 divide x y = case divMod x y of
530 (a,0) -> Just a
531 _ -> Nothing
532
533
534 --------------------------------------------------------------------------------
535 -- Debugging
536 --------------------------------------------------------------------------------
537
538 numTrace :: String -> SDoc -> TcS ()
539 numTrace x y = traceTcS ("[numerics] " ++ x) y
540
541 vmany :: Outputable a => [a] -> SDoc
542 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
543
544 instance Outputable Term where
545 ppr (Var xi) = ppr xi
546 ppr (Num n _) = integer n
547
548 instance Outputable Prop where
549 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
550 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
551 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
552
553 instance Outputable Op where
554 ppr op = case op of
555 Add -> char '+'
556 Mul -> char '*'
557 Exp -> char '^'
558