cmm/CBE: Use foldLocalRegsDefd
[ghc.git] / compiler / cmm / CmmCommonBlockElim.hs
index aca39bc..c83497e 100644 (file)
@@ -24,6 +24,7 @@ import qualified Data.List as List
 import Data.Word
 import qualified Data.Map as M
 import Outputable
+import DynFlags (DynFlags)
 import UniqFM
 import UniqDFM
 import qualified TrieMap as TM
@@ -59,11 +60,11 @@ import Control.Arrow (first, second)
 -- rightfully complained: #10397
 
 -- TODO: Use optimization fuel
-elimCommonBlocks :: CmmGraph -> CmmGraph
-elimCommonBlocks g = replaceLabels env $ copyTicks env g
+elimCommonBlocks :: DynFlags -> CmmGraph -> CmmGraph
+elimCommonBlocks dflags g = replaceLabels env $ copyTicks env g
   where
-     env = iterate mapEmpty blocks_with_key
-     groups = groupByInt hash_block (postorderDfs g)
+     env = iterate dflags mapEmpty blocks_with_key
+     groups = groupByInt (hash_block dflags) (postorderDfs g)
      blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups]
 
 -- Invariant: The blocks in the list are pairwise distinct
@@ -73,42 +74,47 @@ type Key = [Label]
 type Subst = LabelMap BlockId
 
 -- The outer list groups by hash. We retain this grouping throughout.
-iterate :: Subst -> [[(Key, DistinctBlocks)]] -> Subst
-iterate subst blocks
+iterate :: DynFlags -> Subst -> [[(Key, DistinctBlocks)]] -> Subst
+iterate dflags subst blocks
     | mapNull new_substs = subst
-    | otherwise = iterate subst' updated_blocks
+    | otherwise = iterate dflags subst' updated_blocks
   where
     grouped_blocks :: [[(Key, [DistinctBlocks])]]
     grouped_blocks = map groupByLabel blocks
 
     merged_blocks :: [[(Key, DistinctBlocks)]]
-    (new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
+    (new_substs, merged_blocks) =
+        List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks
       where
         go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db))
           where
-            (new_subst2, db) = mergeBlockList subst dbs
+            (new_subst2, db) = mergeBlockList dflags subst dbs
 
     subst' = subst `mapUnion` new_substs
     updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks
 
-mergeBlocks :: Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks)
-mergeBlocks subst existing new = go new
+mergeBlocks :: DynFlags -> Subst
+            -> DistinctBlocks -> DistinctBlocks
+            -> (Subst, DistinctBlocks)
+mergeBlocks dflags subst existing new = go new
   where
     go [] = (mapEmpty, existing)
-    go (b:bs) = case List.find (eqBlockBodyWith (eqBid subst) b) existing of
-        -- This block is a duplicate. Drop it, and add it to the substitution
-        Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
-        -- This block is not a duplicate, keep it.
-        Nothing -> second (b:) $ go bs
-
-mergeBlockList :: Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks)
-mergeBlockList _ [] = pprPanic "mergeBlockList" empty
-mergeBlockList subst (b:bs) = go mapEmpty b bs
+    go (b:bs) =
+        case List.find (eqBlockBodyWith dflags (eqBid subst) b) existing of
+          -- This block is a duplicate. Drop it, and add it to the substitution
+          Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs
+          -- This block is not a duplicate, keep it.
+          Nothing -> second (b:) $ go bs
+
+mergeBlockList :: DynFlags -> Subst -> [DistinctBlocks]
+               -> (Subst, DistinctBlocks)
+mergeBlockList _      _     [] = pprPanic "mergeBlockList" empty
+mergeBlockList dflags subst (b:bs) = go mapEmpty b bs
   where
     go !new_subst1 b [] = (new_subst1, b)
     go !new_subst1 b1 (b2:bs) = go new_subst b bs
       where
-        (new_subst2, b) =  mergeBlocks subst b1 b2
+        (new_subst2, b) =  mergeBlocks dflags subst b1 b2
         new_subst = new_subst1 `mapUnion` new_subst2
 
 
@@ -175,8 +181,8 @@ data HashEnv = HashEnv { localRegHashEnv :: !(LocalRegEnv DeBruijn)
                        , nextIndex       :: !DeBruijn
                        }
 
