Hoopl: remove dependency on Hoopl package
[ghc.git] / compiler / cmm / CmmCommonBlockElim.hs
index 9b484f9..3c23e70 100644 (file)
@@ -1,9 +1,4 @@
-{-# LANGUAGE GADTs #-}
--- ToDo: remove -fno-warn-warnings-deprecations
-{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
--- ToDo: remove -fno-warn-incomplete-patterns
-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
-
+{-# LANGUAGE GADTs, BangPatterns #-}
 module CmmCommonBlockElim
   ( elimCommonBlocks
   )
@@ -13,19 +8,26 @@ where
 import BlockId
 import Cmm
 import CmmUtils
+import CmmSwitch (eqSwitchTargetWith)
 import CmmContFlowOpt
+-- import PprCmm ()
 import Prelude hiding (iterate, succ, unzip, zip)
 
-import Hoopl hiding (ChangeFlag)
+import Hoopl.Block
+import Hoopl.Graph
+import Hoopl.Label
+import Hoopl.Collections
 import Data.Bits
+import Data.Maybe (mapMaybe)
 import qualified Data.List as List
 import Data.Word
-import FastString
+import qualified Data.Map as M
 import Outputable
 import UniqFM
-
-my_trace :: String -> SDoc -> a -> a
-my_trace = if False then pprTrace else \_ _ a -> a
+import UniqDFM
+import qualified TrieMap as TM
+import Unique
+import Control.Arrow (first, second)
 
 -- -----------------------------------------------------------------------------
 -- Eliminate common blocks
@@ -41,39 +43,72 @@ my_trace = if False then pprTrace else \_ _ a -> a
 -- is made redundant by the old block.
 -- Otherwise, it is added to the useful blocks.
 
+-- To avoid comparing every block with every other block repeatedly, we group
+-- them by
+--   * a hash of the block, ignoring labels (explained below)
+--   * the list of outgoing labels
+-- The hash is invariant under relabeling, so we only ever compare within
+-- the same group of blocks.
+--
+-- The list of outgoing labels is updated as we merge blocks (that is why they
+-- are not included in the hash, which we want to calculate only once).
+--
+-- All in all, two blocks should never be compared if they have different
+-- hashes, and at most once otherwise. Previously, we were slower, and people
+-- rightfully complained: #10397
+
 -- TODO: Use optimization fuel
 elimCommonBlocks :: CmmGraph -> CmmGraph
-elimCommonBlocks g = replaceLabels env g
+elimCommonBlocks g = replaceLabels env $ copyTicks env g
   where
-     env = iterate hashed_blocks mapEmpty
-     hashed_blocks = map (\b -> (hash_block b, b)) $ postorderDfs g
+     env = iterate mapEmpty blocks_with_key
+     groups = groupByInt hash_block (postorderDfs g)
+     blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups]
+
+-- Invariant: The blocks in the list are pairwise distinct
+-- (so avoid comparing them again)
+type DistinctBlocks = [CmmBlock]
+type Key = [Label]
+type Subst = LabelMap BlockId
 
--- Iterate over the blocks until convergence
-iterate :: [(HashCode,CmmBlock)] -> BlockEnv BlockId -> BlockEnv BlockId
-iterate blocks subst =
-  case foldl common_block (False, emptyUFM, subst) blocks of
-    (changed,  _, subst)
-       | changed   -> iterate blocks subst
-       | otherwise -> subst
+-- The outer list groups by hash. We retain this grouping throughout.
+iterate :: Subst -> [[(Key, DistinctBlocks)]] -> Subst
+iterate subst blocks
+    | mapNull new_substs = subst
+    | otherwise = iterate subst' updated_blocks
+  where
+    grouped_blocks :: [[(Key, [DistinctBlocks])]]
+    grouped_blocks = map groupByLabel blocks
 
-type State  = (ChangeFlag, UniqFM [CmmBlock], BlockEnv BlockId)
+    merged_blocks :: [[(Key, DistinctBlocks)]]
+    (new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
+      where
+        go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db))
+          where
+            (new_subst2, db) = mergeBlockList subst dbs
 
-type ChangeFlag = Bool
-type HashCode = Int
+    subst' = subst `mapUnion` new_substs
+    updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks
 
--- Try to find a block that is equal (or ``common'') to b.
-common_block :: State -> (HashCode, CmmBlock) -> State
-common_block (old_change, bmap, subst) (hash, b) =
-  case lookupUFM bmap hash of
-    Just bs -> case (List.find (eqBlockBodyWith (eqBid subst) b) bs,
-                     mapLookup bid subst) of
-                 (Just b', Nothing)                         -> addSubst b'
-                 (Just b', Just b'') | entryLabel b' /= b'' -> addSubst b'
-                 _ -> (old_change, addToUFM bmap hash (b : bs), subst)
-    Nothing -> (old_change, addToUFM bmap hash [b], subst)
-  where bid = entryLabel b
-        addSubst b' = my_trace "found new common block" (ppr (entryLabel b')) $
-                      (True, bmap, mapInsert bid (entryLabel b') subst)
+mergeBlocks :: Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks)
+mergeBlocks subst existing new = go new
+  where
+    go [] = (mapEmpty, existing)
+    go (b:bs) = case List.find (eqBlockBodyWith (eqBid subst) b) existing of
+        -- This block is a duplicate. Drop it, and add it to the substitution
+        Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
+        -- This block is not a duplicate, keep it.
+        Nothing -> second (b:) $ go bs
+
+mergeBlockList :: Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks)
+mergeBlockList _ [] = pprPanic "mergeBlockList" empty
+mergeBlockList subst (b:bs) = go mapEmpty b bs
+  where
+    go !new_subst1 b [] = (new_subst1, b)
+    go !new_subst1 b1 (b2:bs) = go new_subst b bs
+      where
+        (new_subst2, b) =  mergeBlocks subst b1 b2
+        new_subst = new_subst1 `mapUnion` new_subst2
 
 
 -- -----------------------------------------------------------------------------
@@ -82,9 +117,16 @@ common_block (old_change, bmap, subst) (hash, b) =
 -- Below here is mostly boilerplate: hashing blocks ignoring labels,
 -- and comparing blocks modulo a label mapping.
 
--- To speed up comparisons, we hash each basic block modulo labels.
+-- To speed up comparisons, we hash each basic block modulo jump labels.
 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
 -- but it should be fast and good enough.
+
+-- We want to get as many small buckets as possible, as comparing blocks is
+-- expensive. So include as much as possible in the hash. Ideally everything
+-- that is compared with (==) in eqBlockBodyWith.
+
+type HashCode = Int
+
 hash_block :: CmmBlock -> HashCode
 hash_block block =
   fromIntegral (foldBlockNodesB3 (hash_fst, hash_mid, hash_lst) block (0 :: Word32) .&. (0x7fffffff :: Word32))
@@ -94,18 +136,19 @@ hash_block block =
         hash_lst m h = hash_node m + h `shiftL` 1
 
         hash_node :: CmmNode O x -> Word32
-        hash_node (CmmComment (FastString u _ _ _ _)) = cvt u
+        hash_node n | dont_care n = 0 -- don't care
         hash_node (CmmAssign r e) = hash_reg r + hash_e e
         hash_node (CmmStore e e') = hash_e e + hash_e e'
         hash_node (CmmUnsafeForeignCall t _ as) = hash_tgt t + hash_list hash_e as
         hash_node (CmmBranch _) = 23 -- NB. ignore the label
-        hash_node (CmmCondBranch p _ _) = hash_e p
-        hash_node (CmmCall e _ _ _ _) = hash_e e
-        hash_node (CmmForeignCall t _ _ _ _ _) = hash_tgt t
+        hash_node (CmmCondBranch p _ _ _) = hash_e p
+        hash_node (CmmCall e _ _ _ _ _) = hash_e e
+        hash_node (CmmForeignCall t _ _ _ _ _ _) = hash_tgt t
         hash_node (CmmSwitch e _) = hash_e e
+        hash_node _ = error "hash_node: unknown Cmm node!"
 
         hash_reg :: CmmReg -> Word32
-        hash_reg   (CmmLocal _) = 117
+        hash_reg   (CmmLocal localReg) = hash_unique localReg -- important for performance, see #10397
         hash_reg   (CmmGlobal _)    = 19
 
         hash_e :: CmmExpr -> Word32
@@ -119,6 +162,7 @@ hash_block block =
         hash_lit :: CmmLit -> Word32
         hash_lit (CmmInt i _) = fromInteger i
         hash_lit (CmmFloat r _) = truncate r
+        hash_lit (CmmVec ls) = hash_list hash_lit ls
         hash_lit (CmmLabel _) = 119 -- ugh
         hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
         hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
@@ -131,38 +175,131 @@ hash_block block =
         hash_list f = foldl (\z x -> f x + z) (0::Word32)
 
         cvt = fromInteger . toInteger
+
+        hash_unique :: Uniquable a => a -> Word32
+        hash_unique = cvt . getKey . getUnique
+
+-- | Ignore these node types for equality
+dont_care :: CmmNode O x -> Bool
+dont_care CmmComment {}  = True
+dont_care CmmTick {}     = True
+dont_care CmmUnwind {}   = True
+dont_care _other         = False
+
 -- Utilities: equality and substitution on the graph.
 
 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
-eqBid :: BlockEnv BlockId -> BlockId -> BlockId -> Bool
+eqBid :: LabelMap BlockId -> BlockId -> BlockId -> Bool
 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
-lookupBid :: BlockEnv BlockId -> BlockId -> BlockId
+lookupBid :: LabelMap BlockId -> BlockId -> BlockId
 lookupBid subst bid = case mapLookup bid subst of
                         Just bid  -> lookupBid subst bid
                         Nothing -> bid
 
+-- Middle nodes and expressions can contain BlockIds, in particular in
+-- CmmStackSlot and CmmBlock, so we have to use a special equality for
+-- these.
+--
+eqMiddleWith :: (BlockId -> BlockId -> Bool)
+             -> CmmNode O O -> CmmNode O O -> Bool
+eqMiddleWith eqBid (CmmAssign r1 e1) (CmmAssign r2 e2)
+  = r1 == r2 && eqExprWith eqBid e1 e2
+eqMiddleWith eqBid (CmmStore l1 r1) (CmmStore l2 r2)
+  = eqExprWith eqBid l1 l2 && eqExprWith eqBid r1 r2
+eqMiddleWith eqBid (CmmUnsafeForeignCall t1 r1 a1)
+                   (CmmUnsafeForeignCall t2 r2 a2)
+  = t1 == t2 && r1 == r2 && and (zipWith (eqExprWith eqBid) a1 a2)
+eqMiddleWith _ _ _ = False
+
+eqExprWith :: (BlockId -> BlockId -> Bool)
+           -> CmmExpr -> CmmExpr -> Bool
+eqExprWith eqBid = eq
+ where
+  CmmLit l1          `eq` CmmLit l2          = eqLit l1 l2
+  CmmLoad e1 _       `eq` CmmLoad e2 _       = e1 `eq` e2
+  CmmReg r1          `eq` CmmReg r2          = r1==r2
+  CmmRegOff r1 i1    `eq` CmmRegOff r2 i2    = r1==r2 && i1==i2
+  CmmMachOp op1 es1  `eq` CmmMachOp op2 es2  = op1==op2 && es1 `eqs` es2
+  CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2
+  _e1                `eq` _e2                = False
+
+  xs `eqs` ys = and (zipWith eq xs ys)
+
+  eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2
+  eqLit l1 l2 = l1 == l2
+
+  eqArea Old Old = True
+  eqArea (Young id1) (Young id2) = eqBid id1 id2
+  eqArea _ _ = False
+
 -- Equality on the body of a block, modulo a function mapping block
 -- IDs to block IDs.
 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
 eqBlockBodyWith eqBid block block'
-  = blockToList m == blockToList m' && eqLastWith eqBid l l'
+  {-
+  | equal     = pprTrace "equal" (vcat [ppr block, ppr block']) True
+  | otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False
+  -}
+  = equal
   where (_,m,l)   = blockSplit block
+        nodes     = filter (not . dont_care) (blockToList m)
         (_,m',l') = blockSplit block'
+        nodes'    = filter (not . dont_care) (blockToList m')
+
+        equal = and (zipWith (eqMiddleWith eqBid) nodes nodes') &&
+                eqLastWith eqBid l l'
+
 
 eqLastWith :: (BlockId -> BlockId -> Bool) -> CmmNode O C -> CmmNode O C -> Bool
 eqLastWith eqBid (CmmBranch bid1) (CmmBranch bid2) = eqBid bid1 bid2
-eqLastWith eqBid (CmmCondBranch c1 t1 f1) (CmmCondBranch c2 t2 f2) =
-  c1 == c2 && eqBid t1 t2 && eqBid f1 f2
-eqLastWith eqBid (CmmCall t1 c1 a1 r1 u1) (CmmCall t2 c2 a2 r2 u2) =
-  t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2
-eqLastWith eqBid (CmmSwitch e1 bs1) (CmmSwitch e2 bs2) =
-  e1 == e2 && eqListWith (eqMaybeWith eqBid) bs1 bs2
+eqLastWith eqBid (CmmCondBranch c1 t1 f1 l1) (CmmCondBranch c2 t2 f2 l2) =
+  c1 == c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
+eqLastWith eqBid (CmmCall t1 c1 g1 a1 r1 u1) (CmmCall t2 c2 g2 a2 r2 u2) =
+  t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
+eqLastWith eqBid (CmmSwitch e1 ids1) (CmmSwitch e2 ids2) =
+  e1 == e2 && eqSwitchTargetWith eqBid ids1 ids2
 eqLastWith _ _ _ = False
 
-eqListWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
-eqListWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
-
 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
 eqMaybeWith _ Nothing Nothing = True
 eqMaybeWith _ _ _ = False
+
+-- | Given a block map, ensure that all "target" blocks are covered by
+-- the same ticks as the respective "source" blocks. This not only
+-- means copying ticks, but also adjusting tick scopes where
+-- necessary.
+copyTicks :: LabelMap BlockId -> CmmGraph -> CmmGraph
+copyTicks env g
+  | mapNull env = g
+  | otherwise   = ofBlockMap (g_entry g) $ mapMap copyTo blockMap
+  where -- Reverse block merge map
+        blockMap = toBlockMap g
+        revEnv = mapFoldWithKey insertRev M.empty env
+        insertRev k x = M.insertWith (const (k:)) x [k]
+        -- Copy ticks and scopes into the given block
+        copyTo block = case M.lookup (entryLabel block) revEnv of
+          Nothing -> block
+          Just ls -> foldr copy block $ mapMaybe (flip mapLookup blockMap) ls
+        copy from to =
+          let ticks = blockTicks from
+              CmmEntry  _   scp0        = firstNode from
+              (CmmEntry lbl scp1, code) = blockSplitHead to
+          in CmmEntry lbl (combineTickScopes scp0 scp1) `blockJoinHead`
+             foldr blockCons code (map CmmTick ticks)
+
+-- Group by [Label]
+groupByLabel :: [(Key, a)] -> [(Key, [a])]
+groupByLabel = go (TM.emptyTM :: TM.ListMap UniqDFM a)
+  where
+    go !m [] = TM.foldTM (:) m []
+    go !m ((k,v) : entries) = go (TM.alterTM k' adjust m) entries
+      where k' = map getUnique k
+            adjust Nothing       = Just (k,[v])
+            adjust (Just (_,vs)) = Just (k,v:vs)
+
+
+groupByInt :: (a -> Int) -> [a] -> [[a]]
+groupByInt f xs = nonDetEltsUFM $ List.foldl' go emptyUFM xs
+  -- See Note [Unique Determinism and code generation]
+  where go m x = alterUFM (Just . maybe [x] (x:)) m (f x)