rewrite branchChainElim; other refactoring in CmmContFlowOpt
authorSimon Marlow <marlowsd@gmail.com>
Tue, 23 Aug 2011 15:08:34 +0000 (16:08 +0100)
committerSimon Marlow <marlowsd@gmail.com>
Thu, 25 Aug 2011 10:12:32 +0000 (11:12 +0100)
compiler/cmm/CmmContFlowOpt.hs
compiler/cmm/CmmUtils.hs
compiler/utils/Digraph.lhs

index f8007cc..0019213 100644 (file)
 {-# OPTIONS_GHC -fno-warn-warnings-deprecations -fno-warn-incomplete-patterns #-}
 
 module CmmContFlowOpt
-    ( runCmmOpts, oldCmmCfgOpts, cmmCfgOpts
-    , branchChainElim, removeUnreachableBlocks, predMap
-    , replaceLabels, replaceBranches, runCmmContFlowOpts
+    ( runCmmContFlowOpts
+    , removeUnreachableBlocks, replaceBranches
     )
 where
 
 import BlockId
 import Cmm
 import CmmUtils
-import qualified OldCmm as Old
-
+import Digraph
 import Maybes
+import Outputable
+
 import Compiler.Hoopl
 import Control.Monad
-import Outputable
 import Prelude hiding (succ, unzip, zip)
-import Util
 
-------------------------------------
-runCmmContFlowOpts :: CmmGroup -> CmmGroup
-runCmmContFlowOpts prog = runCmmOpts cmmCfgOpts prog
+-----------------------------------------------------------------------------
+--
+-- Control-flow optimisations
+--
+-----------------------------------------------------------------------------
 
-oldCmmCfgOpts :: Old.ListGraph Old.CmmStmt -> Old.ListGraph Old.CmmStmt
-cmmCfgOpts    :: CmmGraph -> CmmGraph
+runCmmContFlowOpts :: CmmGroup -> CmmGroup
+runCmmContFlowOpts = map (optProc cmmCfgOpts)
 
-oldCmmCfgOpts = oldBranchChainElim  -- boring, but will get more exciting later
-cmmCfgOpts    =
-  removeUnreachableBlocks . blockConcat . branchChainElim
+cmmCfgOpts :: CmmGraph -> CmmGraph
+cmmCfgOpts = removeUnreachableBlocks . blockConcat . branchChainElim
         -- Here branchChainElim can ultimately be replaced
         -- with a more exciting combination of optimisations
 
-runCmmOpts :: (g -> g) -> GenCmmGroup d h g -> GenCmmGroup d h g
--- Lifts a transformer on a single graph to one on the whole program
-runCmmOpts opt = map (optProc opt)
-
 optProc :: (g -> g) -> GenCmmDecl d h g -> GenCmmDecl d h g
-optProc _   top@(CmmData {}) = top
 optProc opt (CmmProc info lbl g) = CmmProc info lbl (opt g)
+optProc _   top                  = top
 
-----------------------------------------------------------------
-oldBranchChainElim :: Old.ListGraph Old.CmmStmt -> Old.ListGraph Old.CmmStmt
--- If L is not captured in an instruction, we can remove any
--- basic block of the form L: goto L', and replace L with L' everywhere else.
--- How does L get captured? In a CallArea.
-oldBranchChainElim (Old.ListGraph blocks)
-  | null lone_branch_blocks     -- No blocks to remove
-  = Old.ListGraph blocks
-  | otherwise
-  = Old.ListGraph new_blocks
-  where
-    (lone_branch_blocks, others) = partitionWith isLoneBranch blocks
-    new_blocks = map (replaceLabels env) others
-    env = mkClosureBlockEnv lone_branch_blocks
-
-    isLoneBranch :: Old.CmmBasicBlock -> Either (BlockId, BlockId) Old.CmmBasicBlock
-    isLoneBranch (Old.BasicBlock id [Old.CmmBranch target]) | id /= target = Left (id, target)
-    isLoneBranch other_block                                           = Right other_block
-       -- An infinite loop is not a link in a branch chain!
-
-    replaceLabels :: BlockEnv BlockId -> Old.CmmBasicBlock -> Old.CmmBasicBlock
-    replaceLabels env (Old.BasicBlock id stmts)
-      = Old.BasicBlock id (map replace stmts)
-      where
-        replace (Old.CmmBranch id)       = Old.CmmBranch (lookup id)
-        replace (Old.CmmCondBranch e id) = Old.CmmCondBranch e (lookup id)
-        replace (Old.CmmSwitch e tbl)    = Old.CmmSwitch e (map (fmap lookup) tbl)
-        replace other_stmt           = other_stmt
-
-        lookup id = mapLookup id env `orElse` id 
-
-----------------------------------------------------------------
-branchChainElim :: CmmGraph -> CmmGraph
--- Remove any basic block of the form L: goto L',
--- and replace L with L' everywhere else,
--- unless L is the successor of a call instruction and L'
--- is the entry block. You don't want to set the successor
--- of a function call to the entry block because there is no good way
--- to store both the infotables for the call and from the callee,
--- while putting the stack pointer in a consistent place.
+-----------------------------------------------------------------------------
+--
+-- Branch Chain Elimination
+--
+-----------------------------------------------------------------------------
+
+-- | Remove any basic block of the form L: goto L', and replace L with
+-- L' everywhere else, unless L is the successor of a call instruction
+-- and L' is the entry block. You don't want to set the successor of a
+-- function call to the entry block because there is no good way to
+-- store both the infotables for the call and from the callee, while
+-- putting the stack pointer in a consistent place.
 --
 -- JD isn't quite sure when it's safe to share continuations for different
 -- function calls -- have to think about where the SP will be,
 -- so we'll table that problem for now by leaving all call successors alone.
+
+branchChainElim :: CmmGraph -> CmmGraph
 branchChainElim g
-  | null lone_branch_blocks     -- No blocks to remove
-  = g
-  | otherwise
-  = replaceLabels env $ ofBlockList (g_entry g) (self_branches ++ others)
+  | null lone_branch_blocks = g    -- No blocks to remove
+  | otherwise               = pprTrace "branchChainElim" (ppr forest) $ replaceLabels (mapFromList edges) g
   where
     blocks = toBlockList g
-    (lone_branch_blocks, others) = partitionWith isLoneBranch blocks
-    env = mkClosureBlockEnv lone_branch_blocks
-    self_branches =
-      let loop_to (id, _) =
-            if lookup id == id then
-              Just $ blockOfNodeList (JustC (CmmEntry id), [], JustC (mkBranchNode id))
-            else
-              Nothing
-      in  mapMaybe loop_to lone_branch_blocks
-    lookup id = mapLookup id env `orElse` id
+
+    lone_branch_blocks :: [(BlockId, BlockId)]
+      -- each (L,K) is a block of the form
+      --   L : goto K
+    lone_branch_blocks = mapCatMaybes isLoneBranch blocks
 
     call_succs = foldl add emptyBlockSet blocks
       where add :: BlockSet -> CmmBlock -> BlockSet
@@ -110,37 +72,67 @@ branchChainElim g
                 (CmmCall _ (Just k) _ _ _) -> setInsert k succs
                 (CmmForeignCall {succ=k})  -> setInsert k succs
                 _                          -> succs
-    isLoneBranch :: CmmBlock -> Either (BlockId, BlockId) CmmBlock
-    isLoneBranch block | (JustC (CmmEntry id), [], JustC (CmmBranch target)) <- blockToNodeList block,
-                         id /= target && not (setMember id call_succs)
-                       = Left (id,target)
-    isLoneBranch other = Right other
-       -- An infinite loop is not a link in a branch chain!
-
-maybeReplaceLabels :: (CmmNode O C -> Bool) -> BlockEnv BlockId -> CmmGraph -> CmmGraph
-maybeReplaceLabels lpred env =
-  replace_eid . mapGraphNodes (id, middle, last)
+
+    isLoneBranch :: CmmBlock -> Maybe (BlockId, BlockId)
+    isLoneBranch block
+      | (JustC (CmmEntry id), [], JustC (CmmBranch target)) <- blockToNodeList block
+      , not (setMember id call_succs)
+      = Just (id,target)
+      | otherwise
+      = Nothing
+
+    -- We build a graph from lone_branch_blocks (every node has only
+    -- one out edge).  Then we
+    --   - topologically sort the graph: if from A we can reach B,
+    --     then A occurs before B in the result list.
+    --   - depth-first search starting from the nodes in this list.
+    --     This gives us a [[node]], in which each list is a dependency
+    --     chain.
+    --   - for each list [a1,a2,...an] replace branches to ai with an.
+    --
+    -- This approach nicely deals with cycles by ignoring them.
+    -- Branches in a cycle will be redirected to somewhere in the
+    -- cycle, but we don't really care where.  A cycle should be dead code,
+    -- and so will be eliminated by removeUnreachableBlocks.
+    --
+    fromNode (b,_) = b
+    toNode   a     = (a,a)
+
+    all_block_ids :: LabelSet
+    all_block_ids = setFromList (map fst lone_branch_blocks)
+                      `setUnion`
+                    setFromList (map snd lone_branch_blocks)
+
+    forest = dfsTopSortG $ graphFromVerticesAndAdjacency nodes lone_branch_blocks
+        where nodes = map toNode $ setElems $ all_block_ids
+
+    edges  = [ (fromNode y, fromNode x)
+             | (x:xs) <- map reverse forest, y <- xs ]
+
+----------------------------------------------------------------
+
+replaceLabels :: BlockEnv BlockId -> CmmGraph -> CmmGraph
+replaceLabels env =
+  replace_eid . mapGraphNodes1 txnode
    where
      replace_eid g = g {g_entry = lookup (g_entry g)}
-     lookup id = fmap lookup (mapLookup id env) `orElse` id
-     
-     middle = mapExpDeep exp
-     last l = if lpred l then mapExpDeep exp (last' l) else l
-     last' :: CmmNode O C -> CmmNode O C
-     last' (CmmBranch bid)             = CmmBranch (lookup bid)
-     last' (CmmCondBranch p t f)       = CmmCondBranch p (lookup t) (lookup f)
-     last' (CmmSwitch e arms)          = CmmSwitch e (map (liftM lookup) arms)
-     last' (CmmCall t k a res r)       = CmmCall t (liftM lookup k) a res r
-     last' (CmmForeignCall t r a bid u i) = CmmForeignCall t r a (lookup bid) u i
-
+     lookup id = mapLookup id env `orElse` id
+
+     txnode :: CmmNode e x -> CmmNode e x
+     txnode (CmmBranch bid)         = CmmBranch (lookup bid)
+     txnode (CmmCondBranch p t f)   = CmmCondBranch (exp p) (lookup t) (lookup f)
+     txnode (CmmSwitch e arms)      = CmmSwitch (exp e) (map (liftM lookup) arms)
+     txnode (CmmCall t k a res r)   = CmmCall (exp t) (liftM lookup k) a res r
+     txnode fc@CmmForeignCall{}     = fc{ args = map exp (args fc)
+                                        , succ = lookup (succ fc) }
+     txnode other                   = mapExpDeep exp other
+
+     exp :: CmmExpr -> CmmExpr
      exp (CmmLit (CmmBlock bid))                = CmmLit (CmmBlock (lookup bid))
      exp (CmmStackSlot (CallArea (Young id)) i) = CmmStackSlot (CallArea (Young (lookup id))) i
      exp e                                      = e
 
 
-replaceLabels :: BlockEnv BlockId -> CmmGraph -> CmmGraph
-replaceLabels = maybeReplaceLabels (const True)
-
 replaceBranches :: BlockEnv BlockId -> CmmGraph -> CmmGraph
 replaceBranches env g = mapGraphNodes (id, id, last) g
   where
@@ -151,6 +143,8 @@ replaceBranches env g = mapGraphNodes (id, id, last) g
     last l@(CmmCall {})          = l
     last l@(CmmForeignCall {})   = l
     lookup id = fmap lookup (mapLookup id env) `orElse` id
+            -- XXX: this is a recursive lookup, it follows chains until the lookup
+            -- returns Nothing, at which point we return the last BlockId
 
 ----------------------------------------------------------------
 -- Build a map from a block to its set of predecessors. Very useful.
@@ -159,7 +153,13 @@ predMap blocks = foldr add_preds mapEmpty blocks -- find the back edges
   where add_preds block env = foldl (add (entryLabel block)) env (successors block)
         add bid env b' =
           mapInsert b' (setInsert bid (mapLookup b' env `orElse` setEmpty)) env
-----------------------------------------------------------------
+
+-----------------------------------------------------------------------------
+--
+-- Block concatenation
+--
+-----------------------------------------------------------------------------
+
 -- If a block B branches to a label L, L is not the entry block,
 -- and L has no other predecessors,
 -- then we can splice the block starting with L onto the end of B.
@@ -171,43 +171,51 @@ predMap blocks = foldr add_preds mapEmpty blocks -- find the back edges
 -- we are about to eliminate is not named in another instruction.
 --
 -- Note: This optimization does _not_ subsume branch chain elimination.
+
 blockConcat  :: CmmGraph -> CmmGraph
 blockConcat g@(CmmGraph {g_entry=eid}) =
   replaceLabels concatMap $ ofBlockMap (g_entry g) blocks'
-  where blocks = postorderDfs g
-        (blocks', concatMap) =
+  where
+     blocks = postorderDfs g
+
+     (blocks', concatMap) =
            foldr maybe_concat (toBlockMap g, mapEmpty) $ blocks
-        maybe_concat :: CmmBlock -> (LabelMap CmmBlock, LabelMap Label) -> (LabelMap CmmBlock, LabelMap Label)
-        maybe_concat b unchanged@(blocks', concatMap) =
-          let bid = entryLabel b
-          in case blockToNodeList b of
-               (JustC h, m, JustC (CmmBranch b')) ->
-                  if canConcatWith b' then
-                    (mapInsert bid (splice blocks' h m b') blocks',
-                     mapInsert b' bid concatMap)
-                  else unchanged
-               _ -> unchanged
-        num_preds bid = liftM setSize (mapLookup bid backEdges) `orElse` 0
-        canConcatWith b' = b' /= eid && num_preds b' == 1
-        backEdges = predMap blocks
-        splice :: forall map n e x.
-                  IsMap map =>
-                  map (Block n e x) -> n C O -> [n O O] -> KeyOf map -> Block n C x
-        splice blocks' h m bid' =
+
+     maybe_concat :: CmmBlock -> (LabelMap CmmBlock, LabelMap Label) -> (LabelMap CmmBlock, LabelMap Label)
+     maybe_concat b unchanged@(blocks', concatMap) =
+       let bid = entryLabel b
+       in case blockToNodeList b of
+            (JustC h, m, JustC (CmmBranch b')) ->
+               if canConcatWith b' then
+                 (mapInsert bid (splice blocks' h m b') blocks',
+                  mapInsert b' bid concatMap)
+               else unchanged
+            _ -> unchanged
+
+     num_preds bid = liftM setSize (mapLookup bid backEdges) `orElse` 0
+
+     canConcatWith b' = b' /= eid && num_preds b' == 1
+
+     backEdges = predMap blocks
+
+     splice :: forall map n e x.
+               IsMap map =>
+               map (Block n e x) -> n C O -> [n O O] -> KeyOf map -> Block n C x
+     splice blocks' h m bid' =
           case mapLookup bid' blocks' of
             Nothing -> panic "unknown successor block"
-            Just block | (_, m', l') <- blockToNodeList block -> blockOfNodeList (JustC h, (m ++ m'), l')
-----------------------------------------------------------------
-mkClosureBlockEnv :: [(BlockId, BlockId)] -> BlockEnv BlockId
-mkClosureBlockEnv blocks = mapFromList $ map follow blocks
-    where singleEnv = mapFromList blocks :: BlockEnv BlockId
-          follow (id, next) = (id, endChain id next)
-          endChain orig id = case mapLookup id singleEnv of
-                               Just id' | id /= orig -> endChain orig id'
-                               _ -> id
-----------------------------------------------------------------
+            Just block | (_, m', l') <- blockToNodeList block
+                -> blockOfNodeList (JustC h, (m ++ m'), l')
+
+
+-----------------------------------------------------------------------------
+--
+-- Removing unreachable blocks
+--
+-----------------------------------------------------------------------------
+
 removeUnreachableBlocks :: CmmGraph -> CmmGraph
-removeUnreachableBlocks g =
-  if length blocks < mapSize (toBlockMap g) then ofBlockList (g_entry g) blocks
-                                           else g
-    where blocks = postorderDfs g
+removeUnreachableBlocks g
+  | length blocks < mapSize (toBlockMap g) = ofBlockList (g_entry g) blocks
+  | otherwise = g
+  where blocks = postorderDfs g
index 47a5b09..a06d629 100644 (file)
@@ -51,7 +51,7 @@ module CmmUtils(
         lastNode, replaceLastNode, insertBetween,
         ofBlockMap, toBlockMap, insertBlock,
         ofBlockList, toBlockList, bodyToBlockList,
-        foldGraphBlocks, mapGraphNodes, postorderDfs,
+        foldGraphBlocks, mapGraphNodes, postorderDfs, mapGraphNodes1,
       
         analFwd, analBwd, analRewFwd, analRewBwd,
         dataflowPassFwd, dataflowPassBwd
@@ -418,6 +418,10 @@ mapGraphNodes :: ( CmmNode C O -> CmmNode C O
 mapGraphNodes funs@(mf,_,_) g =
   ofBlockMap (entryLabel $ mf $ CmmEntry $ g_entry g) $ mapMap (blockMapNodes3 funs) $ toBlockMap g
 
+mapGraphNodes1 :: (forall e x. CmmNode e x -> CmmNode e x) -> CmmGraph -> CmmGraph
+mapGraphNodes1 f g = modifyGraph (graphMapBlocks (blockMapNodes f)) g
+
+
 foldGraphBlocks :: (CmmBlock -> a -> a) -> a -> CmmGraph -> a
 foldGraphBlocks k z g = mapFold k z $ toBlockMap g
 
index b9d2da3..aa0f654 100644 (file)
@@ -8,7 +8,7 @@ module Digraph(
         Graph, graphFromVerticesAndAdjacency, graphFromEdgedVertices,
 
         SCC(..), Node, flattenSCC, flattenSCCs,
-        stronglyConnCompG, topologicalSortG, 
+        stronglyConnCompG, topologicalSortG, dfsTopSortG,
         verticesG, edgesG, hasVertexG,
         reachableG, transposeG,
         outdegreeG, indegreeG,
@@ -288,6 +288,12 @@ topologicalSortG :: Graph node -> [node]
 topologicalSortG graph = map (gr_vertex_to_node graph) result
   where result = {-# SCC "Digraph.topSort" #-} topSort (gr_int_graph graph)
 
+dfsTopSortG :: Graph node -> [[node]]
+dfsTopSortG graph =
+  map (map (gr_vertex_to_node graph) . flattenTree) $ dfs g (topSort g)
+  where
+    g = gr_int_graph graph
+
 reachableG :: Graph node -> node -> [node]
 reachableG graph from = map (gr_vertex_to_node graph) result
   where from_vertex = expectJust "reachableG" (gr_node_to_vertex graph from)