-hash_block :: CmmBlock -> HashCode
-hash_block block =
+hash_block :: DynFlags -> CmmBlock -> HashCode
+hash_block dflags block =
   --pprTrace "hash_block" (ppr (entryLabel block) $$ ppr hash)
   hash
   where hash_fst _ (env, h) = (env, h)
@@ -196,20 +202,24 @@ hash_block block =
 
         hash_node :: HashEnv -> CmmNode O x -> (HashEnv, Word32)
         hash_node env n =
-            case n of
-              n | dont_care n -> pure_ 0  -- don't care
-              CmmAssign (CmmLocal r) e -> (bind_local_reg r env, hash_e env e)
-              CmmAssign r e   -> pure_ $ hash_reg env r + hash_e env e
-              CmmStore e e'   -> pure_ $ hash_e env e + hash_e env e'
-              CmmUnsafeForeignCall t _ as
-                              -> pure_ $ hash_tgt env t + hash_list (hash_e env) as
-              CmmBranch _     -> pure_ 23 -- NB. ignore the label
-              CmmCondBranch p _ _ _ -> pure_ $ hash_e env p
-              CmmCall e _ _ _ _ _   -> pure_ $ hash_e env e
-              CmmForeignCall t _ _ _ _ _ _ -> pure_ $ hash_tgt env t
-              CmmSwitch e _   -> pure_ $ hash_e env e
-              _               -> error "hash_node: unknown Cmm node!"
-          where pure_ x = (env, x)
+            (env', hash)
+          where
+            hash =
+              case n of
+                n | dont_care n -> 0  -- don't care
+                -- don't include register as it is a binding occurrence
+                CmmAssign (CmmLocal _) e -> hash_e env e
+                CmmAssign r e   -> hash_reg env r + hash_e env e
+                CmmStore e e'   -> hash_e env e + hash_e env e'
+                CmmUnsafeForeignCall t _ as
+                                -> hash_tgt env t + hash_list (hash_e env) as
+                CmmBranch _     ->  23 -- NB. ignore the label
+                CmmCondBranch p _ _ _ -> hash_e env p
+                CmmCall e _ _ _ _ _   -> hash_e env e
+                CmmForeignCall t _ _ _ _ _ _ -> hash_tgt env t
+                CmmSwitch e _   -> hash_e env e
+                _               -> error "hash_node: unknown Cmm node!"
+            env' = foldLocalRegsDefd dflags (flip bind_local_reg) env n
 
         hash_reg :: HashEnv -> CmmReg -> Word32
         hash_reg env (CmmLocal localReg)
@@ -281,38 +291,45 @@ type LocalRegMapping = LocalRegEnv LocalReg
 -- CmmStackSlot and CmmBlock, so we have to use a special equality for
 -- these.
 --
-eqMiddleWith :: (BlockId -> BlockId -> Bool)
+eqMiddleWith :: DynFlags
+             -> (BlockId -> BlockId -> Bool)
              -> LocalRegMapping
              -> CmmNode O O -> CmmNode O O
              -> (LocalRegMapping, Bool)
-eqMiddleWith eqBid env a b =
+eqMiddleWith dflags eqBid env a b =
   case (a, b) of
-    (CmmAssign (CmmLocal r1) e1,  CmmAssign (CmmLocal r2) e2) ->
+     -- registers aren't compared since they are binding occurrences
+    (CmmAssign (CmmLocal _) e1,  CmmAssign (CmmLocal _) e2) ->
         let eq = eqExprWith eqBid env e1 e2
-            env' = addToUFM env r1 r2
         in (env', eq)
 
     (CmmAssign r1 e1,  CmmAssign r2 e2) ->
         let eq = r1 == r2
               && eqExprWith eqBid env e1 e2
-        in (env, eq)
+        in (env', eq)
 
     (CmmStore l1 r1,  CmmStore l2 r2) ->
         let eq = eqExprWith eqBid env l1 l2
               && eqExprWith eqBid env r1 r2
-        in (env, eq)
+        in (env', eq)
 
-    (CmmUnsafeForeignCall t1 r1 a1,  CmmUnsafeForeignCall t2 r2 a2) ->
+     -- result registers aren't compared since they are binding occurrences
+    (CmmUnsafeForeignCall t1 _ a1,  CmmUnsafeForeignCall t2 _ a2) ->
         let eq = t1 == t2
-              && r1 == r2
               && and (zipWith (eqExprWith eqBid env) a1 a2)
-        in (env, eq)
+        in (env', eq)
 
     _ -> (env, False)
+  where
+    env' = List.foldl' (\acc (ra,rb) -> addToUFM acc ra rb) emptyUFM
+           $ List.zip defd_a defd_b
+    defd_a = foldLocalRegsDefd dflags (flip (:)) [] a
+    defd_b = foldLocalRegsDefd dflags (flip (:)) [] b
 
 eqExprWith :: (BlockId -> BlockId -> Bool)
            -> LocalRegMapping
-           -> CmmExpr -> CmmExpr -> Bool
+           -> CmmExpr -> CmmExpr
+           -> Bool
 eqExprWith eqBid env = eq
  where
   CmmLit l1          `eq` CmmLit l2          = eqLit l1 l2
@@ -340,47 +357,50 @@ eqExprWith eqBid env = eq
 
 -- Equality on the body of a block, modulo a function mapping block
 -- IDs to block IDs.
-eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
-eqBlockBodyWith eqBid block block'
+eqBlockBodyWith :: DynFlags
+                -> (BlockId -> BlockId -> Bool)
+                -> CmmBlock -> CmmBlock -> Bool
+eqBlockBodyWith dflags eqBid block block'
   {-
   | equal     = pprTrace "equal" (vcat [ppr block, ppr block']) True
   | otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False
   -}
-  = equal_go emptyUFM nodes nodes'
+  = equal
   where (_,m,l)   = blockSplit block
         nodes     = filter (not . dont_care) (blockToList m)
         (_,m',l') = blockSplit block'
         nodes'    = filter (not . dont_care) (blockToList m')
 
-        -- Compare middle nodes, accumulating a local register mapping as we go.
-        -- We also must ensure that the lists are of equal length. Finally,
-        -- compare the last nodes.
-        equal_go :: LocalRegMapping -> [CmmNode O O] -> [CmmNode O O] -> Bool
-        equal_go acc (a:as) (b:bs)
-          | let (acc', eq) = eqMiddleWith eqBid acc a b
-          , eq
-          = equal_go acc' as bs
-        equal_go acc [] [] = eqLastWith eqBid acc l l'
-        equal_go _   _  _  = False
+        (env_mid, eqs_mid) =
+            List.mapAccumL (\acc (a,b) -> eqMiddleWith dflags eqBid acc a b)
+                           emptyUFM
+                           (List.zip nodes nodes')
+        equal = and eqs_mid && eqLastWith eqBid env_mid l l'
 
 
 eqLastWith :: (BlockId -> BlockId -> Bool) -> LocalRegMapping
            -> CmmNode O C -> CmmNode O C -> Bool
 eqLastWith eqBid env a b =
-  case (a, b) of
-    (CmmBranch bid1, CmmBranch bid2) ->
-        eqBid bid1 bid2
-    (CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) ->
-        eqExprWith eqBid env c1 c2
-        && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
-    (CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) ->
-        eqExprWith eqBid env t1 t2
-        && eqMaybeWith eqBid c1 c2
-        && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
-    (CmmSwitch e1 ids1, CmmSwitch e2 ids2) ->
-        eqExprWith eqBid env e1 e2
-        && eqSwitchTargetWith eqBid ids1 ids2
-    _ -> False
+    case (a, b) of
+      (CmmBranch bid1, CmmBranch bid2) -> eqBid bid1 bid2
+      (CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) ->
+          eqExprWith eqBid env c1 c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2
+      (CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) ->
+             t1 == t2
+          && eqMaybeWith eqBid c1 c2
+          && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2
+      (CmmSwitch e1 ids1, CmmSwitch e2 ids2) ->
+          eqExprWith eqBid env e1 e2 && eqSwitchTargetWith eqBid ids1 ids2
+      -- result registers aren't compared since they are binding occurrences
+      (CmmForeignCall t1 _ a1 s1 ret_args1 ret_off1 intrbl1,
+       CmmForeignCall t2 _ a2 s2 ret_args2 ret_off2 intrbl2) ->
+             t1 == t2
+          && and (zipWith (eqExprWith eqBid env) a1 a2)
+          && s1 == s2
+          && ret_args1 == ret_args2
+          && ret_off1 == ret_off2
+          && intrbl1 == intrbl2
+      _ -> False
 
 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'