Fix type-equality in the type checker (fixes Trac #8553)
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 module TcTypeNats
2 ( typeNatTyCons
3 , typeNatCoAxiomRules
4 , BuiltInSynFamily(..)
5 ) where
6
7 import Type
8 import Pair
9 import TcType ( TcType, tcEqType )
10 import TyCon ( TyCon, SynTyConRhs(..), mkSynTyCon, TyConParent(..) )
11 import Coercion ( Role(..) )
12 import TcRnTypes ( Xi )
13 import CoAxiom ( CoAxiomRule(..), BuiltInSynFamily(..) )
14 import Name ( Name, BuiltInSyntax(..) )
15 import TysWiredIn ( typeNatKind, mkWiredInTyConName
16 , promotedBoolTyCon
17 , promotedFalseDataCon, promotedTrueDataCon
18 )
19 import TysPrim ( tyVarList, mkArrowKinds )
20 import PrelNames ( gHC_TYPELITS
21 , typeNatAddTyFamNameKey
22 , typeNatMulTyFamNameKey
23 , typeNatExpTyFamNameKey
24 , typeNatLeqTyFamNameKey
25 , typeNatSubTyFamNameKey
26 )
27 import FastString ( FastString, fsLit )
28 import qualified Data.Map as Map
29 import Data.Maybe ( isJust )
30
31 {-------------------------------------------------------------------------------
32 Built-in type constructors for functions on type-lelve nats
33 -}
34
35 typeNatTyCons :: [TyCon]
36 typeNatTyCons =
37 [ typeNatAddTyCon
38 , typeNatMulTyCon
39 , typeNatExpTyCon
40 , typeNatLeqTyCon
41 , typeNatSubTyCon
42 ]
43
44 typeNatAddTyCon :: TyCon
45 typeNatAddTyCon = mkTypeNatFunTyCon2 name
46 BuiltInSynFamily
47 { sfMatchFam = matchFamAdd
48 , sfInteractTop = interactTopAdd
49 , sfInteractInert = interactInertAdd
50 }
51 where
52 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "+")
53 typeNatAddTyFamNameKey typeNatAddTyCon
54
55 typeNatSubTyCon :: TyCon
56 typeNatSubTyCon = mkTypeNatFunTyCon2 name
57 BuiltInSynFamily
58 { sfMatchFam = matchFamSub
59 , sfInteractTop = interactTopSub
60 , sfInteractInert = interactInertSub
61 }
62 where
63 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "-")
64 typeNatSubTyFamNameKey typeNatSubTyCon
65
66 typeNatMulTyCon :: TyCon
67 typeNatMulTyCon = mkTypeNatFunTyCon2 name
68 BuiltInSynFamily
69 { sfMatchFam = matchFamMul
70 , sfInteractTop = interactTopMul
71 , sfInteractInert = interactInertMul
72 }
73 where
74 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "*")
75 typeNatMulTyFamNameKey typeNatMulTyCon
76
77 typeNatExpTyCon :: TyCon
78 typeNatExpTyCon = mkTypeNatFunTyCon2 name
79 BuiltInSynFamily
80 { sfMatchFam = matchFamExp
81 , sfInteractTop = interactTopExp
82 , sfInteractInert = interactInertExp
83 }
84 where
85 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "^")
86 typeNatExpTyFamNameKey typeNatExpTyCon
87
88 typeNatLeqTyCon :: TyCon
89 typeNatLeqTyCon =
90 mkSynTyCon name
91 (mkArrowKinds [ typeNatKind, typeNatKind ] boolKind)
92 (take 2 $ tyVarList typeNatKind)
93 [Nominal,Nominal]
94 (BuiltInSynFamTyCon ops)
95 NoParentTyCon
96
97 where
98 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "<=?")
99 typeNatLeqTyFamNameKey typeNatLeqTyCon
100 ops = BuiltInSynFamily
101 { sfMatchFam = matchFamLeq
102 , sfInteractTop = interactTopLeq
103 , sfInteractInert = interactInertLeq
104 }
105
106
107 -- Make a binary built-in constructor of kind: Nat -> Nat -> Nat
108 mkTypeNatFunTyCon2 :: Name -> BuiltInSynFamily -> TyCon
109 mkTypeNatFunTyCon2 op tcb =
110 mkSynTyCon op
111 (mkArrowKinds [ typeNatKind, typeNatKind ] typeNatKind)
112 (take 2 $ tyVarList typeNatKind)
113 [Nominal,Nominal]
114 (BuiltInSynFamTyCon tcb)
115 NoParentTyCon
116
117
118
119
120 {-------------------------------------------------------------------------------
121 Built-in rules axioms
122 -------------------------------------------------------------------------------}
123
124 -- If you add additional rules, please remember to add them to
125 -- `typeNatCoAxiomRules` also.
126 axAddDef
127 , axMulDef
128 , axExpDef
129 , axLeqDef
130 , axAdd0L
131 , axAdd0R
132 , axMul0L
133 , axMul0R
134 , axMul1L
135 , axMul1R
136 , axExp1L
137 , axExp0R
138 , axExp1R
139 , axLeqRefl
140 , axLeq0L
141 , axSubDef
142 , axSub0R
143 :: CoAxiomRule
144
145 axAddDef = mkBinAxiom "AddDef" typeNatAddTyCon $
146 \x y -> Just $ num (x + y)
147
148 axMulDef = mkBinAxiom "MulDef" typeNatMulTyCon $
149 \x y -> Just $ num (x * y)
150
151 axExpDef = mkBinAxiom "ExpDef" typeNatExpTyCon $
152 \x y -> Just $ num (x ^ y)
153
154 axLeqDef = mkBinAxiom "LeqDef" typeNatLeqTyCon $
155 \x y -> Just $ bool (x <= y)
156
157 axSubDef = mkBinAxiom "SubDef" typeNatSubTyCon $
158 \x y -> fmap num (minus x y)
159
160 axAdd0L = mkAxiom1 "Add0L" $ \t -> (num 0 .+. t) === t
161 axAdd0R = mkAxiom1 "Add0R" $ \t -> (t .+. num 0) === t
162 axSub0R = mkAxiom1 "Sub0R" $ \t -> (t .-. num 0) === t
163 axMul0L = mkAxiom1 "Mul0L" $ \t -> (num 0 .*. t) === num 0
164 axMul0R = mkAxiom1 "Mul0R" $ \t -> (t .*. num 0) === num 0
165 axMul1L = mkAxiom1 "Mul1L" $ \t -> (num 1 .*. t) === t
166 axMul1R = mkAxiom1 "Mul1R" $ \t -> (t .*. num 1) === t
167 axExp1L = mkAxiom1 "Exp1L" $ \t -> (num 1 .^. t) === num 1
168 axExp0R = mkAxiom1 "Exp0R" $ \t -> (t .^. num 0) === num 1
169 axExp1R = mkAxiom1 "Exp1R" $ \t -> (t .^. num 1) === t
170 axLeqRefl = mkAxiom1 "LeqRefl" $ \t -> (t <== t) === bool True
171 axLeq0L = mkAxiom1 "Leq0L" $ \t -> (num 0 <== t) === bool True
172
173 typeNatCoAxiomRules :: Map.Map FastString CoAxiomRule
174 typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
175 [ axAddDef
176 , axMulDef
177 , axExpDef
178 , axLeqDef
179 , axAdd0L
180 , axAdd0R
181 , axMul0L
182 , axMul0R
183 , axMul1L
184 , axMul1R
185 , axExp1L
186 , axExp0R
187 , axExp1R
188 , axLeqRefl
189 , axLeq0L
190 , axSubDef
191 ]
192
193
194
195 {-------------------------------------------------------------------------------
196 Various utilities for making axioms and types
197 -------------------------------------------------------------------------------}
198
199 (.+.) :: Type -> Type -> Type
200 s .+. t = mkTyConApp typeNatAddTyCon [s,t]
201
202 (.-.) :: Type -> Type -> Type
203 s .-. t = mkTyConApp typeNatSubTyCon [s,t]
204
205 (.*.) :: Type -> Type -> Type
206 s .*. t = mkTyConApp typeNatMulTyCon [s,t]
207
208 (.^.) :: Type -> Type -> Type
209 s .^. t = mkTyConApp typeNatExpTyCon [s,t]
210
211 (<==) :: Type -> Type -> Type
212 s <== t = mkTyConApp typeNatLeqTyCon [s,t]
213
214 (===) :: Type -> Type -> Pair Type
215 x === y = Pair x y
216
217 num :: Integer -> Type
218 num = mkNumLitTy
219
220 boolKind :: Kind
221 boolKind = mkTyConApp promotedBoolTyCon []
222
223 bool :: Bool -> Type
224 bool b = if b then mkTyConApp promotedTrueDataCon []
225 else mkTyConApp promotedFalseDataCon []
226
227 isBoolLitTy :: Type -> Maybe Bool
228 isBoolLitTy tc =
229 do (tc,[]) <- splitTyConApp_maybe tc
230 case () of
231 _ | tc == promotedFalseDataCon -> return False
232 | tc == promotedTrueDataCon -> return True
233 | otherwise -> Nothing
234
235 known :: (Integer -> Bool) -> TcType -> Bool
236 known p x = case isNumLitTy x of
237 Just a -> p a
238 Nothing -> False
239
240
241
242
243 -- For the definitional axioms
244 mkBinAxiom :: String -> TyCon ->
245 (Integer -> Integer -> Maybe Type) -> CoAxiomRule
246 mkBinAxiom str tc f =
247 CoAxiomRule
248 { coaxrName = fsLit str
249 , coaxrTypeArity = 2
250 , coaxrAsmpRoles = []
251 , coaxrRole = Nominal
252 , coaxrProves = \ts cs ->
253 case (ts,cs) of
254 ([s,t],[]) -> do x <- isNumLitTy s
255 y <- isNumLitTy t
256 z <- f x y
257 return (mkTyConApp tc [s,t] === z)
258 _ -> Nothing
259 }
260
261 mkAxiom1 :: String -> (Type -> Pair Type) -> CoAxiomRule
262 mkAxiom1 str f =
263 CoAxiomRule
264 { coaxrName = fsLit str
265 , coaxrTypeArity = 1
266 , coaxrAsmpRoles = []
267 , coaxrRole = Nominal
268 , coaxrProves = \ts cs ->
269 case (ts,cs) of
270 ([s],[]) -> return (f s)
271 _ -> Nothing
272 }
273
274
275 {-------------------------------------------------------------------------------
276 Evaluation
277 -------------------------------------------------------------------------------}
278
279 matchFamAdd :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
280 matchFamAdd [s,t]
281 | Just 0 <- mbX = Just (axAdd0L, [t], t)
282 | Just 0 <- mbY = Just (axAdd0R, [s], s)
283 | Just x <- mbX, Just y <- mbY =
284 Just (axAddDef, [s,t], num (x + y))
285 where mbX = isNumLitTy s
286 mbY = isNumLitTy t
287 matchFamAdd _ = Nothing
288
289 matchFamSub :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
290 matchFamSub [s,t]
291 | Just 0 <- mbY = Just (axSub0R, [s], s)
292 | Just x <- mbX, Just y <- mbY, Just z <- minus x y =
293 Just (axSubDef, [s,t], num z)
294 where mbX = isNumLitTy s
295 mbY = isNumLitTy t
296 matchFamSub _ = Nothing
297
298 matchFamMul :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
299 matchFamMul [s,t]
300 | Just 0 <- mbX = Just (axMul0L, [t], num 0)
301 | Just 0 <- mbY = Just (axMul0R, [s], num 0)
302 | Just 1 <- mbX = Just (axMul1L, [t], t)
303 | Just 1 <- mbY = Just (axMul1R, [s], s)
304 | Just x <- mbX, Just y <- mbY =
305 Just (axMulDef, [s,t], num (x * y))
306 where mbX = isNumLitTy s
307 mbY = isNumLitTy t
308 matchFamMul _ = Nothing
309
310 matchFamExp :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
311 matchFamExp [s,t]
312 | Just 0 <- mbY = Just (axExp0R, [s], num 1)
313 | Just 1 <- mbX = Just (axExp1L, [t], num 1)
314 | Just 1 <- mbY = Just (axExp1R, [s], s)
315 | Just x <- mbX, Just y <- mbY =
316 Just (axExpDef, [s,t], num (x ^ y))
317 where mbX = isNumLitTy s
318 mbY = isNumLitTy t
319 matchFamExp _ = Nothing
320
321 matchFamLeq :: [Type] -> Maybe (CoAxiomRule, [Type], Type)
322 matchFamLeq [s,t]
323 | Just 0 <- mbX = Just (axLeq0L, [t], bool True)
324 | Just x <- mbX, Just y <- mbY =
325 Just (axLeqDef, [s,t], bool (x <= y))
326 | tcEqType s t = Just (axLeqRefl, [s], bool True)
327 where mbX = isNumLitTy s
328 mbY = isNumLitTy t
329 matchFamLeq _ = Nothing
330
331 {-------------------------------------------------------------------------------
332 Interact with axioms
333 -------------------------------------------------------------------------------}
334
335 interactTopAdd :: [Xi] -> Xi -> [Pair Type]
336 interactTopAdd [s,t] r
337 | Just 0 <- mbZ = [ s === num 0, t === num 0 ] -- (s + t ~ 0) => (s ~ 0, t ~ 0)
338 | Just x <- mbX, Just z <- mbZ, Just y <- minus z x = [t === num y] -- (5 + t ~ 8) => (t ~ 3)
339 | Just y <- mbY, Just z <- mbZ, Just x <- minus z y = [s === num x] -- (s + 5 ~ 8) => (s ~ 3)
340 where
341 mbX = isNumLitTy s
342 mbY = isNumLitTy t
343 mbZ = isNumLitTy r
344 interactTopAdd _ _ = []
345
346 {-
347 Note [Weakened interaction rule for subtraction]
348 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
349
350 A simpler interaction here might be:
351
352 `s - t ~ r` --> `t + r ~ s`
353
354 This would enable us to reuse all the code for addition.
355 Unfortunately, this works a little too well at the moment.
356 Consider the following example:
357
358 0 - 5 ~ r --> 5 + r ~ 0 --> (5 = 0, r = 0)
359
360 This (correctly) spots that the constraint cannot be solved.
361
362 However, this may be a problem if the constraint did not
363 need to be solved in the first place! Consider the following example:
364
365 f :: Proxy (If (5 <=? 0) (0 - 5) (5 - 0)) -> Proxy 5
366 f = id
367
368 Currently, GHC is strict while evaluating functions, so this does not
369 work, because even though the `If` should evaluate to `5 - 0`, we
370 also evaluate the "then" branch which generates the constraint `0 - 5 ~ r`,
371 which fails.
372
373 So, for the time being, we only add an improvement when the RHS is a constant,
374 which happens to work OK for the moment, although clearly we need to do
375 something more general.
376 -}
377 interactTopSub :: [Xi] -> Xi -> [Pair Type]
378 interactTopSub [s,t] r
379 | Just z <- mbZ = [ s === (num z .+. t) ] -- (s - t ~ 5) => (5 + t ~ s)
380 where
381 mbZ = isNumLitTy r
382 interactTopSub _ _ = []
383
384
385
386
387
388 interactTopMul :: [Xi] -> Xi -> [Pair Type]
389 interactTopMul [s,t] r
390 | Just 1 <- mbZ = [ s === num 1, t === num 1 ] -- (s * t ~ 1) => (s ~ 1, t ~ 1)
391 | Just x <- mbX, Just z <- mbZ, Just y <- divide z x = [t === num y] -- (3 * t ~ 15) => (t ~ 5)
392 | Just y <- mbY, Just z <- mbZ, Just x <- divide z y = [s === num x] -- (s * 3 ~ 15) => (s ~ 5)
393 where
394 mbX = isNumLitTy s
395 mbY = isNumLitTy t
396 mbZ = isNumLitTy r
397 interactTopMul _ _ = []
398
399 interactTopExp :: [Xi] -> Xi -> [Pair Type]
400 interactTopExp [s,t] r
401 | Just 0 <- mbZ = [ s === num 0 ] -- (s ^ t ~ 0) => (s ~ 0)
402 | Just x <- mbX, Just z <- mbZ, Just y <- logExact z x = [t === num y] -- (2 ^ t ~ 8) => (t ~ 3)
403 | Just y <- mbY, Just z <- mbZ, Just x <- rootExact z y = [s === num x] -- (s ^ 2 ~ 9) => (s ~ 3)
404 where
405 mbX = isNumLitTy s
406 mbY = isNumLitTy t
407 mbZ = isNumLitTy r
408 interactTopExp _ _ = []
409
410 interactTopLeq :: [Xi] -> Xi -> [Pair Type]
411 interactTopLeq [s,t] r
412 | Just 0 <- mbY, Just True <- mbZ = [ s === num 0 ] -- (s <= 0) => (s ~ 0)
413 where
414 mbY = isNumLitTy t
415 mbZ = isBoolLitTy r
416 interactTopLeq _ _ = []
417
418
419
420 {-------------------------------------------------------------------------------
421 Interaction with inerts
422 -------------------------------------------------------------------------------}
423
424 interactInertAdd :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
425 interactInertAdd [x1,y1] z1 [x2,y2] z2
426 | sameZ && tcEqType x1 x2 = [ y1 === y2 ]
427 | sameZ && tcEqType y1 y2 = [ x1 === x2 ]
428 where sameZ = tcEqType z1 z2
429 interactInertAdd _ _ _ _ = []
430
431 interactInertSub :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
432 interactInertSub [x1,y1] z1 [x2,y2] z2
433 | sameZ && tcEqType x1 x2 = [ y1 === y2 ]
434 | sameZ && tcEqType y1 y2 = [ x1 === x2 ]
435 where sameZ = tcEqType z1 z2
436 interactInertSub _ _ _ _ = []
437
438 interactInertMul :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
439 interactInertMul [x1,y1] z1 [x2,y2] z2
440 | sameZ && known (/= 0) x1 && tcEqType x1 x2 = [ y1 === y2 ]
441 | sameZ && known (/= 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
442 where sameZ = tcEqType z1 z2
443
444 interactInertMul _ _ _ _ = []
445
446 interactInertExp :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
447 interactInertExp [x1,y1] z1 [x2,y2] z2
448 | sameZ && known (> 1) x1 && tcEqType x1 x2 = [ y1 === y2 ]
449 | sameZ && known (> 0) y1 && tcEqType y1 y2 = [ x1 === x2 ]
450 where sameZ = tcEqType z1 z2
451
452 interactInertExp _ _ _ _ = []
453
454
455 interactInertLeq :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
456 interactInertLeq [x1,y1] z1 [x2,y2] z2
457 | bothTrue && tcEqType x1 y2 && tcEqType y1 x2 = [ x1 === y1 ]
458 | bothTrue && tcEqType y1 x2 = [ (x1 <== y2) === bool True ]
459 | bothTrue && tcEqType y2 x1 = [ (x2 <== y1) === bool True ]
460 where bothTrue = isJust $ do True <- isBoolLitTy z1
461 True <- isBoolLitTy z2
462 return ()
463
464 interactInertLeq _ _ _ _ = []
465
466
467
468
469 {- -----------------------------------------------------------------------------
470 These inverse functions are used for simplifying propositions using
471 concrete natural numbers.
472 ----------------------------------------------------------------------------- -}
473
474 -- | Subtract two natural numbers.
475 minus :: Integer -> Integer -> Maybe Integer
476 minus x y = if x >= y then Just (x - y) else Nothing
477
478 -- | Compute the exact logarithm of a natural number.
479 -- The logarithm base is the second argument.
480 logExact :: Integer -> Integer -> Maybe Integer
481 logExact x y = do (z,True) <- genLog x y
482 return z
483
484
485 -- | Divide two natural numbers.
486 divide :: Integer -> Integer -> Maybe Integer
487 divide _ 0 = Nothing
488 divide x y = case divMod x y of
489 (a,0) -> Just a
490 _ -> Nothing
491
492 -- | Compute the exact root of a natural number.
493 -- The second argument specifies which root we are computing.
494 rootExact :: Integer -> Integer -> Maybe Integer
495 rootExact x y = do (z,True) <- genRoot x y
496 return z
497
498
499
500 {- | Compute the the n-th root of a natural number, rounded down to
501 the closest natural number. The boolean indicates if the result
502 is exact (i.e., True means no rounding was done, False means rounded down).
503 The second argument specifies which root we are computing. -}
504 genRoot :: Integer -> Integer -> Maybe (Integer, Bool)
505 genRoot _ 0 = Nothing
506 genRoot x0 1 = Just (x0, True)
507 genRoot x0 root = Just (search 0 (x0+1))
508 where
509 search from to = let x = from + div (to - from) 2
510 a = x ^ root
511 in case compare a x0 of
512 EQ -> (x, True)
513 LT | x /= from -> search x to
514 | otherwise -> (from, False)
515 GT | x /= to -> search from x
516 | otherwise -> (from, False)
517
518 {- | Compute the logarithm of a number in the given base, rounded down to the
519 closest integer. The boolean indicates if we the result is exact
520 (i.e., True means no rounding happened, False means we rounded down).
521 The logarithm base is the second argument. -}
522 genLog :: Integer -> Integer -> Maybe (Integer, Bool)
523 genLog x 0 = if x == 1 then Just (0, True) else Nothing
524 genLog _ 1 = Nothing
525 genLog 0 _ = Nothing
526 genLog x base = Just (exactLoop 0 x)
527 where
528 exactLoop s i
529 | i == 1 = (s,True)
530 | i < base = (s,False)
531 | otherwise =
532 let s1 = s + 1
533 in s1 `seq` case divMod i base of
534 (j,r)
535 | r == 0 -> exactLoop s1 j
536 | otherwise -> (underLoop s1 j, False)
537
538 underLoop s i
539 | i < base = s
540 | otherwise = let s1 = s + 1 in s1 `seq` underLoop s1 (div i base)
541
542
543
544
545
546
547