Weaken the improvement for subtraction.
[ghc.git] / compiler / typecheck / TcTypeNats.hs
1 module TcTypeNats
2 ( typeNatTyCons
3 , typeNatCoAxiomRules
4 , TcBuiltInSynFamily(..)
5 ) where
6
7 import Type
8 import Pair
9 import TcType ( TcType )
10 import TyCon ( TyCon, SynTyConRhs(..), mkSynTyCon, TyConParent(..) )
11 import Coercion ( Role(..) )
12 import TcRnTypes ( Xi )
13 import TcEvidence ( mkTcAxiomRuleCo, TcCoercion )
14 import CoAxiom ( CoAxiomRule(..) )
15 import Name ( Name, BuiltInSyntax(..) )
16 import TysWiredIn ( typeNatKind, mkWiredInTyConName
17 , promotedBoolTyCon
18 , promotedFalseDataCon, promotedTrueDataCon
19 )
20 import TysPrim ( tyVarList, mkArrowKinds )
21 import PrelNames ( gHC_TYPELITS
22 , typeNatAddTyFamNameKey
23 , typeNatMulTyFamNameKey
24 , typeNatExpTyFamNameKey
25 , typeNatLeqTyFamNameKey
26 , typeNatSubTyFamNameKey
27 )
28 import FamInst ( TcBuiltInSynFamily(..) )
29 import FastString ( FastString, fsLit )
30 import qualified Data.Map as Map
31 import Data.Maybe ( isJust )
32
33 {-------------------------------------------------------------------------------
34 Built-in type constructors for functions on type-lelve nats
35 -}
36
37 typeNatTyCons :: [TyCon]
38 typeNatTyCons =
39 [ typeNatAddTyCon
40 , typeNatMulTyCon
41 , typeNatExpTyCon
42 , typeNatLeqTyCon
43 , typeNatSubTyCon
44 ]
45
46 typeNatAddTyCon :: TyCon
47 typeNatAddTyCon = mkTypeNatFunTyCon2 name
48 TcBuiltInSynFamily
49 { sfMatchFam = matchFamAdd
50 , sfInteractTop = interactTopAdd
51 , sfInteractInert = interactInertAdd
52 }
53 where
54 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "+")
55 typeNatAddTyFamNameKey typeNatAddTyCon
56
57 typeNatSubTyCon :: TyCon
58 typeNatSubTyCon = mkTypeNatFunTyCon2 name
59 TcBuiltInSynFamily
60 { sfMatchFam = matchFamSub
61 , sfInteractTop = interactTopSub
62 , sfInteractInert = interactInertSub
63 }
64 where
65 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "-")
66 typeNatSubTyFamNameKey typeNatSubTyCon
67
68 typeNatMulTyCon :: TyCon
69 typeNatMulTyCon = mkTypeNatFunTyCon2 name
70 TcBuiltInSynFamily
71 { sfMatchFam = matchFamMul
72 , sfInteractTop = interactTopMul
73 , sfInteractInert = interactInertMul
74 }
75 where
76 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "*")
77 typeNatMulTyFamNameKey typeNatMulTyCon
78
79 typeNatExpTyCon :: TyCon
80 typeNatExpTyCon = mkTypeNatFunTyCon2 name
81 TcBuiltInSynFamily
82 { sfMatchFam = matchFamExp
83 , sfInteractTop = interactTopExp
84 , sfInteractInert = interactInertExp
85 }
86 where
87 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "^")
88 typeNatExpTyFamNameKey typeNatExpTyCon
89
90 typeNatLeqTyCon :: TyCon
91 typeNatLeqTyCon =
92 mkSynTyCon name
93 (mkArrowKinds [ typeNatKind, typeNatKind ] boolKind)
94 (take 2 $ tyVarList typeNatKind)
95 [Nominal,Nominal]
96 (BuiltInSynFamTyCon ops)
97 NoParentTyCon
98
99 where
100 name = mkWiredInTyConName UserSyntax gHC_TYPELITS (fsLit "<=?")
101 typeNatLeqTyFamNameKey typeNatLeqTyCon
102 ops = TcBuiltInSynFamily
103 { sfMatchFam = matchFamLeq
104 , sfInteractTop = interactTopLeq
105 , sfInteractInert = interactInertLeq
106 }
107
108
109 -- Make a binary built-in constructor of kind: Nat -> Nat -> Nat
110 mkTypeNatFunTyCon2 :: Name -> TcBuiltInSynFamily -> TyCon
111 mkTypeNatFunTyCon2 op tcb =
112 mkSynTyCon op
113 (mkArrowKinds [ typeNatKind, typeNatKind ] typeNatKind)
114 (take 2 $ tyVarList typeNatKind)
115 [Nominal,Nominal]
116 (BuiltInSynFamTyCon tcb)
117 NoParentTyCon
118
119
120
121
122 {-------------------------------------------------------------------------------
123 Built-in rules axioms
124 -------------------------------------------------------------------------------}
125
126 -- If you add additional rules, please remember to add them to
127 -- `typeNatCoAxiomRules` also.
128 axAddDef
129 , axMulDef
130 , axExpDef
131 , axLeqDef
132 , axAdd0L
133 , axAdd0R
134 , axMul0L
135 , axMul0R
136 , axMul1L
137 , axMul1R
138 , axExp1L
139 , axExp0R
140 , axExp1R
141 , axLeqRefl
142 , axLeq0L
143 , axSubDef
144 , axSub0R
145 :: CoAxiomRule
146
147 axAddDef = mkBinAxiom "AddDef" typeNatAddTyCon $
148 \x y -> Just $ num (x + y)
149
150 axMulDef = mkBinAxiom "MulDef" typeNatMulTyCon $
151 \x y -> Just $ num (x * y)
152
153 axExpDef = mkBinAxiom "ExpDef" typeNatExpTyCon $
154 \x y -> Just $ num (x ^ y)
155
156 axLeqDef = mkBinAxiom "LeqDef" typeNatLeqTyCon $
157 \x y -> Just $ bool (x <= y)
158
159 axSubDef = mkBinAxiom "SubDef" typeNatSubTyCon $
160 \x y -> fmap num (minus x y)
161
162 axAdd0L = mkAxiom1 "Add0L" $ \t -> (num 0 .+. t) === t
163 axAdd0R = mkAxiom1 "Add0R" $ \t -> (t .+. num 0) === t
164 axSub0R = mkAxiom1 "Sub0R" $ \t -> (t .-. num 0) === t
165 axMul0L = mkAxiom1 "Mul0L" $ \t -> (num 0 .*. t) === num 0
166 axMul0R = mkAxiom1 "Mul0R" $ \t -> (t .*. num 0) === num 0
167 axMul1L = mkAxiom1 "Mul1L" $ \t -> (num 1 .*. t) === t
168 axMul1R = mkAxiom1 "Mul1R" $ \t -> (t .*. num 1) === t
169 axExp1L = mkAxiom1 "Exp1L" $ \t -> (num 1 .^. t) === num 1
170 axExp0R = mkAxiom1 "Exp0R" $ \t -> (t .^. num 0) === num 1
171 axExp1R = mkAxiom1 "Exp1R" $ \t -> (t .^. num 1) === t
172 axLeqRefl = mkAxiom1 "LeqRefl" $ \t -> (t <== t) === bool True
173 axLeq0L = mkAxiom1 "Leq0L" $ \t -> (num 0 <== t) === bool True
174
175 typeNatCoAxiomRules :: Map.Map FastString CoAxiomRule
176 typeNatCoAxiomRules = Map.fromList $ map (\x -> (coaxrName x, x))
177 [ axAddDef
178 , axMulDef
179 , axExpDef
180 , axLeqDef
181 , axAdd0L
182 , axAdd0R
183 , axMul0L
184 , axMul0R
185 , axMul1L
186 , axMul1R
187 , axExp1L
188 , axExp0R
189 , axExp1R
190 , axLeqRefl
191 , axLeq0L
192 , axSubDef
193 ]
194
195
196
197 {-------------------------------------------------------------------------------
198 Various utilities for making axioms and types
199 -------------------------------------------------------------------------------}
200
201 (.+.) :: Type -> Type -> Type
202 s .+. t = mkTyConApp typeNatAddTyCon [s,t]
203
204 (.-.) :: Type -> Type -> Type
205 s .-. t = mkTyConApp typeNatSubTyCon [s,t]
206
207 (.*.) :: Type -> Type -> Type
208 s .*. t = mkTyConApp typeNatMulTyCon [s,t]
209
210 (.^.) :: Type -> Type -> Type
211 s .^. t = mkTyConApp typeNatExpTyCon [s,t]
212
213 (<==) :: Type -> Type -> Type
214 s <== t = mkTyConApp typeNatLeqTyCon [s,t]
215
216 (===) :: Type -> Type -> Pair Type
217 x === y = Pair x y
218
219 num :: Integer -> Type
220 num = mkNumLitTy
221
222 boolKind :: Kind
223 boolKind = mkTyConApp promotedBoolTyCon []
224
225 bool :: Bool -> Type
226 bool b = if b then mkTyConApp promotedTrueDataCon []
227 else mkTyConApp promotedFalseDataCon []
228
229 isBoolLitTy :: Type -> Maybe Bool
230 isBoolLitTy tc =
231 do (tc,[]) <- splitTyConApp_maybe tc
232 case () of
233 _ | tc == promotedFalseDataCon -> return False
234 | tc == promotedTrueDataCon -> return True
235 | otherwise -> Nothing
236
237 known :: (Integer -> Bool) -> TcType -> Bool
238 known p x = case isNumLitTy x of
239 Just a -> p a
240 Nothing -> False
241
242
243
244
245 -- For the definitional axioms
246 mkBinAxiom :: String -> TyCon ->
247 (Integer -> Integer -> Maybe Type) -> CoAxiomRule
248 mkBinAxiom str tc f =
249 CoAxiomRule
250 { coaxrName = fsLit str
251 , coaxrTypeArity = 2
252 , coaxrAsmpRoles = []
253 , coaxrRole = Nominal
254 , coaxrProves = \ts cs ->
255 case (ts,cs) of
256 ([s,t],[]) -> do x <- isNumLitTy s
257 y <- isNumLitTy t
258 z <- f x y
259 return (mkTyConApp tc [s,t] === z)
260 _ -> Nothing
261 }
262
263 mkAxiom1 :: String -> (Type -> Pair Type) -> CoAxiomRule
264 mkAxiom1 str f =
265 CoAxiomRule
266 { coaxrName = fsLit str
267 , coaxrTypeArity = 1
268 , coaxrAsmpRoles = []
269 , coaxrRole = Nominal
270 , coaxrProves = \ts cs ->
271 case (ts,cs) of
272 ([s],[]) -> return (f s)
273 _ -> Nothing
274 }
275
276
277 {-------------------------------------------------------------------------------
278 Evaluation
279 -------------------------------------------------------------------------------}
280
281 matchFamAdd :: [Type] -> Maybe (TcCoercion, TcType)
282 matchFamAdd [s,t]
283 | Just 0 <- mbX = Just (mkTcAxiomRuleCo axAdd0L [t] [], t)
284 | Just 0 <- mbY = Just (mkTcAxiomRuleCo axAdd0R [s] [], s)
285 | Just x <- mbX, Just y <- mbY =
286 Just (mkTcAxiomRuleCo axAddDef [s,t] [], num (x + y))
287 where mbX = isNumLitTy s
288 mbY = isNumLitTy t
289 matchFamAdd _ = Nothing
290
291 matchFamSub :: [Type] -> Maybe (TcCoercion, TcType)
292 matchFamSub [s,t]
293 | Just 0 <- mbY = Just (mkTcAxiomRuleCo axSub0R [s] [], s)
294 | Just x <- mbX, Just y <- mbY, Just z <- minus x y =
295 Just (mkTcAxiomRuleCo axSubDef [s,t] [], num z)
296 where mbX = isNumLitTy s
297 mbY = isNumLitTy t
298 matchFamSub _ = Nothing
299
300 matchFamMul :: [Xi] -> Maybe (TcCoercion, Xi)
301 matchFamMul [s,t]
302 | Just 0 <- mbX = Just (mkTcAxiomRuleCo axMul0L [t] [], num 0)
303 | Just 0 <- mbY = Just (mkTcAxiomRuleCo axMul0R [s] [], num 0)
304 | Just 1 <- mbX = Just (mkTcAxiomRuleCo axMul1L [t] [], t)
305 | Just 1 <- mbY = Just (mkTcAxiomRuleCo axMul1R [s] [], s)
306 | Just x <- mbX, Just y <- mbY =
307 Just (mkTcAxiomRuleCo axMulDef [s,t] [], num (x * y))
308 where mbX = isNumLitTy s
309 mbY = isNumLitTy t
310 matchFamMul _ = Nothing
311
312 matchFamExp :: [Xi] -> Maybe (TcCoercion, Xi)
313 matchFamExp [s,t]
314 | Just 0 <- mbY = Just (mkTcAxiomRuleCo axExp0R [s] [], num 1)
315 | Just 1 <- mbX = Just (mkTcAxiomRuleCo axExp1L [t] [], num 1)
316 | Just 1 <- mbY = Just (mkTcAxiomRuleCo axExp1R [s] [], s)
317 | Just x <- mbX, Just y <- mbY =
318 Just (mkTcAxiomRuleCo axExpDef [s,t] [], num (x ^ y))
319 where mbX = isNumLitTy s
320 mbY = isNumLitTy t
321 matchFamExp _ = Nothing
322
323 matchFamLeq :: [Xi] -> Maybe (TcCoercion, Xi)
324 matchFamLeq [s,t]
325 | Just 0 <- mbX = Just (mkTcAxiomRuleCo axLeq0L [t] [], bool True)
326 | Just x <- mbX, Just y <- mbY =
327 Just (mkTcAxiomRuleCo axLeqDef [s,t] [], bool (x <= y))
328 | eqType s t = Just (mkTcAxiomRuleCo axLeqRefl [s] [], bool True)
329 where mbX = isNumLitTy s
330 mbY = isNumLitTy t
331 matchFamLeq _ = Nothing
332
333 {-------------------------------------------------------------------------------
334 Interact with axioms
335 -------------------------------------------------------------------------------}
336
337 interactTopAdd :: [Xi] -> Xi -> [Pair Type]
338 interactTopAdd [s,t] r
339 | Just 0 <- mbZ = [ s === num 0, t === num 0 ]
340 | Just x <- mbX, Just z <- mbZ, Just y <- minus z x = [t === num y]
341 | Just y <- mbY, Just z <- mbZ, Just x <- minus z y = [s === num x]
342 where
343 mbX = isNumLitTy s
344 mbY = isNumLitTy t
345 mbZ = isNumLitTy r
346 interactTopAdd _ _ = []
347
348 {- NOTE:
349 A simpler interaction here might be:
350
351 `s - t ~ r` --> `t + r ~ s`
352
353 This would enable us to reuse all the code for addition.
354 Unfortunately, this works a little too well at the moment.
355 Consider the following example:
356
357 0 - 5 ~ r --> 5 + r ~ 0 --> (5 = 0, r = 0)
358
359 This (correctly) spots that the constraint cannot be solved.
360
361 However, this may be a problem if the constraint did not
362 need to be solved in the first place! Consider the following example:
363
364 f :: Proxy (If (5 <=? 0) (0 - 5) (5 - 0)) -> Proxy 5
365 f = id
366
367 Currently, GHC is strict while evaluating functions, so this does not
368 work, because even though the `If` should evaluate to `5 - 0`, we
369 also evaluate the "else" branch which generates the constraint `0 - 5 ~ r`,
370 which fails.
371
372 So, for the time being, we only add an improvement when the RHS is a constant,
373 which happens to work OK for the moment, although clearly we need to do
374 something more general.
375 -}
376 interactTopSub :: [Xi] -> Xi -> [Pair Type]
377 interactTopSub [s,t] r
378 | Just z <- mbZ = [ s === (num z .+. t) ]
379 where
380 mbZ = isNumLitTy r
381 interactTopSub _ _ = []
382
383
384
385
386
387 interactTopMul :: [Xi] -> Xi -> [Pair Type]
388 interactTopMul [s,t] r
389 | Just 1 <- mbZ = [ s === num 1, t === num 1 ]
390 | Just x <- mbX, Just z <- mbZ, Just y <- divide z x = [t === num y]
391 | Just y <- mbY, Just z <- mbZ, Just x <- divide z y = [s === num x]
392 where
393 mbX = isNumLitTy s
394 mbY = isNumLitTy t
395 mbZ = isNumLitTy r
396 interactTopMul _ _ = []
397
398 interactTopExp :: [Xi] -> Xi -> [Pair Type]
399 interactTopExp [s,t] r
400 | Just 0 <- mbZ = [ s === num 0 ]
401 | Just x <- mbX, Just z <- mbZ, Just y <- logExact z x = [t === num y]
402 | Just y <- mbY, Just z <- mbZ, Just x <- rootExact z y = [s === num x]
403 where
404 mbX = isNumLitTy s
405 mbY = isNumLitTy t
406 mbZ = isNumLitTy r
407 interactTopExp _ _ = []
408
409 interactTopLeq :: [Xi] -> Xi -> [Pair Type]
410 interactTopLeq [s,t] r
411 | Just 0 <- mbY, Just True <- mbZ = [ s === num 0 ]
412 where
413 mbY = isNumLitTy t
414 mbZ = isBoolLitTy r
415 interactTopLeq _ _ = []
416
417
418
419 {-------------------------------------------------------------------------------
420 Interaction with inerts
421 -------------------------------------------------------------------------------}
422
423 interactInertAdd :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
424 interactInertAdd [x1,y1] z1 [x2,y2] z2
425 | sameZ && eqType x1 x2 = [ y1 === y2 ]
426 | sameZ && eqType y1 y2 = [ x1 === x2 ]
427 where sameZ = eqType z1 z2
428 interactInertAdd _ _ _ _ = []
429
430 interactInertSub :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
431 interactInertSub [x1,y1] z1 [x2,y2] z2
432 | sameZ && eqType x1 x2 = [ y1 === y2 ]
433 | sameZ && eqType y1 y2 = [ x1 === x2 ]
434 where sameZ = eqType z1 z2
435 interactInertSub _ _ _ _ = []
436
437 interactInertMul :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
438 interactInertMul [x1,y1] z1 [x2,y2] z2
439 | sameZ && known (/= 0) x1 && eqType x1 x2 = [ y1 === y2 ]
440 | sameZ && known (/= 0) y1 && eqType y1 y2 = [ x1 === x2 ]
441 where sameZ = eqType z1 z2
442
443 interactInertMul _ _ _ _ = []
444
445 interactInertExp :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
446 interactInertExp [x1,y1] z1 [x2,y2] z2
447 | sameZ && known (> 1) x1 && eqType x1 x2 = [ y1 === y2 ]
448 | sameZ && known (> 0) y1 && eqType y1 y2 = [ x1 === x2 ]
449 where sameZ = eqType z1 z2
450
451 interactInertExp _ _ _ _ = []
452
453
454 interactInertLeq :: [Xi] -> Xi -> [Xi] -> Xi -> [Pair Type]
455 interactInertLeq [x1,y1] z1 [x2,y2] z2
456 | bothTrue && eqType x1 y2 && eqType y1 x2 = [ x1 === y1 ]
457 | bothTrue && eqType y1 x2 = [ (x1 <== y2) === bool True ]
458 | bothTrue && eqType y2 x1 = [ (x2 <== y1) === bool True ]
459 where bothTrue = isJust $ do True <- isBoolLitTy z1
460 True <- isBoolLitTy z2
461 return ()
462
463 interactInertLeq _ _ _ _ = []
464
465
466
467
468 {- -----------------------------------------------------------------------------
469 These inverse functions are used for simplifying propositions using
470 concrete natural numbers.
471 ----------------------------------------------------------------------------- -}
472
473 -- | Subtract two natural numbers.
474 minus :: Integer -> Integer -> Maybe Integer
475 minus x y = if x >= y then Just (x - y) else Nothing
476
477 -- | Compute the exact logarithm of a natural number.
478 -- The logarithm base is the second argument.
479 logExact :: Integer -> Integer -> Maybe Integer
480 logExact x y = do (z,True) <- genLog x y
481 return z
482
483
484 -- | Divide two natural numbers.
485 divide :: Integer -> Integer -> Maybe Integer
486 divide _ 0 = Nothing
487 divide x y = case divMod x y of
488 (a,0) -> Just a
489 _ -> Nothing
490
491 -- | Compute the exact root of a natural number.
492 -- The second argument specifies which root we are computing.
493 rootExact :: Integer -> Integer -> Maybe Integer
494 rootExact x y = do (z,True) <- genRoot x y
495 return z
496
497
498
499 {- | Compute the the n-th root of a natural number, rounded down to
500 the closest natural number. The boolean indicates if the result
501 is exact (i.e., True means no rounding was done, False means rounded down).
502 The second argument specifies which root we are computing. -}
503 genRoot :: Integer -> Integer -> Maybe (Integer, Bool)
504 genRoot _ 0 = Nothing
505 genRoot x0 1 = Just (x0, True)
506 genRoot x0 root = Just (search 0 (x0+1))
507 where
508 search from to = let x = from + div (to - from) 2
509 a = x ^ root
510 in case compare a x0 of
511 EQ -> (x, True)
512 LT | x /= from -> search x to
513 | otherwise -> (from, False)
514 GT | x /= to -> search from x
515 | otherwise -> (from, False)
516
517 {- | Compute the logarithm of a number in the given base, rounded down to the
518 closest integer. The boolean indicates if we the result is exact
519 (i.e., True means no rounding happened, False means we rounded down).
520 The logarithm base is the second argument. -}
521 genLog :: Integer -> Integer -> Maybe (Integer, Bool)
522 genLog x 0 = if x == 1 then Just (0, True) else Nothing
523 genLog _ 1 = Nothing
524 genLog 0 _ = Nothing
525 genLog x base = Just (exactLoop 0 x)
526 where
527 exactLoop s i
528 | i == 1 = (s,True)
529 | i < base = (s,False)
530 | otherwise =
531 let s1 = s + 1
532 in s1 `seq` case divMod i base of
533 (j,r)
534 | r == 0 -> exactLoop s1 j
535 | otherwise -> (underLoop s1 j, False)
536
537 underLoop s i
538 | i < base = s
539 | otherwise = let s1 = s + 1 in s1 `seq` underLoop s1 (div i base)
540
541
542
543
544
545
546