Fix treatment of -0.0
authorBen Gamari <bgamari.foss@gmail.com>
Fri, 2 Oct 2015 13:40:43 +0000 (15:40 +0200)
committerBen Gamari <ben@smart-cactus.org>
Fri, 2 Oct 2015 13:51:09 +0000 (15:51 +0200)
Here we fix a few mis-optimizations that could occur in code with
floating point comparisons with -0.0. These issues arose from our
insistence on rewriting equalities into case analyses and the
simplifier's ignorance of floating-point semantics.

For instance, in Trac #10215 (and the similar issue Trac #9238) we
turned `ds == 0.0` into a case analysis,

```
case ds of
    __DEFAULT -> ...
    0.0 -> ...
```

Where the second alternative matches where `ds` is +0.0 and *also* -0.0.
However, the simplifier doesn't realize this and will introduce a local
inlining of `ds = -- +0.0` as it believes this is the only
value that matches this pattern.

Instead of teaching the simplifier about floating-point semantics
we simply prohibit case analysis on floating-point scrutinees and keep
this logic in the comparison primops, where it belongs.

We do several things here,

 - Add test cases from relevant tickets
 - Clean up a bit of documentation
 - Desugar literal matches against floats into applications of the
   appropriate equality primitive instead of case analysis
 - Add a CoreLint to ensure we don't pattern match on floats in Core

Test Plan: validate with included testcases

Reviewers: goldfire, simonpj, austin

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D1061

GHC Trac Issues: #10215, #9238

compiler/coreSyn/CoreLint.hs
compiler/coreSyn/CoreSyn.hs
compiler/deSugar/MatchLit.hs
compiler/prelude/PrelRules.hs
compiler/typecheck/TcType.hs
testsuite/tests/deSugar/should_run/T10215.hs [new file with mode: 0644]
testsuite/tests/deSugar/should_run/T10215.stdout [new file with mode: 0644]
testsuite/tests/deSugar/should_run/T9238.hs [new file with mode: 0644]
testsuite/tests/deSugar/should_run/T9238.stdout [new file with mode: 0644]
testsuite/tests/deSugar/should_run/all.T

index 0b72ff4..ea1d968 100644 (file)
@@ -32,6 +32,7 @@ import Literal
 import DataCon
 import TysWiredIn
 import TysPrim
+import TcType ( isFloatingTy )
 import Var
 import VarEnv
 import VarSet
@@ -662,6 +663,15 @@ lintCoreExpr e@(Case scrut var alt_ty alts) =
           (ptext (sLit "No alternatives for a case scrutinee not known to diverge for sure:") <+> ppr scrut)
         }
 
+     -- See Note [Rules for floating-point comparisons] in PrelRules
+     ; let isLitPat (LitAlt _, _ , _) = True
+           isLitPat _                 = False
+     ; checkL (not $ isFloatingTy scrut_ty && any isLitPat alts)
+         (ptext (sLit $ "Lint warning: Scrutinising floating-point " ++
+                        "expression with literal pattern in case " ++
+                        "analysis (see Trac #9238).")
+          $$ text "scrut" <+> ppr scrut)
+
      ; case tyConAppTyCon_maybe (idType var) of
          Just tycon
               | debugIsOn &&
index fedf1d7..24ce641 100644 (file)
@@ -233,6 +233,10 @@ These data types are the heart of the compiler
 --       The inner case does not need a @Red@ alternative, because @x@
 --       can't be @Red@ at that program point.
 --
+--    5. Floating-point values must not be scrutinised against literals.
+--       See Trac #9238 and Note [Rules for floating-point comparisons]
+--       in PrelRules for rationale.
+--
 -- *  Cast an expression to a particular type.
 --    This is used to implement @newtype@s (a @newtype@ constructor or
 --    destructor just becomes a 'Cast' in Core) and GADTs.
@@ -329,6 +333,9 @@ simplifier calling findAlt with argument (LitAlt 3).  No no.  Integer
 literals are an opaque encoding of an algebraic data type, not of
 an unlifted literal, like all the others.
 
+Also, we do not permit case analysis with literal patterns on floating-point
+types. See Trac #9238 and Note [Rules for floating-point comparisons] in
+PrelRules for the rationale for this restriction.
 
 -------------------------- CoreSyn INVARIANTS ---------------------------
 
