Implement deterministic CallInfoSet
authorBartosz Nitka <niteria@gmail.com>
Mon, 6 Jun 2016 11:36:21 +0000 (04:36 -0700)
committerBartosz Nitka <niteria@gmail.com>
Mon, 25 Jul 2016 14:41:26 +0000 (07:41 -0700)
We need CallInfoSet to be deterministic because it determines the
order that the binds get generated.

Currently it's not deterministic because it's keyed on
`CallKey = [Maybe Type]` and `Ord CallKey` is implemented
with `cmpType` which is nondeterministic.

See Note [CallInfoSet determinism] for more details.

Test Plan: ./validate

Reviewers: simonpj, bgamari, austin, simonmar

Reviewed By: simonmar

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D2242

GHC Trac Issues: #4012

(cherry picked from commit 48e9a1f5521fa3185510d144dd28a87e452ce134)

compiler/specialise/Specialise.hs

index 97e294d..0c1d398 100644 (file)
@@ -36,6 +36,7 @@ import Outputable
 import FastString
 import State
 import UniqDFM
+import TrieMap
 
 #if __GLASGOW_HASKELL__ < 709
 import Control.Applicative (Applicative(..))
@@ -44,9 +45,6 @@ import Control.Monad
 #if __GLASGOW_HASKELL__ > 710
 import qualified Control.Monad.Fail as MonadFail
 #endif
-import Data.Map (Map)
-import qualified Data.Map as Map
-import qualified FiniteMap as Map
 
 {-
 ************************************************************************
@@ -663,10 +661,10 @@ specImports dflags this_mod top_env done callers rule_base cds
   where
     go :: RuleBase -> [CallInfoSet] -> CoreM ([CoreRule], [CoreBind])
     go _ [] = return ([], [])
-    go rb (CIS fn calls_for_fn : other_calls)
+    go rb (cis@(CIS fn _calls_for_fn) : other_calls)
       = do { (rules1, spec_binds1) <- specImport dflags this_mod top_env
                                                  done callers rb fn $
-                                      Map.toList calls_for_fn
+                                      ciSetToList cis
            ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
            ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
 
@@ -1731,19 +1729,71 @@ type CallDetails  = DIdEnv CallInfoSet
   -- The order of specialized binds and rules depends on how we linearize
   -- CallDetails, so to get determinism we must use a deterministic set here.
   -- See Note [Deterministic UniqFM] in UniqDFM
-newtype CallKey   = CallKey [Maybe Type]                        -- Nothing => unconstrained type argument
-
--- CallInfo uses a Map, thereby ensuring that
--- we record only one call instance for any key
---
--- The list of types and dictionaries is guaranteed to
--- match the type of f
-data CallInfoSet = CIS Id (Map CallKey ([DictExpr], VarSet))
-                        -- Range is dict args and the vars of the whole
-                        -- call (including tyvars)
-                        -- [*not* include the main id itself, of course]
+newtype CallKey   = CallKey [Maybe Type]
+  -- Nothing => unconstrained type argument
+
+data CallInfoSet = CIS Id (Bag CallInfo)
+  -- The list of types and dictionaries is guaranteed to
+  -- match the type of f
+
+{-
+Note [CallInfoSet determinism]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+CallInfoSet holds a Bag of (CallKey, [DictExpr], VarSet) triplets for a given
+Id. They represent the types that the function is instantiated at along with
+the dictionaries and free variables.
+
+We use this information to generate specialized versions of a given function.
+CallInfoSet used to be defined as:
+
+  data CallInfoSet = CIS Id (Map CallKey ([DictExpr], VarSet))
+
+Unfortunately this was not deterministic. The Ord instance of CallKey was
+defined in terms of cmpType which is not deterministic.
+See Note [cmpType nondeterminism].
+The end result was that if the function had multiple specializations they would
+be generated in arbitrary order.
+
+We need a container that:
+a) when turned into a list has only one element per each CallKey and the list
+has deterministic order
+b) supports union
+c) supports singleton
+d) supports filter
+
+We can't use UniqDFM here because there's no one Unique that we can key on.
+
+The current approach is to implement the set as a Bag with duplicates.
+This makes b), c), d) trivial and pushes a) towards the end. The deduplication
+is done by using a TrieMap for membership tests on CallKey. This lets us delete
+the nondeterministic Ord CallKey instance.
+
+An alternative approach would be to augument the Map the same way that UniqDFM
+is augumented, by keeping track of insertion order and using it to order the
+resulting lists. It would mean keeping the nondeterministic Ord CallKey
+instance making it easy to reintroduce nondeterminism in the future.
+-}
+
+ciSetToList :: CallInfoSet -> [CallInfo]
+ciSetToList (CIS _ b) = snd $ foldrBag combine (emptyTM, []) b
+  where
+  -- This is where we eliminate duplicates, recording the CallKeys we've
+  -- already seen in the TrieMap. See Note [CallInfoSet determinism].
+  combine :: CallInfo -> (CallKeySet, [CallInfo]) -> (CallKeySet, [CallInfo])
+  combine ci@(CallKey key, _) (set, acc)
+    | Just _ <- lookupTM key set = (set, acc)
+    | otherwise = (insertTM key () set, ci:acc)
+
+type CallKeySet = ListMap (MaybeMap TypeMap) ()
+  -- We only use it in ciSetToList to check for membership
+
+ciSetFilter :: (CallInfo -> Bool) -> CallInfoSet -> CallInfoSet
+ciSetFilter p (CIS id a) = CIS id (filterBag p a)
 
 type CallInfo = (CallKey, ([DictExpr], VarSet))
