Improve reasoning about ordering to support transitivity and anti-symmetry.
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 {-# LANGUAGE PatternGuards #-}
2 module TcTypeNats (solveNumWanted, addNumGiven) where
3
4 import TcSMonad ( TcS, Xi
5 , newWantedCoVar
6 , setWantedCoBind, setDictBind, newDictVar
7 , tcsLookupTyCon, tcsLookupClass
8 , CanonicalCt (..), CanonicalCts
9 , traceTcS
10 )
11 import TcCanonical (mkCanonicals)
12 import HsBinds (EvTerm(..))
13 import Class ( Class, className, classTyCon )
14 import Type ( tcView, mkTyConApp, mkNumberTy, isNumberTy
15 , tcEqType )
16 import TypeRep (Type(..))
17 import TyCon (TyCon, tyConName)
18 import Var (EvVar)
19 import Coercion ( mkUnsafeCoercion )
20 import Outputable
21 import PrelNames ( lessThanEqualClassName
22 , addTyFamName, mulTyFamName, expTyFamName
23 )
24 import Bag (bagToList, emptyBag, unionBags)
25 import Data.Maybe (fromMaybe)
26 import Data.List (nub, partition)
27 import Control.Monad (msum, mplus, zipWithM, (<=<), guard)
28
29
30
31 data Term = Var Xi | Num Integer (Maybe Xi)
32 data Op = Add | Mul | Exp deriving Eq
33 data Prop = EqFun Op Term Term Term
34 | Leq Term Term
35 | Eq Term Term
36
37 commutative :: Op -> Bool
38 commutative op = op == Add || op == Mul
39
40 num :: Integer -> Term
41 num n = Num n Nothing
42
43 instance Eq Term where
44 Var x == Var y = tcEqType x y
45 Num x _ == Num y _ = x == y
46 _ == _ = False
47
48
49 --------------------------------------------------------------------------------
50 -- Interface with the type checker
51 --------------------------------------------------------------------------------
52
53
54 -- We keep the original type in numeric constants to preserve type synonyms.
55 toTerm :: Xi -> Term
56 toTerm xi = case mplus (isNumberTy xi) (isNumberTy =<< tcView xi) of
57 Just n -> Num n (Just xi)
58 _ -> Var xi
59
60 fromTerm :: Term -> Xi
61 fromTerm (Num n mb) = fromMaybe (mkNumberTy n) mb
62 fromTerm (Var xi) = xi
63
64 toProp :: CanonicalCt -> Prop
65 toProp (CDictCan { cc_class = c, cc_tyargs = [xi1,xi2] })
66 | className c == lessThanEqualClassName = Leq (toTerm xi1) (toTerm xi2)
67
68 toProp (CFunEqCan { cc_fun = tc, cc_tyargs = [xi11,xi12], cc_rhs = xi2 })
69 | tyConName tc == addTyFamName = EqFun Add t1 t2 t3
70 | tyConName tc == mulTyFamName = EqFun Mul t1 t2 t3
71 | tyConName tc == expTyFamName = EqFun Exp t1 t2 t3
72
73 where t1 = toTerm xi11
74 t2 = toTerm xi12
75 t3 = toTerm xi2
76
77 toProp p = panic $
78 "[TcTypeNats.toProp] Unexpected CanonicalCt: " ++ showSDoc (ppr p)
79
80
81 solveNumWanted :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
82 TcS (Maybe CanonicalCt, CanonicalCts)
83 solveNumWanted given wanted prop =
84 do let asmps = map toProp $ bagToList $ unionBags given wanted
85 goal = toProp prop
86
87 numTrace "solveNumWanted" (vmany asmps <+> text "|-" <+> ppr goal)
88
89 case solve asmps goal of
90
91 Simplified sgs ->
92 do numTrace "Simplified to" (vmany sgs)
93 defineDummy (cc_id prop) =<< fromProp goal
94 evs <- mapM (newSubGoal <=< fromProp) sgs
95 goals <- mkCanonicals (cc_flavor prop) evs
96 return (Nothing, goals)
97
98 -- XXX: The new wanted might imply some of the existing wanteds...
99 Improved is ->
100 do numTrace "Improved by" (vmany is)
101 evs <- mapM (newSubGoal <=< fromProp) is
102 goals <- mkCanonicals (cc_flavor prop) evs
103 return (Just prop, goals)
104
105 -- XXX: Report an error
106 Impossible ->
107 do numTrace "Impssoble" empty
108 return (Just prop, emptyBag)
109
110
111 addNumGiven :: CanonicalCts -> CanonicalCts -> CanonicalCt ->
112 TcS (Maybe CanonicalCt, CanonicalCts, CanonicalCts)
113 addNumGiven given wanted prop =
114 do let asmps = map toProp $ bagToList $ unionBags given wanted
115 goal = toProp prop
116
117 numTrace "addNumGiven" (vmany asmps <+> text " /\\ " <+> ppr goal)
118 case solve asmps goal of
119
120 Simplified sgs ->
121 do numTrace "Simplified to" (vmany sgs)
122 evs <- mapM (newFact <=< fromProp) sgs
123 facts <- mkCanonicals (cc_flavor prop) evs
124 return (Nothing, unionBags given wanted, facts)
125
126 Improved is ->
127 do numTrace "Improved by" (vmany is)
128 evs <- mapM (newFact <=< fromProp) is
129 facts <- mkCanonicals (cc_flavor prop) evs
130 return (Just prop, given, unionBags wanted facts)
131
132 -- XXX: report error
133 Impossible ->
134 do numTrace "Impossible" empty
135 return (Just prop, unionBags given wanted, emptyBag)
136
137
138
139 data CvtProp = CvtClass Class [Type]
140 | CvtCo Type Type
141
142 fromProp :: Prop -> TcS CvtProp
143 fromProp (Leq t1 t2) =
144 do cl <- tcsLookupClass lessThanEqualClassName
145 return (CvtClass cl [ fromTerm t1, fromTerm t2 ])
146
147 fromProp (Eq t1 t2) = return $ CvtCo (fromTerm t1) (fromTerm t2)
148
149 fromProp (EqFun op t1 t2 t3) =
150 do tc <- tcsLookupTyCon $ case op of
151 Add -> addTyFamName
152 Mul -> mulTyFamName
153 Exp -> expTyFamName
154 return $ CvtCo (mkTyConApp tc [fromTerm t1, fromTerm t2]) (fromTerm t3)
155
156
157 newSubGoal :: CvtProp -> TcS EvVar
158 newSubGoal (CvtClass c ts) = newDictVar c ts
159 newSubGoal (CvtCo t1 t2) = newWantedCoVar t1 t2
160
161 newFact :: CvtProp -> TcS EvVar
162 newFact prop =
163 do d <- newSubGoal prop
164 defineDummy d prop
165 return d
166
167
168 -- If we decided that we want to generate evidence terms,
169 -- here we would set the evidence properly. For now, we skip this
170 -- step because evidence terms are not used for anything, and they
171 -- get quite large, at least, if we start with a small set of axioms.
172 defineDummy :: EvVar -> CvtProp -> TcS ()
173 defineDummy d (CvtClass c ts) =
174 setDictBind d $ EvAxiom "<=" $ mkTyConApp (classTyCon c) ts
175
176 defineDummy c (CvtCo t1 t2) =
177 setWantedCoBind c $ mkUnsafeCoercion t1 t2
178
179
180
181
182 --------------------------------------------------------------------------------
183 -- The Solver
184 --------------------------------------------------------------------------------
185
186
187 data Result = Impossible -- We know that the goal cannot be solved.
188 | Simplified [Prop] -- We reformulated the goal.
189 | Improved [Prop] -- We learned some new facts.
190
191 solve :: [Prop] -> Prop -> Result
192 solve asmps (Leq a b) =
193 let ps = propsToOrd asmps
194 in case isLeq ps a b of
195 True -> Simplified []
196 False ->
197 case improveLeq ps a b of
198 Nothing -> Impossible
199 Just ps -> Improved ps
200
201 solve asmps prop =
202 case solveWanted1 prop of
203 Just sgs -> Simplified sgs
204 Nothing ->
205 case msum $ zipWith solveWanted2 asmps (repeat prop) of
206 Just sgs -> Simplified sgs
207 Nothing ->
208 case improve1 prop of
209 Nothing -> Impossible
210 Just eqs ->
211 case zipWithM improve2 asmps (repeat prop) of
212 Nothing -> Impossible
213 Just eqs1 -> Improved (concat (eqs : eqs1))
214
215
216
217
218 improve1 :: Prop -> Maybe [Prop]
219 improve1 prop =
220 case prop of
221
222 EqFun Add (Num m _) (Num n _) t -> Just [ Eq (num (m+n)) t ]
223 EqFun Add (Num 0 _) s t -> Just [ Eq s t ]
224 EqFun Add (Num m _) s (Num n _)
225 | m <= n -> Just [ Eq (num (n-m)) s ]
226 | otherwise -> Nothing
227 EqFun Add r s t
228 | r == t -> Just [ Eq (num 0) s ]
229 | s == t -> Just [ Eq (num 0) r ]
230
231 EqFun Mul (Num m _) (Num n _) t -> Just [ Eq (num (m*n)) t ]
232 EqFun Mul (Num 0 _) _ t -> Just [ Eq (num 0) t ]
233 EqFun Mul (Num 1 _) s t -> Just [ Eq s t ]
234 EqFun Mul (Num _ _) s t
235 | s == t -> Just [ Eq (num 0) s ]
236
237 EqFun Mul r s (Num 1 _) -> Just [ Eq (num 1) r, Eq (num 1) s ]
238 EqFun Mul (Num m _) s (Num n _)
239 | Just a <- divide n m -> Just [ Eq (num a) s ]
240 | otherwise -> Nothing
241
242
243 EqFun Exp (Num m _) (Num n _) t -> Just [ Eq (num (m ^ n)) t ]
244
245 EqFun Exp (Num 1 _) _ t -> Just [ Eq (num 1) t ]
246 EqFun Exp (Num _ _) s t
247 | s == t -> Nothing
248 EqFun Exp (Num m _) s (Num n _) -> do a <- descreteLog m n
249 return [ Eq (num a) s ]
250
251 EqFun Exp _ (Num 0 _) t -> Just [ Eq (num 1) t ]
252 EqFun Exp r (Num 1 _) t -> Just [ Eq r t ]
253 EqFun Exp r (Num m _) (Num n _) -> do a <- descreteRoot m n
254 return [ Eq (num a) r ]
255
256 _ -> Just []
257
258 improve2 :: Prop -> Prop -> Maybe [ Prop ]
259 improve2 asmp prop =
260 case asmp of
261
262 EqFun Add a1 b1 c1 ->
263 case prop of
264 EqFun Add a2 b2 c2
265 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
266 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
267 | c1 == c2 && a1 == a2 -> Just [ Eq b1 b2 ]
268 | c1 == c2 && a1 == b2 -> Just [ Eq b1 a2 ]
269 | c1 == c2 && b1 == b2 -> Just [ Eq a1 a2 ]
270 | c1 == c2 && b1 == a2 -> Just [ Eq a1 b2 ]
271 _ -> Just []
272
273
274 EqFun Mul a1 b1 c1 ->
275 case prop of
276 EqFun Mul a2 b2 c2
277 | a1 == a2 && b1 == b2 -> Just [ Eq c1 c2 ]
278 | a1 == b2 && b1 == a2 -> Just [ Eq c1 c2 ]
279 | c1 == c2 && b1 == b2, Num m _ <- a1, Num n _ <- a2, m /= n
280 -> Just [ Eq (num 0) b1, Eq (num 0) c1 ]
281 _ -> Just []
282
283
284 _ -> Just []
285
286
287 solveWanted1 :: Prop -> Maybe [ Prop ]
288 solveWanted1 prop =
289 case prop of
290
291 EqFun Add (Num m _) (Num n _) (Num mn _) | m + n == mn -> Just []
292 EqFun Add (Num 0 _) s t | s == t -> Just []
293 EqFun Add r s t | r == s ->
294 Just [ EqFun Mul (num 2) r t ]
295
296 EqFun Mul (Num m _) (Num n _) (Num mn _) | m * n == mn -> Just []
297 EqFun Mul (Num 0 _) _ (Num 0 _) -> Just []
298 EqFun Mul (Num 1 _) s t | s == t -> Just []
299 EqFun Mul r s t | r == s ->
300 Just [ EqFun Exp r (num 2) t ]
301
302 -- Simple normalization of commutative operators
303 EqFun op r@(Var _) s@(Num _ _) t | commutative op ->
304 Just [ EqFun op s r t ]
305
306 EqFun Exp (Num m _) (Num n _) (Num mn _) | m ^ n == mn -> Just []
307 EqFun Exp (Num 1 _) _ (Num 1 _) -> Just []
308 EqFun Exp _ (Num 0 _) (Num 1 _) -> Just []
309 EqFun Exp r (Num 1 _) t | r == t -> Just []
310 EqFun Exp r (Num _ _) t | r == t ->
311 Just [Leq r (num 1)]
312
313 _ -> Nothing
314
315
316 solveWanted2 :: Prop -> Prop -> Maybe [Prop]
317 solveWanted2 asmp prop =
318 case (asmp, prop) of
319
320 (EqFun op1 r1 s1 t1, EqFun op2 r2 s2 t2)
321 | op1 == op2 && t1 == t2 &&
322 ( r1 == r2 && s1 == s2
323 || commutative op1 && r1 == s2 && s1 == r2
324 ) -> Just []
325
326 (EqFun Add (Num m _) b c1, EqFun Add (Num n _) d c2)
327 | c1 == c2 -> if m >= n then Just [ EqFun Add (num (m - n)) b d ]
328 else Just [ EqFun Add (num (n - m)) d b ]
329
330 (EqFun Mul (Num m _) b c1, EqFun Mul (Num n _) d c2)
331 | c1 == c2, Just x <- divide m n -> Just [ EqFun Mul (num x) b d ]
332 | c1 == c2, Just x <- divide n m -> Just [ EqFun Mul (num x) d b ]
333
334 -- hm: m * b = c |- c + b = d <=> (m + 1) * b = d
335 (EqFun Mul (Num m _) s1 t1, EqFun Add r2 s2 t2)
336 | t1 == r2 && s1 == s2 -> Just [ EqFun Mul (num (m + 1)) s1 t2 ]
337 | t1 == s2 && s1 == r2 -> Just [ EqFun Mul (num (m + 1)) r2 t2 ]
338
339 _ -> Nothing
340
341
342 --------------------------------------------------------------------------------
343 -- Reasoning about ordering.
344 --------------------------------------------------------------------------------
345
346 -- This function assumes that the assumptions are acyclic.
347 isLeq :: [(Term, Term)] -> Term -> Term -> Bool
348 isLeq _ (Num 0 _) _ = True
349 isLeq _ (Num m _) (Num n _) = m <= n
350 isLeq _ a b | a == b = True
351 isLeq ps (Num m _) a = or [ isLeq ps b a | (Num x _, b) <- ps, m <= x ]
352 isLeq ps a (Num m _) = or [ isLeq ps a b | (b, Num x _) <- ps, x <= m ]
353 isLeq ps a b = or [ isLeq ps c b | (a',c) <- ps, a == a' ]
354
355 isGt :: [(Term, Term)] -> Term -> Term -> Bool
356 isGt _ (Num m _) (Num n _) = m > n
357 isGt ps (Num m _) a = (m > 0) && isLeq ps a (num (m - 1))
358 isGt ps a (Num m _) = isLeq ps (num (m + 1)) a
359 isGt _ _ _ = False
360
361 improveLeq :: [(Term,Term)] -> Term -> Term -> Maybe [Prop]
362 improveLeq ps a b | isLeq ps b a = Just [Eq a b]
363 | isGt ps a b = Nothing
364 | otherwise = Just []
365
366 -- Ordering constraints derived from numeric predicates.
367 -- We do not consider equlities because they should be substituted away.
368 propsToOrd :: [Prop] -> [(Term,Term)]
369 propsToOrd props = loop (step [] unconditional)
370 where
371 loop ps = let new = filter (`notElem` ps) (step ps conditional)
372 in if null new then ps else loop (new ++ ps)
373
374 step ps = nub . concatMap (toOrd ps)
375
376 isConditional (EqFun op _ _ _) = op == Mul || op == Exp
377 isConditional _ = False
378
379 (conditional,unconditional) = partition isConditional props
380
381 toOrd _ (Leq a b) = [(a,b)]
382 toOrd _ (Eq _ _) = [] -- Would lead to a cycle, should be subst. away
383 toOrd ps (EqFun op a b c) =
384 case op of
385 Add -> [(a,c),(b,c)]
386
387 Mul -> (guard (isLeq ps (num 1) a) >> return (b,c)) ++
388 (guard (isLeq ps (num 1) b) >> return (a,c))
389 Exp
390 | Num 0 _ <- a -> [(c, num 1)]
391 | a == c -> [(a, num 1)]
392 | otherwise -> (guard (isLeq ps (num 2) a) >> return (b,c)) ++
393 (guard (isLeq ps (num 1) b) >> return (a,c))
394
395
396 --------------------------------------------------------------------------------
397 -- Descrete Math
398 --------------------------------------------------------------------------------
399
400 descreteRoot :: Integer -> Integer -> Maybe Integer
401 descreteRoot root num = search 0 num
402 where
403 search from to = let x = from + div (to - from) 2
404 a = x ^ root
405 in case compare a num of
406 EQ -> Just x
407 LT | x /= from -> search x to
408 GT | x /= to -> search from x
409 _ -> Nothing
410
411 descreteLog :: Integer -> Integer -> Maybe Integer
412 descreteLog _ 0 = Just 0
413 descreteLog base num | base == num = Just 1
414 descreteLog base num = case divMod num base of
415 (x,0) -> fmap (1+) (descreteLog base x)
416 _ -> Nothing
417
418 divide :: Integer -> Integer -> Maybe Integer
419 divide _ 0 = Nothing
420 divide x y = case divMod x y of
421 (a,0) -> Just a
422 _ -> Nothing
423
424
425 --------------------------------------------------------------------------------
426 -- Debugging
427 --------------------------------------------------------------------------------
428
429 numTrace :: String -> SDoc -> TcS ()
430 numTrace x y = traceTcS ("[numerics] " ++ x) y
431
432 vmany :: Outputable a => [a] -> SDoc
433 vmany xs = braces $ hcat $ punctuate comma $ map ppr xs
434
435 instance Outputable Term where
436 ppr (Var xi) = ppr xi
437 ppr (Num n _) = integer n
438
439 instance Outputable Prop where
440 ppr (EqFun op t1 t2 t3) = ppr t1 <+> ppr op <+> ppr t2 <+> char '~' <+> ppr t3
441 ppr (Leq t1 t2) = ppr t1 <+> text "<=" <+> ppr t2
442 ppr (Eq t1 t2) = ppr t1 <+> char '~' <+> ppr t2
443
444 instance Outputable Op where
445 ppr op = case op of
446 Add -> char '+'
447 Mul -> char '*'
448 Exp -> char '^'
449