index 25021f5..fddfa80 100644 (file)
@@ -295,10 +295,12 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _
                             = mk_con_pat intDataCon    (HsIntPrim    "" int_lit)
   | isWordTy ty,   Just int_lit <- mb_int_lit
                             = mk_con_pat wordDataCon   (HsWordPrim   "" int_lit)
-  | isFloatTy ty,  Just rat_lit <- mb_rat_lit = mk_con_pat floatDataCon  (HsFloatPrim  rat_lit)
-  | isDoubleTy ty, Just rat_lit <- mb_rat_lit = mk_con_pat doubleDataCon (HsDoublePrim rat_lit)
   | isStringTy ty, Just str_lit <- mb_str_lit
                             = tidy_lit_pat (HsString "" str_lit)
+     -- NB: do /not/ convert Float or Double literals to F# 3.8 or D# 5.3
+     -- If we do convert to the constructor form, we'll generate a case
+     -- expression on a Float# or Double# and that's not allowed in Core; see
+     -- Trac #9238 and Note [Rules for floating-point comparisons] in PrelRules
   where
     mk_con_pat :: DataCon -> HsLit -> Pat Id
     mk_con_pat con lit = unLoc (mkPrefixConPat con [noLoc $ LitPat lit] [])
@@ -309,15 +311,6 @@ tidyNPat tidy_lit_pat (OverLit val False _ ty) mb_neg _
                    (Just _,  HsIntegral _ i) -> Just (-i)
                    _ -> Nothing
 
-    mb_rat_lit :: Maybe FractionalLit
-    mb_rat_lit = case (mb_neg, val) of
-       (Nothing, HsIntegral _ i) -> Just (integralFractionalLit (fromInteger i))
-       (Just _,  HsIntegral _ i) -> Just (integralFractionalLit
-                                                             (fromInteger (-i)))
-       (Nothing, HsFractional f) -> Just f
-       (Just _, HsFractional f)  -> Just (negateFractionalLit f)
-       _ -> Nothing
-
     mb_str_lit :: Maybe FastString
     mb_str_lit = case (mb_neg, val) of
                    (Nothing, HsIsString _ s) -> Just s
index d44c224..f87dce4 100644 (file)
@@ -241,19 +241,19 @@ primOpRules nm CharGeOp   = mkRelOpRule nm (>=) [ boundsCmp Ge ]
 primOpRules nm CharLeOp   = mkRelOpRule nm (<=) [ boundsCmp Le ]
 primOpRules nm CharLtOp   = mkRelOpRule nm (<)  [ boundsCmp Lt ]
 
-primOpRules nm FloatGtOp  = mkFloatingRelOpRule nm (>)  []
-primOpRules nm FloatGeOp  = mkFloatingRelOpRule nm (>=) []
-primOpRules nm FloatLeOp  = mkFloatingRelOpRule nm (<=) []
-primOpRules nm FloatLtOp  = mkFloatingRelOpRule nm (<)  []
-primOpRules nm FloatEqOp  = mkFloatingRelOpRule nm (==) [ litEq True ]
-primOpRules nm FloatNeOp  = mkFloatingRelOpRule nm (/=) [ litEq False ]
-
-primOpRules nm DoubleGtOp = mkFloatingRelOpRule nm (>)  []
-primOpRules nm DoubleGeOp = mkFloatingRelOpRule nm (>=) []
-primOpRules nm DoubleLeOp = mkFloatingRelOpRule nm (<=) []
-primOpRules nm DoubleLtOp = mkFloatingRelOpRule nm (<)  []
-primOpRules nm DoubleEqOp = mkFloatingRelOpRule nm (==) [ litEq True ]
-primOpRules nm DoubleNeOp = mkFloatingRelOpRule nm (/=) [ litEq False ]
+primOpRules nm FloatGtOp  = mkFloatingRelOpRule nm (>)
+primOpRules nm FloatGeOp  = mkFloatingRelOpRule nm (>=)
+primOpRules nm FloatLeOp  = mkFloatingRelOpRule nm (<=)
+primOpRules nm FloatLtOp  = mkFloatingRelOpRule nm (<)
+primOpRules nm FloatEqOp  = mkFloatingRelOpRule nm (==)
+primOpRules nm FloatNeOp  = mkFloatingRelOpRule nm (/=)
+
+primOpRules nm DoubleGtOp = mkFloatingRelOpRule nm (>)
+primOpRules nm DoubleGeOp = mkFloatingRelOpRule nm (>=)
+primOpRules nm DoubleLeOp = mkFloatingRelOpRule nm (<=)
+primOpRules nm DoubleLtOp = mkFloatingRelOpRule nm (<)
+primOpRules nm DoubleEqOp = mkFloatingRelOpRule nm (==)
+primOpRules nm DoubleNeOp = mkFloatingRelOpRule nm (/=)
 
 primOpRules nm WordGtOp   = mkRelOpRule nm (>)  [ boundsCmp Gt ]
 primOpRules nm WordGeOp   = mkRelOpRule nm (>=) [ boundsCmp Ge ]