+                    -- Range is dict args and the vars of the whole
+                    -- call (including tyvars)
+                    -- [*not* include the main id itself, of course]
 
 instance Outputable CallInfoSet where
   ppr (CIS fn map) = hang (text "CIS" <+> ppr fn)
@@ -1761,24 +1811,12 @@ ppr_call_key_ty (Just ty) = char '@' <+> pprParendType ty
 instance Outputable CallKey where
   ppr (CallKey ts) = ppr ts
 
--- Type isn't an instance of Ord, so that we can control which
--- instance we use.  That's tiresome here.  Oh well
-instance Eq CallKey where
-  k1 == k2 = case k1 `compare` k2 of { EQ -> True; _ -> False }
-
-instance Ord CallKey where
-  compare (CallKey k1) (CallKey k2) = cmpList cmp k1 k2
-                where
-                  cmp Nothing   Nothing   = EQ
-                  cmp Nothing   (Just _)  = LT
-                  cmp (Just _)  Nothing   = GT
-                  cmp (Just t1) (Just t2) = cmpType t1 t2
-
 unionCalls :: CallDetails -> CallDetails -> CallDetails
 unionCalls c1 c2 = plusDVarEnv_C unionCallInfoSet c1 c2
 
 unionCallInfoSet :: CallInfoSet -> CallInfoSet -> CallInfoSet
-unionCallInfoSet (CIS f calls1) (CIS _ calls2) = CIS f (calls1 `Map.union` calls2)
+unionCallInfoSet (CIS f calls1) (CIS _ calls2) =
+  CIS f (calls1 `unionBags` calls2)
 
 callDetailsFVs :: CallDetails -> VarSet
 callDetailsFVs calls =
@@ -1787,14 +1825,15 @@ callDetailsFVs calls =
   -- immediately by converting to a nondeterministic set.
 
 callInfoFVs :: CallInfoSet -> VarSet
-callInfoFVs (CIS _ call_info) = Map.foldRight (\(_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
+callInfoFVs (CIS _ call_info) =
+  foldrBag (\(_, (_,fv)) vs -> unionVarSet fv vs) emptyVarSet call_info
 
 ------------------------------------------------------------
 singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails
 singleCall id tys dicts
   = MkUD {ud_binds = emptyBag,
           ud_calls = unitDVarEnv id $ CIS id $
-                     Map.singleton (CallKey tys) (dicts, call_fvs) }
+                     unitBag (CallKey tys, (dicts, call_fvs)) }
   where
     call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
     tys_fvs  = tyCoVarsOfTypes (catMaybes tys)
@@ -2047,7 +2086,7 @@ callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
                           , ud_calls = delDVarEnv orig_calls fn }
     calls_for_me = case lookupDVarEnv orig_calls fn of
                         Nothing -> []
-                        Just (CIS _ calls) -> filter_dfuns (Map.toList calls)
+                        Just cis -> filter_dfuns (ciSetToList cis)
 
     dep_set = foldlBag go (unitVarSet fn) orig_dbs
     go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set
@@ -2081,11 +2120,9 @@ splitDictBinds dbs bndr_set
 deleteCallsMentioning :: VarSet -> CallDetails -> CallDetails
 -- Remove calls *mentioning* bs
 deleteCallsMentioning bs calls
-  = mapDVarEnv filter_calls calls
+  = mapDVarEnv (ciSetFilter keep_call) calls
   where
-    filter_calls :: CallInfoSet -> CallInfoSet
-    filter_calls (CIS f calls) = CIS f (Map.filter keep_call calls)
-    keep_call (_, fvs) = not (fvs `intersectsVarSet` bs)
+    keep_call (_, (_, fvs)) = not (fvs `intersectsVarSet` bs)
 
 deleteCallsFor :: [Id] -> CallDetails -> CallDetails
 -- Remove calls *for* bs