@@ -284,29 +284,49 @@ mkPrimOpRule nm arity rules = Just $ mkBasicRule nm arity (msum rules)
 mkRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
             -> [RuleM CoreExpr] -> Maybe CoreRule
 mkRelOpRule nm cmp extra
-  = mkPrimOpRule nm 2 $ rules ++ extra
+  = mkPrimOpRule nm 2 $
+    binaryCmpLit cmp : equal_rule : extra
   where
-    rules = [ binaryCmpLit cmp
-            , do equalArgs
-              -- x `cmp` x does not depend on x, so
-              -- compute it for the arbitrary value 'True'
-              -- and use that result
-                 dflags <- getDynFlags
-                 return (if cmp True True
-                         then trueValInt  dflags
-                         else falseValInt dflags) ]
-
--- Note [Rules for floating-point comparisons]
--- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
---
--- We need different rules for floating-point values because for floats
--- it is not true that x = x. The special case when this does not occur
--- are NaNs.
+        -- x `cmp` x does not depend on x, so
+        -- compute it for the arbitrary value 'True'
+        -- and use that result
+    equal_rule = do { equalArgs
+                    ; dflags <- getDynFlags
+                    ; return (if cmp True True
+                              then trueValInt  dflags
+                              else falseValInt dflags) }
+
+{- Note [Rules for floating-point comparisons]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We need different rules for floating-point values because for floats
+it is not true that x = x (for NaNs); so we do not want the equal_rule
+rule that mkRelOpRule uses.
+
+Note also that, in the case of equality/inequality, we do /not/
+want to switch to a case-expression.  For example, we do not want
+to convert
+   case (eqFloat# x 3.8#) of
+     True -> this
+     False -> that
+to
+  case x of
+    3.8#::Float# -> this
+    _            -> that
+See Trac #9238.  Reason: comparing floating-point values for equality
+delicate, and we don't want to implement that delicacy in the code for
+case expressions.  So we make it an invariant of Core that a case
+expression never scrutinises a Float# or Double#.
+
+This transformation is what the litEq rule does;
+see Note [The litEq rule: converting equality to case].
+So we /refrain/ from using litEq for mkFloatingRelOpRule.
+-}
 
 mkFloatingRelOpRule :: Name -> (forall a . Ord a => a -> a -> Bool)
-                    -> [RuleM CoreExpr] -> Maybe CoreRule
-mkFloatingRelOpRule nm cmp extra -- See Note [Rules for floating-point comparisons]
-  = mkPrimOpRule nm 2 $ binaryCmpLit cmp : extra
+                    -> Maybe CoreRule
+-- See Note [Rules for floating-point comparisons]
+mkFloatingRelOpRule nm cmp
+  = mkPrimOpRule nm 2 [binaryCmpLit cmp]
 
 -- common constants
 zeroi, onei, zerow, onew :: DynFlags -> Literal
@@ -428,24 +448,27 @@ doubleOp2 op dflags (MachDouble f1) (MachDouble f2)
 doubleOp2 _ _ _ _ = Nothing
 
 --------------------------
--- This stuff turns
---      n ==# 3#
--- into
---      case n of
---        3# -> True
---        m  -> False
---
--- This is a Good Thing, because it allows case-of case things
--- to happen, and case-default absorption to happen.  For
--- example:
---
---      if (n ==# 3#) || (n ==# 4#) then e1 else e2
--- will transform to
---      case n of
---        3# -> e1
---        4# -> e1
---        m  -> e2
--- (modulo the usual precautions to avoid duplicating e1)
+{- Note [The litEq rule: converting equality to case]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+This stuff turns
+     n ==# 3#
+into
+     case n of
+       3# -> True
+       m  -> False
+
+This is a Good Thing, because it allows case-of case things
+to happen, and case-default absorption to happen.  For
+example:
+
+     if (n ==# 3#) || (n ==# 4#) then e1 else e2
+will transform to
+     case n of
+       3# -> e1
+       4# -> e1
+       m  -> e2
+(modulo the usual precautions to avoid duplicating e1)
+-}
 
 litEq :: Bool  -- True <=> equality, False <=> inequality
       -> RuleM CoreExpr
index ffaef16..e5f49e4 100644 (file)
@@ -65,7 +65,7 @@ module TcType (
   eqType, eqTypes, eqPred, cmpType, cmpTypes, cmpPred, eqTypeX,
   tcEqType, tcEqKind,
   isSigmaTy, isRhoTy, isOverloadedTy,
-  isDoubleTy, isFloatTy, isIntTy, isWordTy, isStringTy,
+  isFloatingTy, isDoubleTy, isFloatTy, isIntTy, isWordTy, isStringTy,
   isIntegerTy, isBoolTy, isUnitTy, isCharTy,
   isTauTy, isTauTyCon, tcIsTyVarTy, tcIsForAllTy,
   isPredTy, isTyVarClassPred, isTyVarExposed, isTyVarUnderDatatype,
@@ -1439,6 +1439,11 @@ isUnitTy       = is_tc unitTyConKey
 isCharTy       = is_tc charTyConKey
 isAnyTy        = is_tc anyTyConKey
 
+-- | Does a type represent a floating-point number?
+isFloatingTy :: Type -> Bool
+isFloatingTy ty = isFloatTy ty || isDoubleTy ty
+
+-- | Is a type 'String'?
 isStringTy :: Type -> Bool
 isStringTy ty
   = case tcSplitTyConApp_maybe ty of
diff --git a/testsuite/tests/deSugar/should_run/T10215.hs b/testsuite/tests/deSugar/should_run/T10215.hs
new file mode 100644 (file)
index 0000000..9a2d224
--- /dev/null
@@ -0,0 +1,9 @@
+testF :: Float -> Bool
+testF x = x == 0 && not (isNegativeZero x)
+
+testD :: Double -> Bool
+testD x = x == 0 && not (isNegativeZero x)
+
+main :: IO ()
+main = do print $ testF (-0.0)
+          print $ testD (-0.0)
diff --git a/testsuite/tests/deSugar/should_run/T10215.stdout b/testsuite/tests/deSugar/should_run/T10215.stdout
new file mode 100644 (file)
index 0000000..abb2393
--- /dev/null
@@ -0,0 +1,3 @@
+False
+False
+
diff --git a/testsuite/tests/deSugar/should_run/T9238.hs b/testsuite/tests/deSugar/should_run/T9238.hs
new file mode 100644 (file)
index 0000000..79eeeb7
--- /dev/null
@@ -0,0 +1,16 @@
+compareDouble :: Double -> Double -> Ordering
+compareDouble x y =
+       case (isNaN x, isNaN y) of
+       (True, True)   -> EQ
+       (True, False)  -> LT
+       (False, True)  -> GT
+       (False, False) ->
+          -- Make -0 less than 0
+          case (x == 0, y == 0, isNegativeZero x, isNegativeZero y) of
+          (True, True, True, False) -> LT
+          (True, True, False, True) -> GT
+          _                         -> x `compare` y
+
+main = do
+    let l = [-0, 0]
+    print [ (x, y, compareDouble x y) | x <- l, y <- l ]
diff --git a/testsuite/tests/deSugar/should_run/T9238.stdout b/testsuite/tests/deSugar/should_run/T9238.stdout
new file mode 100644 (file)
index 0000000..8dbd09d
--- /dev/null
@@ -0,0 +1,2 @@
+[(-0.0,-0.0,EQ),(-0.0,0.0,LT),(0.0,-0.0,GT),(0.0,0.0,EQ)]
+
index 228b90d..bc72b01 100644 (file)
@@ -46,5 +46,7 @@ test('DsStaticPointers',
      ],
      compile_and_run, [''])
 test('T8952', normal, compile_and_run, [''])
+test('T9238', normal, compile_and_run, [''])
 test('T9844', normal, compile_and_run, [''])
+test('T10215', normal, compile_and_run, [''])
 test('DsStrictData', normal, compile_and_run, [''])