Add loop level analysis to the NCG backend.
authorklebinger.andreas@gmx.at <klebinger.andreas@gmx.at>
Sun, 17 Feb 2019 23:28:39 +0000 (00:28 +0100)
committerMarge Bot <ben+marge-bot@smart-cactus.org>
Wed, 16 Oct 2019 11:04:21 +0000 (07:04 -0400)
For backends maintaining the CFG during codegen
we can now find loops and their nesting level.

This is based on the Cmm CFG and dominator analysis.

As a result we can estimate edge frequencies a lot better
for methods, resulting in far better code layout.

Speedup on nofib: ~1.5%
Increase in compile times: ~1.9%

To make this feasible this commit adds:
* Dominator analysis based on the Lengauer-Tarjan Algorithm.
* An algorithm estimating global edge frequences from branch
probabilities - In CFG.hs

A few static branch prediction heuristics:

* Expect to take the backedge in loops.
* Expect to take the branch NOT exiting a loop.
* Expect integer vs constant comparisons to be false.

We also treat heap/stack checks special for branch prediction
to avoid them being treated as loops.

compiler/cmm/Hoopl/Dataflow.hs
compiler/ghc.cabal.in
compiler/nativeGen/AsmCodeGen.hs
compiler/nativeGen/BlockLayout.hs
compiler/nativeGen/CFG.hs
compiler/nativeGen/RegAlloc/Graph/SpillCost.hs
compiler/nativeGen/X86/CodeGen.hs
compiler/utils/Dominators.hs [new file with mode: 0644]
compiler/utils/OrdList.hs

index 2a2bb72..9762a84 100644 (file)
@@ -6,8 +6,6 @@
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE TypeFamilies #-}
 
-{-# OPTIONS_GHC -fprof-auto-top #-}
-
 --
 -- Copyright (c) 2010, João Dias, Simon Marlow, Simon Peyton Jones,
 -- and Norman Ramsey
@@ -108,6 +106,7 @@ analyzeCmm
     -> FactBase f
     -> FactBase f
 analyzeCmm dir lattice transfer cmmGraph initFact =
+    {-# SCC analyzeCmm #-}
     let entry = g_entry cmmGraph
         hooplGraph = g_graph cmmGraph
         blockMap =
@@ -169,7 +168,7 @@ rewriteCmm
     -> CmmGraph
     -> FactBase f
     -> UniqSM (CmmGraph, FactBase f)
-rewriteCmm dir lattice rwFun cmmGraph initFact = do
+rewriteCmm dir lattice rwFun cmmGraph initFact = {-# SCC rewriteCmm #-} do
     let entry = g_entry cmmGraph
         hooplGraph = g_graph cmmGraph
         blockMap1 =
index a612733..3ff27ea 100644 (file)
@@ -593,6 +593,7 @@ Library
             Instruction
             BlockLayout
             CFG
+            Dominators
             Format
             Reg
             RegClass
index 6b7727a..4c883e7 100644 (file)
@@ -562,7 +562,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
                 Opt_D_dump_asm_native "Native code"
                 (vcat $ map (pprNatCmmDecl ncgImpl) native)
 
-        dumpIfSet_dyn dflags
+        when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags
                 Opt_D_dump_cfg_weights "CFG Weights"
                 (pprEdgeWeights nativeCfgWeights)
 
@@ -691,7 +691,7 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
                 {-# SCC "generateJumpTables" #-}
                 generateJumpTables ncgImpl alloced
 
-        dumpIfSet_dyn dflags
+        when (not $ null nativeCfgWeights) $ dumpIfSet_dyn dflags
                 Opt_D_dump_cfg_weights "CFG Update information"
                 ( text "stack:" <+> ppr stack_updt_blks $$
                   text "linearAlloc:" <+> ppr cfgRegAllocUpdates )
@@ -705,8 +705,9 @@ cmmNativeGen dflags this_mod modLoc ncgImpl us fileIds dbgMap cmm count
             optimizedCFG =
                 optimizeCFG (cfgWeightInfo dflags) cmm <$!> postShortCFG
 
-        maybe   (return ())
-                (dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Final Weights" . pprEdgeWeights)
+        maybe (return ()) (\cfg->
+                dumpIfSet_dyn dflags Opt_D_dump_cfg_weights "CFG Final Weights"
+                ( pprEdgeWeights cfg ))
                 optimizedCFG
 
         --TODO: Partially check validity of the cfg.
index 7a39071..56e3177 100644 (file)
@@ -6,6 +6,8 @@
 {-# LANGUAGE DataKinds #-}
 {-# LANGUAGE ScopedTypeVariables #-}
 {-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE FlexibleContexts #-}
 
 module BlockLayout
     ( sequenceTop )
@@ -22,7 +24,6 @@ import BlockId
 import Cmm
 import Hoopl.Collections
 import Hoopl.Label
-import Hoopl.Block
 
 import DynFlags (gopt, GeneralFlag(..), DynFlags, backendMaintainsCfg)
 import UniqFM
@@ -41,11 +42,30 @@ import ListSetOps (removeDups)
 import OrdList
 import Data.List
 import Data.Foldable (toList)
-import Hoopl.Graph
 
 import qualified Data.Set as Set
+import Data.STRef
+import Control.Monad.ST.Strict
+import Control.Monad (foldM)
 
 {-
+  Note [CFG based code layout]
+  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  The major steps in placing blocks are as follow:
+  * Compute a CFG based on the Cmm AST, see getCfgProc.
+    This CFG will have edge weights representing a guess
+    on how important they are.
+  * After we convert Cmm to Asm we run `optimizeCFG` which
+    adds a few more "educated guesses" to the equation.
+  * Then we run loop analysis on the CFG (`loopInfo`) which tells us
+    about loop headers, loop nesting levels and the sort.
+  * Based on the CFG and loop information refine the edge weights
+    in the CFG and normalize them relative to the most often visited
+    node. (See `mkGlobalWeights`)
+  * Feed this CFG into the block layout code (`sequenceTop`) in this
+    module. Which will then produce a code layout based on the input weights.
+
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   ~~~ Note [Chain based CFG serialization]
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -60,8 +80,8 @@ import qualified Data.Set as Set
   but also how much a block would benefit from being placed sequentially after
   it's predecessor.
   For example blocks which are preceeded by an info table are more likely to end
-  up in a different cache line than their predecessor. So there is less benefit
-  in placing them sequentially.
+  up in a different cache line than their predecessor and we can't eliminate the jump
+  so there is less benefit to placing them sequentially.
 
   For example consider this example:
 
@@ -81,56 +101,83 @@ import qualified Data.Set as Set
   Eg for our example we might end up with two chains like:
   [A->B->C->X],[D]. Blocks inside chains will always be placed sequentially.
   However there is no particular order in which chains are placed since
-  (hopefully) the blocks for which sequentially is important have already
+  (hopefully) the blocks for which sequentiality is important have already
   been placed in the same chain.
 
   -----------------------------------------------------------------------------
-      First try to create a lists of good chains.
+     1) First try to create a list of good chains.
   -----------------------------------------------------------------------------
 
-  We do so by taking a block not yet placed in a chain and
-  looking at these cases:
+  Good chains are these which allow us to eliminate jump instructions.
+  Which further eliminate often executed jumps first.
+
+  We do so by:
+
+  *)  Ignore edges which represent instructions which can not be replaced
+      by fall through control flow. Primarily calls and edges to blocks which
+      are prefixed by a info table we have to jump across.
+
+  *)  Then process remaining edges in order of frequency taken and:
+
+    +)  If source and target have not been placed build a new chain from them.
+
+    +)  If source and target have been placed, and are ends of differing chains
+        try to merge the two chains.
 
-  *)  Check if the best predecessor of the block is at the end of a chain.
-      If so add the current block to the end of that chain.
+    +)  If one side of the edge is a end/front of a chain, add the other block of
+        to edge to the same chain
 
-      Eg if we look at block C and already have the chain (A -> B)
-      then we extend the chain to (A -> B -> C).
+        Eg if we look at edge (B -> C) and already have the chain (A -> B)
+        then we extend the chain to (A -> B -> C).
 
-      Combined with the fact that we process blocks in reverse post order
-      this means loop bodies and trivially sequential control flow already
-      ends up as a single chain.
+    +)  If the edge was used to modify or build a new chain remove the edge from
+        our working list.
 
-  *)  Otherwise we create a singleton chain from the block we are looking at.
-      Eg if we have from the example above already constructed (A->B)
-      and look at D we create the chain (D) resulting in the chains [A->B, D]
+  *) If there any blocks not being placed into a chain after these steps we place
+     them into a chain consisting of only this block.
+
+  Ranking edges by their taken frequency, if
+  two edges compete for fall through on the same target block, the one taken
+  more often will automatically win out. Resulting in fewer instructions being
+  executed.
+
+  Creating singleton chains is required for situations where we have code of the
+  form:
+
+    A: goto B:
+    <infoTable>
+    B: goto C:
+    <infoTable>
+    C: ...
+
+  As the code in block B is only connected to the rest of the program via edges
+  which will be ignored in this step we make sure that B still ends up in a chain
+  this way.
 
   -----------------------------------------------------------------------------
-      We then try to fuse chains.
+     2) We also try to fuse chains.
   -----------------------------------------------------------------------------
 
-  There are edge cases which result in two chains being created which trivially
-  represent linear control flow. For example we might have the chains
-  [(A-B-C),(D-E)] with an cfg triangle:
+  As a result from the above step we still end up with multiple chains which
+  represent sequential control flow chunks. But they are not yet suitable for
+  code layout as we need to place *all* blocks into a single sequence.
 
-      A----->C->D->E
-       \->B-/
+  In this step we combine chains result from the above step via these steps:
 
-  We also get three independent chains if two branches end with a jump
-  to a common successor.
+  *)  Look at the ranked list of *all* edges, including calls/jumps across info tables
+      and the like.
 
-  We take care of these cases by fusing chains which are connected by an
-  edge.
+  *)  Look at each edge and
 
-  We do so by looking at the list of edges sorted by weight.
-  Given the edge (C -> D) we try to find two chains such that:
-      * C is at the end of chain one.
-      * D is in front of chain two.
-      * If two such chains exist we fuse them.
-  We then remove the edge and repeat the process for the rest of the edges.
+    +) Given an edge (A -> B) try to find two chains for which
+      * Block A is at the end of one chain
+      * Block B is at the front of the other chain.
+    +) If we find such a chain we "fuse" them into a single chain, remove the
+       edge from working set and continue.
+    +) If we can't find such chains we skip the edge and continue.
 
   -----------------------------------------------------------------------------
-      Place indirect successors (neighbours) after each other
+     3) Place indirect successors (neighbours) after each other
   -----------------------------------------------------------------------------
 
   We might have chains [A,B,C,X],[E] in a CFG of the sort:
@@ -141,15 +188,11 @@ import qualified Data.Set as Set
   While E does not follow X it's still beneficial to place them near each other.
   This can be advantageous if eg C,X,E will end up in the same cache line.
 
-  TODO: If we remove edges as we use them (eg if we build up A->B remove A->B
-        from the list) we could save some more work in later phases.
-
-
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   ~~~ Note [Triangle Control Flow]
   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-  Checking if an argument is already evaluating leads to a somewhat
+  Checking if an argument is already evaluated leads to a somewhat
   special case  which looks like this:
 
     A:
@@ -204,11 +247,6 @@ import qualified Data.Set as Set
 neighbourOverlapp :: Int
 neighbourOverlapp = 2
 
--- | Only edges heavier than this are considered
---   for fusing two chains into a single chain.
-fuseEdgeThreshold :: EdgeWeight
-fuseEdgeThreshold = 0
-
 -- | Maps blocks near the end of a chain to it's chain AND
 -- the other blocks near the end.
 -- [A,B,C,D,E] Gives entries like (B -> ([A,B], [A,B,C,D,E]))
@@ -224,40 +262,24 @@ type FrontierMap = LabelMap ([BlockId],BlockChain)
 newtype BlockChain
     = BlockChain { chainBlocks :: (OrdList BlockId) }
 
-instance Eq (BlockChain) where
-    (BlockChain blks1) == (BlockChain blks2)
-        = fromOL blks1 == fromOL blks2
+-- All chains are constructed the same way so comparison
+-- including structure is faster.
+instance Eq BlockChain where
+    BlockChain b1 == BlockChain b2 = strictlyEqOL b1 b2
 
 -- Useful for things like sets and debugging purposes, sorts by blocks
 -- in the chain.
 instance Ord (BlockChain) where
    (BlockChain lbls1) `compare` (BlockChain lbls2)
-       = (fromOL lbls1) `compare` (fromOL lbls2)
+       = ASSERT(toList lbls1 /= toList lbls2 || lbls1 `strictlyEqOL` lbls2)
+         strictlyOrdOL lbls1 lbls2
 
 instance Outputable (BlockChain) where
     ppr (BlockChain blks) =
         parens (text "Chain:" <+> ppr (fromOL $ blks) )
 
-data WeightedEdge = WeightedEdge !BlockId !BlockId EdgeWeight deriving (Eq)
-
-
--- | Non deterministic! (Uniques) Sorts edges by weight and nodes.
-instance Ord WeightedEdge where
-  compare (WeightedEdge from1 to1 weight1)
-          (WeightedEdge from2 to2 weight2)
-    | weight1 < weight2 || weight1 == weight2 && from1 < from2 ||
-      weight1 == weight2 && from1 == from2 && to1 < to2
-    = LT
-    | from1 == from2 && to1 == to2 && weight1 == weight2
-    = EQ
-    | otherwise
-    = GT
-
-instance Outputable WeightedEdge where
-    ppr (WeightedEdge from to info) =
-        ppr from <> text "->" <> ppr to <> brackets (ppr info)
-
-type WeightedEdgeList = [WeightedEdge]
+chainFoldl :: (b -> BlockId -> b) -> b -> BlockChain -> b
+chainFoldl f z (BlockChain blocks) = foldl' f z blocks
 
 noDups :: [BlockChain] -> Bool
 noDups chains =
@@ -270,19 +292,21 @@ inFront :: BlockId -> BlockChain -> Bool
 inFront bid (BlockChain seq)
   = headOL seq == bid
 
-chainMember :: BlockId -> BlockChain -> Bool
-chainMember bid chain
-  = elem bid $ fromOL . chainBlocks $ chain
---   = setMember bid . chainMembers $ chain
-
 chainSingleton :: BlockId -> BlockChain
 chainSingleton lbl
     = BlockChain (unitOL lbl)
 
+chainFromList :: [BlockId] -> BlockChain
+chainFromList = BlockChain . toOL
+
 chainSnoc :: BlockChain -> BlockId -> BlockChain
 chainSnoc (BlockChain blks) lbl
   = BlockChain (blks `snocOL` lbl)
 
+chainCons :: BlockId -> BlockChain -> BlockChain
+chainCons lbl (BlockChain blks)
+  = BlockChain (lbl `consOL` blks)
+
 chainConcat :: BlockChain -> BlockChain -> BlockChain
 chainConcat (BlockChain blks1) (BlockChain blks2)
   = BlockChain (blks1 `appOL` blks2)
@@ -311,52 +335,14 @@ takeL :: Int -> BlockChain -> [BlockId]
 takeL n (BlockChain blks) =
     take n . fromOL $ blks
 
--- | For a given list of chains try to fuse chains with strong
---   edges between them into a single chain.
---   Returns the list of fused chains together with a set of
---   used edges. The set of edges is indirectly encoded in the
---   chains so doesn't need to be considered for later passes.
-fuseChains :: WeightedEdgeList -> LabelMap BlockChain
-           -> (LabelMap BlockChain, Set.Set WeightedEdge)
-fuseChains weights chains
-    = let fronts = mapFromList $
-                    map (\chain -> (headOL . chainBlocks $ chain,chain)) $
-                    mapElems chains :: LabelMap BlockChain
-          (chains', used, _) = applyEdges weights chains fronts Set.empty
-      in (chains', used)
-    where
-        applyEdges :: WeightedEdgeList -> LabelMap BlockChain
-                   -> LabelMap BlockChain -> Set.Set WeightedEdge
-                   -> (LabelMap BlockChain, Set.Set WeightedEdge, LabelMap BlockChain)
-        applyEdges [] chainsEnd chainsFront used
-            = (chainsEnd, used, chainsFront)
-        applyEdges (edge@(WeightedEdge from to w):edges) chainsEnd chainsFront used
-            --Since we order edges descending by weight we can stop here
-            | w <= fuseEdgeThreshold
-            = ( chainsEnd, used, chainsFront)
-            --Fuse the two chains
-            | Just c1 <- mapLookup from chainsEnd
-            , Just c2 <- mapLookup to chainsFront
-            , c1 /= c2
-            = let newChain = chainConcat c1 c2
-                  front = headOL . chainBlocks $ newChain
-                  end = lastOL . chainBlocks $ newChain
-                  chainsFront' = mapInsert front newChain $
-                                 mapDelete to chainsFront
-                  chainsEnd'   = mapInsert end newChain $
-                                 mapDelete from chainsEnd
-              in applyEdges edges chainsEnd' chainsFront'
-                            (Set.insert edge used)
-            | otherwise
-            --Check next edge
-            = applyEdges edges chainsEnd chainsFront used
-
+-- Note [Combining neighborhood chains]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 -- See also Note [Chain based CFG serialization]
 -- We have the chains (A-B-C-D) and (E-F) and an Edge C->E.
 --
--- While placing the later after the former doesn't result in sequential
--- control flow it is still be benefical since block C and E might end
+-- While placing the latter after the former doesn't result in sequential
+-- control flow it is still benefical. As block C and E might end
 -- up in the same cache line.
 --
 -- So we place these chains next to each other even if we can't fuse them.
@@ -365,7 +351,7 @@ fuseChains weights chains
 --             v
 --             - -> E -> F ...
 --
--- Simple heuristic to chose which chains we want to combine:
+-- A simple heuristic to chose which chains we want to combine:
 --   * Process edges in descending priority.
 --   * Check if there is a edge near the end of one chain which goes
 --     to a block near the start of another edge.
@@ -375,14 +361,22 @@ fuseChains weights chains
 -- us to find all edges between two chains, check the distance for all edges,
 -- rank them based on the distance and and only then we can select two chains
 -- to combine. Which would add a lot of complexity for little gain.
+--
+-- So instead we just rank by the strength of the edge and use the first pair we
+-- find.
 
 -- | For a given list of chains and edges try to combine chains with strong
 --   edges between them.
-combineNeighbourhood :: WeightedEdgeList -> [BlockChain]
-                     -> [BlockChain]
+combineNeighbourhood  :: [CfgEdge] -- ^ Edges to consider
+                      -> [BlockChain] -- ^ Current chains of blocks
+                      -> ([BlockChain], Set.Set (BlockId,BlockId))
+                      -- ^ Resulting list of block chains, and a set of edges which
+                      -- were used to fuse chains and as such no longer need to be
+                      -- considered.
 combineNeighbourhood edges chains
     = -- pprTraceIt "Neigbours" $
-      applyEdges edges endFrontier startFrontier
+    --   pprTrace "combineNeighbours" (ppr edges) $
+      applyEdges edges endFrontier startFrontier (Set.empty)
     where
         --Build maps from chain ends to chains
         endFrontier, startFrontier :: FrontierMap
@@ -396,14 +390,14 @@ combineNeighbourhood edges chains
                                 let front = getFronts chain
                                     entry = (front,chain)
                                 in map (\x -> (x,entry)) front) chains
-        applyEdges :: WeightedEdgeList -> FrontierMap -> FrontierMap
-                   -> [BlockChain]
-        applyEdges [] chainEnds _chainFronts =
-            ordNub $ map snd $ mapElems chainEnds
-        applyEdges ((WeightedEdge from to _w):edges) chainEnds chainFronts
+        applyEdges :: [CfgEdge] -> FrontierMap -> FrontierMap -> Set.Set (BlockId, BlockId)
+                   -> ([BlockChain], Set.Set (BlockId,BlockId))
+        applyEdges [] chainEnds _chainFronts combined =
+            (ordNub $ map snd $ mapElems chainEnds, combined)
+        applyEdges ((CfgEdge from to _w):edges) chainEnds chainFronts combined
             | Just (c1_e,c1) <- mapLookup from chainEnds
             , Just (c2_f,c2) <- mapLookup to chainFronts
-            , c1 /= c2 -- Avoid trying to concat a short chain with itself.
+            , c1 /= c2 -- Avoid trying to concat a chain with itself.
             = let newChain = chainConcat c1 c2
                   newChainFrontier = getFronts newChain
                   newChainEnds = getEnds newChain
@@ -437,165 +431,299 @@ combineNeighbourhood edges chains
                 --   text "fronts" <+> ppr newFronts $$
                 --   text "ends" <+> ppr newEnds
                 --   )
-                 applyEdges edges newEnds newFronts
+                 applyEdges edges newEnds newFronts (Set.insert (from,to) combined)
             | otherwise
-            = --pprTrace "noNeigbours" (ppr ()) $
-              applyEdges edges chainEnds chainFronts
+            = applyEdges edges chainEnds chainFronts combined
          where
 
         getFronts chain = takeL neighbourOverlapp chain
         getEnds chain = takeR neighbourOverlapp chain
 
-
-
--- See [Chain based CFG serialization]
-buildChains :: CFG -> [BlockId]
-            -> ( LabelMap BlockChain  -- Resulting chains.
+-- In the last stop we combine all chains into a single one.
+-- Trying to place chains with strong edges next to each other.
+mergeChains :: [CfgEdge] -> [BlockChain]
+            -> (BlockChain)
+mergeChains edges chains
+    = -- pprTrace "combine" (ppr edges) $
+      runST $ do
+        let addChain m0 chain = do
+                ref <- newSTRef chain
+                return $ chainFoldl (\m' b -> mapInsert b ref m') m0 chain
+        chainMap' <- foldM (\m0 c -> addChain m0 c) mapEmpty chains
+        merge edges chainMap'
+    where
+        -- We keep a map from ALL blocks to their respective chain (sigh)
+        -- This is required since when looking at an edge we need to find
+        -- the associated chains quickly.
+        -- We use a map of STRefs, maintaining a invariant of one STRef per chain.
+        -- When merging chains we can update the
+        -- STRef of one chain once (instead of writing to the map for each block).
+        -- We then overwrite the STRefs for the other chain so there is again only
+        -- a single STRef for the combined chain.
+        -- The difference in terms of allocations saved is ~0.2% with -O so actually
+        -- significant compared to using a regular map.
+
+        merge :: forall s. [CfgEdge] -> LabelMap (STRef s BlockChain) -> ST s BlockChain
+        merge [] chains = do
+            chains' <- ordNub <$> (mapM readSTRef $ mapElems chains) :: ST s [BlockChain]
+            return $ foldl' chainConcat (head chains') (tail chains')
+        merge ((CfgEdge from to _):edges) chains
+        --   | pprTrace "merge" (ppr (from,to) <> ppr chains) False
+        --   = undefined
+          | cFrom == cTo
+          = merge edges chains
+          | otherwise
+          = do
+            chains' <- mergeComb cFrom cTo
+            merge edges chains'
+          where
+            mergeComb :: STRef s BlockChain -> STRef s BlockChain -> ST s (LabelMap (STRef s BlockChain))
+            mergeComb refFrom refTo = do
+                cRight <- readSTRef refTo
+                chain <- pure chainConcat <*> readSTRef refFrom <*> pure cRight
+                writeSTRef refFrom chain
+                return $ chainFoldl (\m b -> mapInsert b refFrom m) chains cRight
+
+            cFrom = expectJust "mergeChains:chainMap:from" $ mapLookup from chains
+            cTo = expectJust "mergeChains:chainMap:to"   $ mapLookup to   chains
+
+
+-- See Note [Chain based CFG serialization] for the general idea.
+-- This creates and fuses chains at the same time for performance reasons.
+
+-- Try to build chains from a list of edges.
+-- Edges must be sorted **descending** by their priority.
+-- Returns the constructed chains, along with all edges which
+-- are irrelevant past this point, this information doesn't need
+-- to be complete - it's only used to speed up the process.
+-- An Edge is irrelevant if the ends are part of the same chain.
+-- We say these edges are already linked
+buildChains :: [CfgEdge] -> [BlockId]
+            -> ( LabelMap BlockChain  -- Resulting chains, indexd by end if chain.
                , Set.Set (BlockId, BlockId)) --List of fused edges.
-buildChains succWeights blocks
-  = let (_, fusedEdges, chains) = buildNext setEmpty mapEmpty blocks Set.empty
-    in (chains, fusedEdges)
+buildChains edges blocks
+  = runST $ buildNext setEmpty mapEmpty mapEmpty edges Set.empty
   where
-    -- We keep a map from the last block in a chain to the chain itself.
-    -- This we we can easily check if an block should be appened to an
+    -- buildNext builds up chains from edges one at a time.
+
+    -- We keep a map from the ends of chains to the chains.
+    -- This we we can easily check if an block should be appended to an
     -- existing chain!
-    buildNext :: LabelSet
-              -> LabelMap BlockChain -- Map from last element to chain.
-              -> [BlockId] -- Blocks to place
-              -> Set.Set (BlockId, BlockId)
-              -> ( [BlockChain]  -- Placed Blocks
-                 , Set.Set (BlockId, BlockId) --List of fused edges
-                 , LabelMap BlockChain
-                 )
-    buildNext _placed chains [] linked =
-        ([], linked, chains)
-    buildNext placed chains (block:todo) linked
-        | setMember block placed
-        = buildNext placed chains todo linked
+    -- We store them using STRefs so we don't have to rebuild the spine of both
+    -- maps every time we update a chain.
+    buildNext :: forall s. LabelSet
+              -> LabelMap (STRef s BlockChain) -- Map from end of chain to chain.
+              -> LabelMap (STRef s BlockChain) -- Map from start of chain to chain.
+              -> [CfgEdge] -- Edges to check - ordered by decreasing weight
+              -> Set.Set (BlockId, BlockId) -- Used edges
+              -> ST s   ( LabelMap BlockChain -- Chains by end
+                        , Set.Set (BlockId, BlockId) --List of fused edges
+                        )
+    buildNext placed _chainStarts chainEnds  [] linked = do
+        ends' <- sequence $ mapMap readSTRef chainEnds :: ST s (LabelMap BlockChain)
+        -- Any remaining blocks have to be made to singleton chains.
+        -- They might be combined with other chains later on outside this function.
+        let unplaced = filter (\x -> not (setMember x placed)) blocks
+            singletons = map (\x -> (x,chainSingleton x)) unplaced :: [(BlockId,BlockChain)]
+        return (foldl' (\m (k,v) -> mapInsert k v m) ends' singletons , linked)
+    buildNext placed chainStarts chainEnds (edge:todo) linked
+        | from == to
+        -- We skip self edges
+        = buildNext placed chainStarts chainEnds todo (Set.insert (from,to) linked)
+        | not (alreadyPlaced from) &&
+          not (alreadyPlaced to)
+        = do
+            --pprTraceM "Edge-Chain:" (ppr edge)
+            chain' <- newSTRef $ chainFromList [from,to]
+            buildNext
+                (setInsert to (setInsert from placed))
+                (mapInsert from chain' chainStarts)
+                (mapInsert to chain' chainEnds)
+                todo
+                (Set.insert (from,to) linked)
+
+        | (alreadyPlaced from) &&
+          (alreadyPlaced to)
+        , Just predChain <- mapLookup from chainEnds
+        , Just succChain <- mapLookup to chainStarts
+        , predChain /= succChain -- Otherwise we try to create a cycle.
+        = do
+            -- pprTraceM "Fusing edge" (ppr edge)
+            fuseChain predChain succChain
+
+        | (alreadyPlaced from) &&
+          (alreadyPlaced to)
+        =   --pprTraceM "Skipping:" (ppr edge) >>
+            buildNext placed chainStarts chainEnds todo linked
+
         | otherwise
-        = buildNext placed' chains' todo linked'
+        = do -- pprTraceM "Finding chain for:" (ppr edge $$
+             --         text "placed" <+> ppr placed)
+             findChain
       where
-        placed' = (foldl' (flip setInsert) placed placedBlocks)
-        linked' = Set.union linked linkedEdges
-        (placedBlocks, chains', linkedEdges) = findChain block
-
-        --Add the block to a existing or new chain
-        --Returns placed blocks, list of resulting chains
-        --and fused edges
-        findChain :: BlockId
-                -> ([BlockId],LabelMap BlockChain, Set.Set (BlockId, BlockId))
-        findChain block
-        -- B) place block at end of existing chain if
-        -- there is no better block to append.
-          | (pred:_) <- preds
-          , alreadyPlaced pred
-          , Just predChain <- mapLookup pred chains
-          , (best:_) <- filter (not . alreadyPlaced) $ getSuccs pred
-          , best == lbl
-          = --pprTrace "B.2)" (ppr (pred,lbl)) $
-            let newChain = chainSnoc predChain block
-                chainMap = mapInsert lbl newChain $ mapDelete pred chains
-            in  ( [lbl]
-                , chainMap
-                , Set.singleton (pred,lbl) )
-
+        from = edgeFrom edge
+        to   = edgeTo   edge
+        alreadyPlaced blkId = (setMember blkId placed)
+
+        -- Combine two chains into a single one.
+        fuseChain :: STRef s BlockChain -> STRef s BlockChain
+                  -> ST s   ( LabelMap BlockChain -- Chains by end
+                            , Set.Set (BlockId, BlockId) --List of fused edges
+                            )
+        fuseChain fromRef toRef = do
+            fromChain <- readSTRef fromRef
+            toChain <- readSTRef toRef
+            let newChain = chainConcat fromChain toChain
+            ref <- newSTRef newChain
+            let start = head $ takeL 1 newChain
+            let end = head $ takeR 1 newChain
+            -- chains <- sequence $ mapMap readSTRef chainStarts
+            -- pprTraceM "pre-fuse chains:" $ ppr chains
+            buildNext
+                placed
+                (mapInsert start ref $ mapDelete to $ chainStarts)
+                (mapInsert end ref $ mapDelete from $ chainEnds)
+                todo
+                (Set.insert (from,to) linked)
+
+
+        --Add the block to a existing chain or creates a new chain
+        findChain :: ST s   ( LabelMap BlockChain -- Chains by end
+                            , Set.Set (BlockId, BlockId) --List of fused edges
+                            )
+        findChain
+          -- We can attach the block to the end of a chain
+          | alreadyPlaced from
+          , Just predChain <- mapLookup from chainEnds
+          = do
+            chain <- readSTRef predChain
+            let newChain = chainSnoc chain to
+            writeSTRef predChain newChain
+            let chainEnds' = mapInsert to predChain $ mapDelete from chainEnds
+            -- chains <- sequence $ mapMap readSTRef chainStarts
+            -- pprTraceM "from chains:" $ ppr chains
+            buildNext (setInsert to placed) chainStarts chainEnds' todo (Set.insert (from,to) linked)
+          -- We can attack it to the front of a chain
+          | alreadyPlaced to
+          , Just succChain <- mapLookup to chainStarts
+          = do
+            chain <- readSTRef succChain
+            let newChain = from `chainCons` chain
+            writeSTRef succChain newChain
+            let chainStarts' = mapInsert from succChain $ mapDelete to chainStarts
+            -- chains <- sequence $ mapMap readSTRef chainStarts'
+            -- pprTraceM "to chains:" $ ppr chains
+            buildNext (setInsert from placed) chainStarts' chainEnds todo (Set.insert (from,to) linked)
+          -- The placed end of the edge is part of a chain already and not an end.
           | otherwise
-          = --pprTrace "single" (ppr lbl)
-            ( [lbl]
-            , mapInsert lbl (chainSingleton lbl) chains
-            , Set.empty)
+          = do
+            let block    = if alreadyPlaced to then from else to
+            --pprTraceM "Singleton" $ ppr block
+            let newChain = chainSingleton block
+            ref <- newSTRef newChain
+            buildNext (setInsert block placed) (mapInsert block ref chainStarts)
+                      (mapInsert block ref chainEnds) todo (linked)
             where
               alreadyPlaced blkId = (setMember blkId placed)
-              lbl = block
-              getSuccs = map fst . getSuccEdgesSorted succWeights
-              preds = map fst $ getSuccEdgesSorted predWeights lbl
-    --For efficiency we also create the map to look up predecessors here
-    predWeights = reverseEdges succWeights
 
-
-
--- We make the CFG a Hoopl Graph, so we can reuse revPostOrder.
-newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
-instance NonLocal (BlockNode) where
-  entryLabel (BN (lbl,_))   = lbl
-  successors (BN (_,succs)) = succs
-
-fromNode :: BlockNode C C -> BlockId
-fromNode (BN x) = fst x
-
-sequenceChain :: forall a i. (Instruction i, Outputable i) => LabelMap a -> CFG
-            -> [GenBasicBlock i] -> [GenBasicBlock i]
+-- | Place basic blocks based on the given CFG.
+-- See Note [Chain based CFG serialization]
+sequenceChain :: forall a i. (Instruction i, Outputable i)
+              => LabelMap a -- ^ Keys indicate an info table on the block.
+              -> CFG -- ^ Control flow graph and some meta data.
+              -> [GenBasicBlock i] -- ^ List of basic blocks to be placed.
+              -> [GenBasicBlock i] -- ^ Blocks placed in sequence.
 sequenceChain _info _weights    [] = []
 sequenceChain _info _weights    [x] = [x]
 sequenceChain  info weights'     blocks@((BasicBlock entry _):_) =
-    --Optimization, delete edges of weight <= 0.
-    --This significantly improves performance whenever
-    --we iterate over all edges, which is a few times!
     let weights :: CFG
-        weights
-            = filterEdges (\_f _t edgeInfo -> edgeWeight edgeInfo > 0) weights'
+        weights = --pprTrace "cfg'" (pprEdgeWeights cfg')
+                  cfg'
+          where
+            (_, globalEdgeWeights) = {-# SCC mkGlobalWeights #-} mkGlobalWeights entry weights'
+            cfg' = {-# SCC rewriteEdges #-}
+                    mapFoldlWithKey
+                        (\cfg from m ->
+                            mapFoldlWithKey
+                                (\cfg to w -> setEdgeWeight cfg (EdgeWeight w) from to )
+                                cfg m )
+                        weights'
+                        globalEdgeWeights
+
+        directEdges :: [CfgEdge]
+        directEdges = sortBy (flip compare) $ catMaybes . map relevantWeight $ (infoEdgeList weights)
+          where
+            relevantWeight :: CfgEdge -> Maybe CfgEdge
+            relevantWeight edge@(CfgEdge from to edgeInfo)
+                | (EdgeInfo CmmSource { trans_cmmNode = CmmCall {} } _) <- edgeInfo
+                -- Ignore edges across calls
+                = Nothing
+                | mapMember to info
+                , w <- edgeWeight edgeInfo
+                -- The payoff is small if we jump over an info table
+                = Just (CfgEdge from to edgeInfo { edgeWeight = w/8 })
+                | otherwise
+                = Just edge
+
         blockMap :: LabelMap (GenBasicBlock i)
         blockMap
             = foldl' (\m blk@(BasicBlock lbl _ins) ->
                         mapInsert lbl blk m)
                      mapEmpty blocks
 
-        toNode :: BlockId -> BlockNode C C
-        toNode bid =
-            -- sorted such that heavier successors come first.
-            BN (bid,map fst . getSuccEdgesSorted weights' $ bid)
-
-        orderedBlocks :: [BlockId]
-        orderedBlocks
-            = map fromNode $
-              revPostorderFrom (fmap (toNode . blockId) blockMap) entry
-
         (builtChains, builtEdges)
             = {-# SCC "buildChains" #-}
               --pprTraceIt "generatedChains" $
-              --pprTrace "orderedBlocks" (ppr orderedBlocks) $
-              buildChains weights orderedBlocks
+              --pprTrace "blocks" (ppr (mapKeys blockMap)) $
+              buildChains directEdges (mapKeys blockMap)
 
-        rankedEdges :: WeightedEdgeList
-        -- Sort edges descending, remove fused eges
+        rankedEdges :: [CfgEdge]
+        -- Sort descending by weight, remove fused edges
         rankedEdges =
-            map (\(from, to, weight) -> WeightedEdge from to weight) .
-            filter (\(from, to, _)
-                        -> not (Set.member (from,to) builtEdges)) .
-            sortWith (\(_,_,w) -> - w) $ weightedEdgeList weights
+            filter (\edge -> not (Set.member (edgeFrom edge,edgeTo edge) builtEdges)) $
+            directEdges
 
-        (fusedChains, fusedEdges)
+        (neighbourChains, combined)
             = ASSERT(noDups $ mapElems builtChains)
-              {-# SCC "fuseChains" #-}
-              --(pprTrace "RankedEdges" $ ppr rankedEdges) $
-              --pprTraceIt "FusedChains" $
-              fuseChains rankedEdges builtChains
-
-        rankedEdges' =
-            filter (\edge -> not $ Set.member edge fusedEdges) $ rankedEdges
-
-        neighbourChains
-            = ASSERT(noDups $ mapElems fusedChains)
               {-# SCC "groupNeighbourChains" #-}
-              --pprTraceIt "ResultChains" $
-              combineNeighbourhood rankedEdges' (mapElems fusedChains)
+            --   pprTraceIt "NeighbourChains" $
+              combineNeighbourhood rankedEdges (mapElems builtChains)
+
+
+        allEdges :: [CfgEdge]
+        allEdges = {-# SCC allEdges #-}
+                   sortOn (relevantWeight) $ filter (not . deadEdge) $ (infoEdgeList weights)
+          where
+            deadEdge :: CfgEdge -> Bool
+            deadEdge (CfgEdge from to _) = let e = (from,to) in Set.member e combined || Set.member e builtEdges
+            relevantWeight :: CfgEdge -> EdgeWeight
+            relevantWeight (CfgEdge _ _ edgeInfo)
+                | EdgeInfo (CmmSource { trans_cmmNode = CmmCall {}}) _ <- edgeInfo
+                -- Penalize edges across calls
+                = weight/(64.0)
+                | otherwise
+                = weight
+              where
+                -- negate to sort descending
+                weight = negate (edgeWeight edgeInfo)
+
+        masterChain =
+            {-# SCC "mergeChains" #-}
+            -- pprTraceIt "MergedChains" $
+            mergeChains allEdges neighbourChains
 
         --Make sure the first block stays first
-        ([entryChain],chains')
-            = ASSERT(noDups $ neighbourChains)
-              partition (chainMember entry) neighbourChains
-        (entryChain':entryRest)
-            | inFront entry entryChain = [entryChain]
-            | (rest,entry) <- breakChainAt entry entryChain
+        prepedChains
+            | inFront entry masterChain
+            = [masterChain]
+            | (rest,entry) <- breakChainAt entry masterChain
             = [entry,rest]
             | otherwise = pprPanic "Entry point eliminated" $
-                            ppr ([entryChain],chains')
+                            ppr masterChain
 
-        prepedChains
-            = entryChain':(entryRest++chains') :: [BlockChain]
         blockList
-            -- = (concatMap chainToBlocks prepedChains)
-            = (concatMap fromOL $ map chainBlocks prepedChains)
+            = ASSERT(noDups [masterChain])
+              (concatMap fromOL $ map chainBlocks prepedChains)
 
         --chainPlaced = setFromList $ map blockId blockList :: LabelSet
         chainPlaced = setFromList $ blockList :: LabelSet
@@ -605,14 +733,22 @@ sequenceChain  info weights'     blocks@((BasicBlock entry _):_) =
             in filter (\block -> not (isPlaced block)) blocks
 
         placedBlocks =
+            -- We want debug builds to catch this as it's a good indicator for
+            -- issues with CFG invariants. But we don't want to blow up production
+            -- builds if something slips through.
+            ASSERT(null unplaced)
             --pprTraceIt "placedBlocks" $
-            blockList ++ unplaced
+            -- ++ [] is stil kinda expensive
+            if null unplaced then blockList else blockList ++ unplaced
         getBlock bid = expectJust "Block placment" $ mapLookup bid blockMap
     in
         --Assert we placed all blocks given as input
         ASSERT(all (\bid -> mapMember bid blockMap) placedBlocks)
         dropJumps info $ map getBlock placedBlocks
 
+{-# SCC dropJumps #-}
+-- | Remove redundant jumps between blocks when we can rely on
+-- fall through.
 dropJumps :: forall a i. Instruction i => LabelMap a -> [GenBasicBlock i]
           -> [GenBasicBlock i]
 dropJumps _    [] = []
@@ -641,7 +777,8 @@ sequenceTop
     => DynFlags -- Determine which layout algo to use
     -> NcgImpl statics instr jumpDest
     -> Maybe CFG -- ^ CFG if we have one.
-    -> NatCmmDecl statics instr -> NatCmmDecl statics instr
+    -> NatCmmDecl statics instr -- ^ Function to serialize
+    -> NatCmmDecl statics instr
 
 sequenceTop _     _       _           top@(CmmData _ _) = top
 sequenceTop dflags ncgImpl edgeWeights
@@ -650,11 +787,13 @@ sequenceTop dflags ncgImpl edgeWeights
   --Use chain based algorithm
   , Just cfg <- edgeWeights
   = CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
+                            {-# SCC layoutBlocks #-}
                             sequenceChain info cfg blocks )
   | otherwise
   --Use old algorithm
   = let cfg = if dontUseCfg then Nothing else edgeWeights
     in  CmmProc info lbl live ( ListGraph $ ncgMakeFarBranches ncgImpl info $
+                                {-# SCC layoutBlocks #-}
                                 sequenceBlocks cfg info blocks)
   where
     dontUseCfg = gopt Opt_WeightlessBlocklayout dflags ||
index fee4718..8eb69a9 100644 (file)
@@ -6,31 +6,40 @@
 {-# LANGUAGE TupleSections #-}
 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
 {-# LANGUAGE CPP #-}
+{-# LANGUAGE Rank2Types #-}
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DataKinds #-}
 
 module CFG
     ( CFG, CfgEdge(..), EdgeInfo(..), EdgeWeight(..)
     , TransitionSource(..)
 
     --Modify the CFG
-    , addWeightEdge, addEdge, delEdge
+    , addWeightEdge, addEdge
+    , delEdge, delNode
     , addNodesBetween, shortcutWeightMap
     , reverseEdges, filterEdges
     , addImmediateSuccessor
-    , mkWeightInfo, adjustEdgeWeight
+    , mkWeightInfo, adjustEdgeWeight, setEdgeWeight
 
     --Query the CFG
     , infoEdgeList, edgeList
     , getSuccessorEdges, getSuccessors
-    , getSuccEdgesSorted, weightedEdgeList
+    , getSuccEdgesSorted
     , getEdgeInfo
     , getCfgNodes, hasNode
-    , loopMembers
+
+    -- Loop Information
+    , loopMembers, loopLevels, loopInfo
 
     --Construction/Misc
     , getCfg, getCfgProc, pprEdgeWeights, sanityCheckCfg
 
     --Find backedges and update their weight
-    , optimizeCFG )
+    , optimizeCFG
+    , mkGlobalWeights
+
+     )
 where
 
 #include "HsVersions.h"
@@ -38,9 +47,8 @@ where
 import GhcPrelude
 
 import BlockId
-import Cmm ( RawCmmDecl, GenCmmDecl( .. ), CmmBlock, succ, g_entry
-           , CmmGraph )
-import CmmNode
+import Cmm
+
 import CmmUtils
 import CmmSwitch
 import Hoopl.Collections
@@ -50,10 +58,24 @@ import qualified Hoopl.Graph as G
 
 import Util
 import Digraph
+import Maybes
+
+import Unique
+import qualified Dominators as Dom
+import Data.IntMap.Strict (IntMap)
+import Data.IntSet (IntSet)
+
+import qualified Data.IntMap.Strict as IM
+import qualified Data.Map as M
+import qualified Data.IntSet as IS
+import qualified Data.Set as S
+import Data.Tree
+import Data.Bifunctor
 
 import Outputable
 -- DEBUGGING ONLY
 --import Debug
+-- import Debug.Trace
 --import OrdList
 --import Debug.Trace
 import PprCmm () -- For Outputable instances
@@ -61,17 +83,28 @@ import qualified DynFlags as D
 
 import Data.List
 
--- import qualified Data.IntMap.Strict as M --TODO: LabelMap
+import Data.STRef.Strict
+import Control.Monad.ST
+
+import Data.Array.MArray
+import Data.Array.ST
+import Data.Array.IArray
+import Data.Array.Unsafe (unsafeFreeze)
+import Data.Array.Base (unsafeRead, unsafeWrite)
+
+import Control.Monad
+
+type Prob = Double
 
 type Edge = (BlockId, BlockId)
 type Edges = [Edge]
 
 newtype EdgeWeight
-  = EdgeWeight Int
-  deriving (Eq,Ord,Enum,Num,Real,Integral)
+  = EdgeWeight { weightToDouble :: Double }
+  deriving (Eq,Ord,Enum,Num,Real,Fractional)
 
 instance Outputable EdgeWeight where
-  ppr (EdgeWeight w) = ppr w
+  ppr (EdgeWeight w) = doublePrec 5 w
 
 type EdgeInfoMap edgeInfo = LabelMap (LabelMap edgeInfo)
 
@@ -108,15 +141,28 @@ instance Outputable CfgEdge where
     = parens (ppr from1 <+> text "-(" <> ppr edgeInfo <> text ")->" <+> ppr to1)
 
 -- | Can we trace back a edge to a specific Cmm Node
--- or has it been introduced for codegen. We use this to maintain
+-- or has it been introduced during assembly codegen. We use this to maintain
 -- some information which would otherwise be lost during the
 -- Cmm <-> asm transition.
 -- See also Note [Inverting Conditional Branches]
 data TransitionSource
-  = CmmSource (CmmNode O C)
+  = CmmSource { trans_cmmNode :: (CmmNode O C)
+              , trans_info :: BranchInfo }
   | AsmCodeGen
   deriving (Eq)
 
+data BranchInfo = NoInfo         -- ^ Unknown, but not heap or stack check.
+                | HeapStackCheck -- ^ Heap or stack check
+    deriving Eq
+
+instance Outputable BranchInfo where
+    ppr NoInfo = text "regular"
+    ppr HeapStackCheck = text "heap/stack"
+
+isHeapOrStackCheck :: TransitionSource -> Bool
+isHeapOrStackCheck (CmmSource { trans_info = HeapStackCheck}) = True
+isHeapOrStackCheck _ = False
+
 -- | Information about edges
 data EdgeInfo
   = EdgeInfo
@@ -127,12 +173,10 @@ data EdgeInfo
 instance Outputable EdgeInfo where
   ppr edgeInfo = text "weight:" <+> ppr (edgeWeight edgeInfo)
 
--- Allow specialization
-{-# INLINEABLE mkWeightInfo #-}
 -- | Convenience function, generate edge info based
 --   on weight not originating from cmm.
-mkWeightInfo :: Integral n => n -> EdgeInfo
-mkWeightInfo = EdgeInfo AsmCodeGen . fromIntegral
+mkWeightInfo :: EdgeWeight -> EdgeInfo
+mkWeightInfo = EdgeInfo AsmCodeGen
 
 -- | Adjust the weight between the blocks using the given function.
 --   If there is no such edge returns the original map.
@@ -140,12 +184,25 @@ adjustEdgeWeight :: CFG -> (EdgeWeight -> EdgeWeight)
                  -> BlockId -> BlockId -> CFG
 adjustEdgeWeight cfg f from to
   | Just info <- getEdgeInfo from to cfg
-  , weight <- edgeWeight info
-  = addEdge from to (info { edgeWeight = f weight}) cfg
+  , !weight <- edgeWeight info
+  , !newWeight <- f weight
+  = addEdge from to (info { edgeWeight = newWeight}) cfg
+  | otherwise = cfg
+
+-- | Set the weight between the blocks to the given weight.
+--   If there is no such edge returns the original map.
+setEdgeWeight :: CFG -> EdgeWeight
+              -> BlockId -> BlockId -> CFG
+setEdgeWeight cfg !weight from to
+  | Just info <- getEdgeInfo from to cfg
+  = addEdge from to (info { edgeWeight = weight}) cfg
   | otherwise = cfg
 
+
+
 getCfgNodes :: CFG -> LabelSet
-getCfgNodes m = mapFoldMapWithKey (\k v -> setFromList (k:mapKeys v)) m
+getCfgNodes m =
+    mapFoldlWithKey (\s k toMap -> mapFoldlWithKey (\s k _ -> setInsert k s) (setInsert k s) toMap ) setEmpty m
 
 hasNode :: CFG -> BlockId -> Bool
 hasNode m node = mapMember node m || any (mapMember node) m
@@ -294,6 +351,11 @@ delEdge from to m =
         remDest Nothing = Nothing
         remDest (Just wm) = Just $ mapDelete to wm
 
+delNode :: BlockId -> CFG -> CFG
+delNode node cfg =
+  fmap (mapDelete node)  -- < Edges to the node
+    (mapDelete node cfg) -- < Edges from the node
+
 -- | Destinations from bid ordered by weight (descending)
 getSuccEdgesSorted :: CFG -> BlockId -> [(BlockId,EdgeInfo)]
 getSuccEdgesSorted m bid =
@@ -315,36 +377,54 @@ getEdgeInfo from to m
     | otherwise
     = Nothing
 
+getEdgeWeight :: CFG -> BlockId -> BlockId -> EdgeWeight
+getEdgeWeight cfg from to =
+    edgeWeight $ expectJust "Edgeweight for noexisting block" $
+                 getEdgeInfo from to cfg
+
+getTransitionSource :: BlockId -> BlockId -> CFG -> TransitionSource
+getTransitionSource from to cfg = transitionSource $ expectJust "Source info for noexisting block" $
+                        getEdgeInfo from to cfg
+
 reverseEdges :: CFG -> CFG
-reverseEdges cfg = foldr add mapEmpty flatElems
+reverseEdges cfg = mapFoldlWithKey (\cfg from toMap -> go (addNode cfg from) from toMap) mapEmpty cfg
   where
-    elems = mapToList $ fmap mapToList cfg :: [(BlockId,[(BlockId,EdgeInfo)])]
-    flatElems =
-        concatMap (\(from,ws) -> map (\(to,info) -> (to,from,info)) ws ) elems
-    add (to,from,info) m = addEdge to from info m
+    -- We preserve nodes without outgoing edges!
+    addNode :: CFG -> BlockId -> CFG
+    addNode cfg b = mapInsertWith mapUnion b mapEmpty cfg
+    go :: CFG -> BlockId -> (LabelMap EdgeInfo) -> CFG
+    go cfg from toMap = mapFoldlWithKey (\cfg to info -> addEdge to from info cfg) cfg toMap  :: CFG
+
 
 -- | Returns a unordered list of all edges with info
 infoEdgeList :: CFG -> [CfgEdge]
 infoEdgeList m =
-  mapFoldMapWithKey
-    (\from toMap ->
-      map (\(to,info) -> CfgEdge from to info) (mapToList toMap))
-    m
-
--- | Unordered list of edges with weight as Tuple (from,to,weight)
-weightedEdgeList :: CFG -> [(BlockId,BlockId,EdgeWeight)]
-weightedEdgeList m =
-  mapFoldMapWithKey
-    (\from toMap ->
-      map (\(to,info) ->
-        (from,to, edgeWeight info)) (mapToList toMap))
-    m
-      --  (\(from, tos) -> map (\(to,info) -> (from,to, edgeWeight info)) tos )
+    go (mapToList m) []
+  where
+    -- We avoid foldMap to avoid thunk buildup
+    go :: [(BlockId,LabelMap EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
+    go [] acc = acc
+    go ((from,toMap):xs) acc
+      = go' xs from (mapToList toMap) acc
+    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [(BlockId,EdgeInfo)] -> [CfgEdge] -> [CfgEdge]
+    go' froms _    []              acc = go froms acc
+    go' froms from ((to,info):tos) acc
+      = go' froms from tos (CfgEdge from to info : acc)
 
 -- | Returns a unordered list of all edges without weights
 edgeList :: CFG -> [Edge]
 edgeList m =
-        mapFoldMapWithKey (\from toMap -> fmap (from,) (mapKeys toMap)) m
+    go (mapToList m) []
+  where
+    -- We avoid foldMap to avoid thunk buildup
+    go :: [(BlockId,LabelMap EdgeInfo)] -> [Edge] -> [Edge]
+    go [] acc = acc
+    go ((from,toMap):xs) acc
+      = go' xs from (mapKeys toMap) acc
+    go' :: [(BlockId,LabelMap EdgeInfo)] -> BlockId -> [BlockId] -> [Edge] -> [Edge]
+    go' froms _    []              acc = go froms acc
+    go' froms from (to:tos) acc
+      = go' froms from tos ((from,to) : acc)
 
 -- | Get successors of a given node without edge weights.
 getSuccessors :: CFG -> BlockId -> [BlockId]
@@ -355,8 +435,8 @@ getSuccessors m bid
 
 pprEdgeWeights :: CFG -> SDoc
 pprEdgeWeights m =
-    let edges = sort $ weightedEdgeList m
-        printEdge (from, to, weight)
+    let edges = sort $ infoEdgeList m :: [CfgEdge]
+        printEdge (CfgEdge from to (EdgeInfo { edgeWeight = weight }))
             = text "\t" <> ppr from <+> text "->" <+> ppr to <>
               text "[label=\"" <> ppr weight <> text "\",weight=\"" <>
               ppr weight <> text "\"];\n"
@@ -365,7 +445,7 @@ pprEdgeWeights m =
         --to immediately see it when it does.
         printNode node
             = text "\t" <> ppr node <> text ";\n"
-        getEdgeNodes (from, to, _weight) = [from,to]
+        getEdgeNodes (CfgEdge from to _) = [from,to]
         edgeNodes = setFromList $ concatMap getEdgeNodes edges :: LabelSet
         nodes = filter (\n -> (not . setMember n) edgeNodes) . mapKeys $ mapFilter null m
     in
@@ -378,8 +458,8 @@ pprEdgeWeights m =
 updateEdgeWeight :: (EdgeWeight -> EdgeWeight) -> Edge -> CFG -> CFG
 updateEdgeWeight f (from, to) cfg
     | Just oldInfo <- getEdgeInfo from to cfg
-    = let oldWeight = edgeWeight oldInfo
-          newWeight = f oldWeight
+    = let !oldWeight = edgeWeight oldInfo
+          !newWeight = f oldWeight
       in addEdge from to (oldInfo {edgeWeight = newWeight}) cfg
     | otherwise
     = panic "Trying to update invalid edge"
@@ -448,9 +528,7 @@ addNodesBetween m updates =
 
   Should A or B be placed in front of C? The block layout algorithm
   decides this based on which edge (A,C)/(B,C) is heavier. So we
-  make a educated guess how often execution will transer control
-  along each edge as well as how much we gain by placing eg A before
-  C.
+  make a educated guess on which branch should be preferred.
 
   We rank edges in this order:
   * Unconditional Control Transfer - They will always
@@ -479,7 +557,6 @@ addNodesBetween m updates =
           address. This reduces the chance that we return to the same
           cache line further.
 
-
 -}
 -- | Generate weights for a Cmm proc based on some simple heuristics.
 getCfgProc :: D.CfgWeights -> RawCmmDecl -> CFG
@@ -515,13 +592,24 @@ getCfg weights graph =
     getBlockEdges block =
       case branch of
         CmmBranch dest -> [mkEdge dest uncondWeight]
-        CmmCondBranch _c t f l
+        CmmCondBranch cond t f l
           | l == Nothing ->
               [mkEdge f condBranchWeight,   mkEdge t condBranchWeight]
           | l == Just True ->
               [mkEdge f unlikelyCondWeight, mkEdge t likelyCondWeight]
           | l == Just False ->
               [mkEdge f likelyCondWeight,   mkEdge t unlikelyCondWeight]
+          where
+            mkEdgeInfo = -- pprTrace "Info" (ppr branchInfo <+> ppr cond)
+                         EdgeInfo (CmmSource branch branchInfo) . fromIntegral
+            mkEdge target weight = ((bid,target), mkEdgeInfo weight)
+            branchInfo =
+              foldRegsUsed
+                (panic "foldRegsDynFlags")
+                (\info r -> if r == SpLim || r == HpLim || r == BaseReg
+                    then HeapStackCheck else info)
+                NoInfo cond
+
         (CmmSwitch _e ids) ->
           let switchTargets = switchTargetsToList ids
               --Compiler performance hack - for very wide switches don't
@@ -539,7 +627,7 @@ getCfg weights graph =
             map (\x -> ((bid,x),mkEdgeInfo 0)) $ G.successors other
       where
         bid = G.entryLabel block
-        mkEdgeInfo = EdgeInfo (CmmSource branch) . fromIntegral
+        mkEdgeInfo = EdgeInfo (CmmSource branch NoInfo) . fromIntegral
         mkEdge target weight = ((bid,target), mkEdgeInfo weight)
         branch = lastNode block :: CmmNode O C
 
@@ -561,6 +649,11 @@ findBackEdges root cfg =
 optimizeCFG :: D.CfgWeights -> RawCmmDecl -> CFG -> CFG
 optimizeCFG _ (CmmData {}) cfg = cfg
 optimizeCFG weights (CmmProc info _lab _live graph) cfg =
+    {-# SCC optimizeCFG #-}
+    -- pprTrace "Initial:" (pprEdgeWeights cfg) $
+    -- pprTrace "Initial:" (ppr $ mkGlobalWeights (g_entry graph) cfg) $
+
+    -- pprTrace "LoopInfo:" (ppr $ loopInfo cfg (g_entry graph)) $
     favourFewerPreds  .
     penalizeInfoTables info .
     increaseBackEdgeWeight (g_entry graph) $ cfg
@@ -590,12 +683,8 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
           = weight - (fromIntegral $ D.infoTablePenalty weights)
           | otherwise = weight
 
-
-{- Note [Optimize for Fallthrough]
-
--}
     -- | If a block has two successors, favour the one with fewer
-    -- predecessors. (As that one is more likely to become a fallthrough)
+    -- predecessors and/or the one allowing fall through.
     favourFewerPreds :: CFG -> CFG
     favourFewerPreds cfg =
         let
@@ -612,16 +701,17 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
               | preds1 == preds2 = ( 0, 0)
               | otherwise        = (-1, 1)
 
+            update :: CFG -> BlockId -> CFG
             update cfg node
               | [(s1,e1),(s2,e2)] <- getSuccessorEdges cfg node
-              , w1 <- edgeWeight e1
-              , w2 <- edgeWeight e2
+              , !w1 <- edgeWeight e1
+              , !w2 <- edgeWeight e2
               --Only change the weights if there isn't already a ordering.
               , w1 == w2
               , (mod1,mod2) <- modifiers (predCount s1) (predCount s2)
               = (\cfg' ->
                   (adjustEdgeWeight cfg' (+mod2) node s2))
-                  (adjustEdgeWeight cfg  (+mod1) node s1)
+                    (adjustEdgeWeight cfg  (+mod1) node s1)
               | otherwise
               = cfg
         in setFoldl update cfg nodes
@@ -630,13 +720,12 @@ optimizeCFG weights (CmmProc info _lab _live graph) cfg =
         fallthroughTarget to (EdgeInfo source _weight)
           | mapMember to info = False
           | AsmCodeGen <- source = True
-          | CmmSource (CmmBranch {}) <- source = True
-          | CmmSource (CmmCondBranch {}) <- source = True
+          | CmmSource { trans_cmmNode = CmmBranch {} } <- source = True
+          | CmmSource { trans_cmmNode = CmmCondBranch {} } <- source = True
           | otherwise = False
 
 -- | Determine loop membership of blocks based on SCC analysis
---   Ideally we would replace this with a variant giving us loop
---   levels instead but the SCC code will do for now.
+--   This is faster but only gives yes/no answers.
 loopMembers :: CFG -> LabelMap Bool
 loopMembers cfg =
     foldl' (flip setLevel) mapEmpty sccs
@@ -650,3 +739,534 @@ loopMembers cfg =
     setLevel :: SCC BlockId -> LabelMap Bool -> LabelMap Bool
     setLevel (AcyclicSCC bid) m = mapInsert bid False m
     setLevel (CyclicSCC bids) m = foldl' (\m k -> mapInsert k True m) m bids
+
+loopLevels :: CFG -> BlockId -> LabelMap Int
+loopLevels cfg root = liLevels $ loopInfo cfg root
+
+data LoopInfo = LoopInfo
+  { liBackEdges :: [(Edge)] -- ^ List of back edges
+  , liLevels :: LabelMap Int -- ^ BlockId -> LoopLevel mapping
+  , liLoops :: [(Edge, LabelSet)] -- ^ (backEdge, loopBody), body includes header
+  }
+
+instance Outputable LoopInfo where
+    ppr (LoopInfo _ _lvls loops) =
+        text "Loops:(backEdge, bodyNodes)" $$
+            (vcat $ map ppr loops)
+
+-- | Determine loop membership of blocks based on Dominator analysis.
+--   This is slower but gives loop levels instead of just loop membership.
+--   However it only detects natural loops. Irreducible control flow is not
+--   recognized even if it loops. But that is rare enough that we don't have
+--   to care about that special case.
+loopInfo :: CFG -> BlockId -> LoopInfo
+loopInfo cfg root = LoopInfo  { liBackEdges = backEdges
+                              , liLevels = mapFromList loopCounts
+                              , liLoops = loopBodies }
+  where
+    revCfg = reverseEdges cfg
+    graph = fmap (setFromList . mapKeys ) cfg :: LabelMap LabelSet
+
+    --TODO - This should be a no op: Export constructors? Use unsafeCoerce? ...
+    rooted = ( fromBlockId root
+              , toIntMap $ fmap toIntSet graph) :: (Int, IntMap IntSet)
+    -- rooted = unsafeCoerce (root, graph)
+    tree = fmap toBlockId $ Dom.domTree rooted :: Tree BlockId
+
+    -- Map from Nodes to their dominators
+    domMap :: LabelMap LabelSet
+    domMap = mkDomMap tree
+
+    edges = edgeList cfg :: [(BlockId, BlockId)]
+    -- We can't recompute this from the edges, there might be blocks not connected via edges.
+    nodes = getCfgNodes cfg :: LabelSet
+
+    -- identify back edges
+    isBackEdge (from,to)
+      | Just doms <- mapLookup from domMap
+      , setMember to doms
+      = True
+      | otherwise = False
+
+    -- determine the loop body for a back edge
+    findBody edge@(tail, head)
+      = ( edge, setInsert head $ go (setSingleton tail) (setSingleton tail) )
+      where
+        -- The reversed cfg makes it easier to look up predecessors
+        cfg' = delNode head revCfg
+        go :: LabelSet -> LabelSet -> LabelSet
+        go found current
+          | setNull current = found
+          | otherwise = go  (setUnion newSuccessors found)
+                            newSuccessors
+          where
+            newSuccessors = setFilter (\n -> not $ setMember n found) successors :: LabelSet
+            successors = setFromList $ concatMap
+                                      (getSuccessors cfg')
+                                      (setElems current) :: LabelSet
+
+    backEdges = filter isBackEdge edges
+    loopBodies = map findBody backEdges :: [(Edge, LabelSet)]
+
+    -- Block b is part of n loop bodies => loop nest level of n
+    loopCounts =
+      let bodies = map (first snd) loopBodies -- [(Header, Body)]
+          loopCount n = length $ nub . map fst . filter (setMember n . snd) $ bodies
+      in  map (\n -> (n, loopCount n)) $ setElems nodes :: [(BlockId, Int)]
+
+    toIntSet :: LabelSet -> IntSet
+    toIntSet s = IS.fromList . map fromBlockId . setElems $ s
+    toIntMap :: LabelMap a -> IntMap a
+    toIntMap m = IM.fromList $ map (\(x,y) -> (fromBlockId x,y)) $ mapToList m
+
+    mkDomMap :: Tree BlockId -> LabelMap LabelSet
+    mkDomMap root = mapFromList $ go setEmpty root
+      where
+        go :: LabelSet -> Tree BlockId -> [(Label,LabelSet)]
+        go parents (Node lbl [])
+          =  [(lbl, parents)]
+        go parents (Node _ leaves)
+          = let nodes = map rootLabel leaves
+                entries = map (\x -> (x,parents)) nodes
+            in  entries ++ concatMap
+                            (\n -> go (setInsert (rootLabel n) parents) n)
+                            leaves
+
+    fromBlockId :: BlockId -> Int
+    fromBlockId = getKey . getUnique
+
+    toBlockId :: Int -> BlockId
+    toBlockId = mkBlockId . mkUniqueGrimily
+
+-- We make the CFG a Hoopl Graph, so we can reuse revPostOrder.
+newtype BlockNode (e :: Extensibility) (x :: Extensibility) = BN (BlockId,[BlockId])
+
+instance G.NonLocal (BlockNode) where
+  entryLabel (BN (lbl,_))   = lbl
+  successors (BN (_,succs)) = succs
+
+revPostorderFrom :: CFG -> BlockId -> [BlockId]
+revPostorderFrom cfg root =
+    map fromNode $ G.revPostorderFrom hooplGraph root
+  where
+    nodes = getCfgNodes cfg
+    hooplGraph = setFoldl (\m n -> mapInsert n (toNode n) m) mapEmpty nodes
+
+    fromNode :: BlockNode C C -> BlockId
+    fromNode (BN x) = fst x
+
+    toNode :: BlockId -> BlockNode C C
+    toNode bid =
+        BN (bid,getSuccessors cfg $ bid)
+
+
+-- | We take in a CFG which has on its edges weights which are
+--   relative only to other edges originating from the same node.
+--
+--   We return a CFG for which each edge represents a GLOBAL weight.
+--   This means edge weights are comparable across the whole graph.
+--
+--   For irreducible control flow results might be imprecise, otherwise they
+--   are reliable.
+--
+--   The algorithm is based on the Paper
+--   "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
+--   The only big change is that we go over the nodes in the body of loops in
+--   reverse post order. Which is required for diamond control flow to work probably.
+--
+--   We also apply a few prediction heuristics (based on the same paper)
+
+{-# SCC mkGlobalWeights #-}
+mkGlobalWeights :: BlockId -> CFG -> (LabelMap Double, LabelMap (LabelMap Double))
+mkGlobalWeights root localCfg
+  | null localCfg = panic "Error - Empty CFG"
+  | otherwise
+  = --pprTrace "revOrder" (ppr revOrder) $
+    -- undefined --propagate (mapSingleton root 1) (revOrder)
+    (blockFreqs', edgeFreqs')
+  where
+    -- Calculate fixpoints
+    (blockFreqs, edgeFreqs) = calcFreqs nodeProbs backEdges' bodies' revOrder'
+    blockFreqs' = mapFromList $ map (first fromVertex) (assocs blockFreqs) :: LabelMap Double
+    edgeFreqs' = fmap fromVertexMap $ fromVertexMap edgeFreqs
+
+    fromVertexMap :: IM.IntMap x -> LabelMap x
+    fromVertexMap m = mapFromList . map (first fromVertex) $ IM.toList m
+
+    revOrder = revPostorderFrom localCfg root :: [BlockId]
+    loopinfo@(LoopInfo backedges _levels bodies) = loopInfo localCfg root
+
+    revOrder' = map toVertex revOrder
+    backEdges' = map (bimap toVertex toVertex) backedges
+    bodies' = map calcBody bodies
+
+    estimatedCfg = staticBranchPrediction root loopinfo localCfg
+    -- Normalize the weights to probabilities and apply heuristics
+    nodeProbs = cfgEdgeProbabilities estimatedCfg toVertex
+
+    -- By mapping vertices to numbers in reverse post order we can bring any subset into reverse post
+    -- order simply by sorting.
+    -- TODO: The sort is redundant if we can guarantee that setElems returns elements ascending
+    calcBody (backedge, blocks) =
+        (toVertex $ snd backedge, sort . map toVertex $ (setElems blocks))
+
+    vertexMapping = mapFromList $ zip revOrder [0..] :: LabelMap Int
+    blockMapping = listArray (0,mapSize vertexMapping - 1) revOrder :: Array Int BlockId
+    -- Map from blockId to indicies starting at zero
+    toVertex :: BlockId -> Int
+    toVertex   blockId  = expectJust "mkGlobalWeights" $ mapLookup blockId vertexMapping
+    -- Map from indicies starting at zero to blockIds
+    fromVertex :: Int -> BlockId
+    fromVertex vertex   = blockMapping ! vertex
+
+{- Note [Static Branch Prediction]
+   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The work here has been based on the paper
+"Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus.
+
+The primary differences are that if we branch on the result of a heap
+check we do not apply any of the heuristics.
+The reason is simple: They look like loops in the control flow graph
+but are usually never entered, and if at most once.
+
+Currently implemented is a heuristic to predict that we do not exit
+loops (lehPredicts) and one to predict that backedges are more likely
+than any other edge.
+
+The back edge case is special as it superceeds any other heuristic if it
+applies.
+
+Do NOT rely solely on nofib results for benchmarking this. I recommend at least
+comparing megaparsec and container benchmarks. Nofib does not seeem to have
+many instances of "loopy" Cmm where these make a difference.
+
+TODO:
+* The paper containers more benchmarks which should be implemented.
+* If we turn the likelyhood on if/else branches into a probability
+  instead of true/false we could implement this as a Cmm pass.
+  + The complete Cmm code still exists and can be accessed by the heuristics
+  + There is no chance of register allocation/codegen inserting branches/blocks
+  + making the TransitionSource info wrong.
+  + potential to use this information in CmmPasses.
+  - Requires refactoring of all the code relying on the binary nature of likelyhood.
+  - Requires refactoring `loopInfo` to work on both, Cmm Graphs and the backend CFG.
+-}
+
+-- | Combination of target node id and information about the branch
+--   we are looking at.
+type TargetNodeInfo = (BlockId, EdgeInfo)
+
+
+-- | Update branch weights based on certain heuristics.
+-- See Note [Static Branch Prediction]
+-- TODO: This should be combined with optimizeCFG
+{-# SCC staticBranchPrediction #-}
+staticBranchPrediction :: BlockId -> LoopInfo -> CFG -> CFG
+staticBranchPrediction _root (LoopInfo l_backEdges loopLevels l_loops) cfg =
+    -- pprTrace "staticEstimatesOn" (ppr (cfg)) $
+    setFoldl update cfg nodes
+  where
+    nodes = getCfgNodes cfg
+    backedges = S.fromList $ l_backEdges
+    -- Loops keyed by their back edge
+    loops = M.fromList $ l_loops :: M.Map Edge LabelSet
+    loopHeads = S.fromList $ map snd $ M.keys loops
+
+    update :: CFG -> BlockId -> CFG
+    update cfg node
+        -- No successors, nothing to do.
+        | null successors = cfg
+
+        -- Mix of backedges and others:
+        -- Always predict the backedges.
+        | not (null m) && length m < length successors
+        -- Heap/Stack checks "loop", but only once.
+        -- So we simply exclude any case involving them.
+        , not $ any (isHeapOrStackCheck  . transitionSource . snd) successors
+        = let   loopChance = repeat $! pred_LBH / (fromIntegral $ length m)
+                exitChance = repeat $! (1 - pred_LBH) / fromIntegral (length not_m)
+                updates = zip (map fst m) loopChance ++ zip (map fst not_m) exitChance
+        in  -- pprTrace "mix" (ppr (node,successors)) $
+            foldl' (\cfg (to,weight) -> setEdgeWeight cfg weight node to) cfg updates
+
+        -- For (regular) non-binary branches we keep the weights from the STG -> Cmm translation.
+        | length successors /= 2
+        = cfg
+
+        -- Only backedges - no need to adjust
+        | length m > 0
+        = cfg
+
+        -- A regular binary branch, we can plug addition predictors in here.
+        | [(s1,s1_info),(s2,s2_info)] <- successors
+        , not $ any (isHeapOrStackCheck  . transitionSource . snd) successors
+        = -- Normalize weights to total of 1
+            let !w1 = max (edgeWeight s1_info) (0)
+                !w2 = max (edgeWeight s2_info) (0)
+                -- Of both weights are <= 0 we set both to 0.5
+                normalizeWeight w = if w1 + w2 == 0 then 0.5 else w/(w1+w2)
+                !cfg'  = setEdgeWeight cfg  (normalizeWeight w1) node s1
+                !cfg'' = setEdgeWeight cfg' (normalizeWeight w2) node s2
+
+                -- Figure out which heuristics apply to these successors
+                heuristics = map ($ ((s1,s1_info),(s2,s2_info)))
+                            [lehPredicts, phPredicts, ohPredicts, ghPredicts, lhhPredicts, chPredicts
+                            , shPredicts, rhPredicts]
+                -- Apply result of a heuristic. Argument is the likelyhood
+                -- predicted for s1.
+                applyHeuristic :: CFG -> Maybe Prob -> CFG
+                applyHeuristic cfg Nothing = cfg
+                applyHeuristic cfg (Just (s1_pred :: Double))
+                  | s1_old == 0 || s2_old == 0 ||
+                    isHeapOrStackCheck (transitionSource s1_info) ||
+                    isHeapOrStackCheck (transitionSource s2_info)
+                  = cfg
+                  | otherwise =
+                    let -- Predictions from heuristic
+                        s1_prob = EdgeWeight s1_pred :: EdgeWeight
+                        s2_prob = 1.0 - s1_prob
+                        -- Update
+                        d = (s1_old * s1_prob) + (s2_old * s2_prob) :: EdgeWeight
+                        s1_prob' = s1_old * s1_prob / d
+                        !s2_prob' = s2_old * s2_prob / d
+                        !cfg_s1 = setEdgeWeight cfg    s1_prob' node s1
+                    in  -- pprTrace "Applying heuristic!" (ppr (node,s1,s2) $$ ppr (s1_prob', s2_prob')) $
+                        setEdgeWeight cfg_s1 s2_prob' node s2
+                  where
+                    -- Old weights
+                    s1_old = getEdgeWeight cfg node s1
+                    s2_old = getEdgeWeight cfg node s2
+
+            in
+            -- pprTraceIt "RegularCfgResult" $
+            foldl' applyHeuristic cfg'' heuristics
+
+        -- Branch on heap/stack check
+        | otherwise = cfg
+
+      where
+        -- Chance that loops are taken.
+        pred_LBH = 0.875
+        -- successors
+        successors = getSuccessorEdges cfg node
+        -- backedges
+        (m,not_m) = partition (\succ -> S.member (node, fst succ) backedges) successors
+
+        -- Heuristics return nothing if they don't say anything about this branch
+        -- or Just (prob_s1) where prob_s1 is the likelyhood for s1 to be the
+        -- taken branch. s1 is the branch in the true case.
+
+        -- Loop exit heuristic.
+        -- We are unlikely to leave a loop unless it's to enter another one.
+        pred_LEH = 0.75
+        -- If and only if no successor is a loopheader,
+        -- then we will likely not exit the current loop body.
+        lehPredicts :: (TargetNodeInfo,TargetNodeInfo) -> Maybe Prob
+        lehPredicts ((s1,_s1_info),(s2,_s2_info))
+          | S.member s1 loopHeads || S.member s2 loopHeads
+          = Nothing
+
+          | otherwise
+          = --pprTrace "lehPredict:" (ppr $ compare s1Level s2Level) $
+            case compare s1Level s2Level of
+                EQ -> Nothing
+                LT -> Just (1-pred_LEH) --s1 exits to a shallower loop level (exits loop)
+                GT -> Just (pred_LEH)   --s1 exits to a deeper loop level
+            where
+                s1Level = mapLookup s1 loopLevels
+                s2Level = mapLookup s2 loopLevels
+
+        -- Comparing to a constant is unlikely to be equal.
+        ohPredicts (s1,_s2)
+            | CmmSource { trans_cmmNode = src1 } <- getTransitionSource node (fst s1) cfg
+            , CmmCondBranch cond ltrue _lfalse likely <- src1
+            , likely == Nothing
+            , CmmMachOp mop args <- cond
+            , MO_Eq {} <- mop
+            , not (null [x | x@CmmLit{} <- args])
+            = if fst s1 == ltrue then Just 0.3 else Just 0.7
+
+            | otherwise
+            = Nothing
+
+        -- TODO: These are all the other heuristics from the paper.
+        -- Not all will apply, for now we just stub them out as Nothing.
+        phPredicts = const Nothing
+        ghPredicts = const Nothing
+        lhhPredicts = const Nothing
+        chPredicts = const Nothing
+        shPredicts = const Nothing
+        rhPredicts = const Nothing
+
+-- We normalize all edge weights as probabilities between 0 and 1.
+-- Ignoring rounding errors all outgoing edges sum up to 1.
+cfgEdgeProbabilities :: CFG -> (BlockId -> Int) -> IM.IntMap (IM.IntMap Prob)
+cfgEdgeProbabilities cfg toVertex
+    = mapFoldlWithKey foldEdges IM.empty cfg
+  where
+    foldEdges = (\m from toMap -> IM.insert (toVertex from) (normalize toMap) m)
+
+    normalize :: (LabelMap EdgeInfo) -> (IM.IntMap Prob)
+    normalize weightMap
+        | edgeCount <= 1 = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) 1.0 m) IM.empty weightMap
+        | otherwise = mapFoldlWithKey (\m k _ -> IM.insert (toVertex k) (normalWeight k) m) IM.empty weightMap
+      where
+        edgeCount = mapSize weightMap
+        -- Negative weights are generally allowed but are mapped to zero.
+        -- We then check if there is at least one non-zero edge and if not
+        -- assign uniform weights to all branches.
+        minWeight = 0 :: Prob
+        weightMap' = fmap (\w -> max (weightToDouble . edgeWeight $ w) minWeight) weightMap
+        totalWeight = sum weightMap'
+
+        normalWeight :: BlockId -> Prob
+        normalWeight bid
+         | totalWeight == 0
+         = 1.0 / fromIntegral edgeCount
+         | Just w <- mapLookup bid weightMap'
+         = w/totalWeight
+         | otherwise = panic "impossible"
+
+-- This is the fixpoint algorithm from
+--   "Static Branch Prediction and Program Profile Analysis" by Y Wu, JR Larus
+-- The adaption to Haskell is my own.
+calcFreqs :: IM.IntMap (IM.IntMap Prob) -> [(Int,Int)] -> [(Int, [Int])] -> [Int]
+          -> (Array Int Double, IM.IntMap (IM.IntMap Prob))
+calcFreqs graph backEdges loops revPostOrder = runST $ do
+    visitedNodes <- newArray (0,nodeCount-1) False :: ST s (STUArray s Int Bool)
+    blockFreqs <- newArray (0,nodeCount-1) 0.0 :: ST s (STUArray s Int Double)
+    edgeProbs <- newSTRef graph
+    edgeBackProbs <- newSTRef graph
+
+    -- let traceArray a = do
+    --       vs <- forM [0..nodeCount-1] $ \i -> readArray a i >>= (\v -> return (i,v))
+          -- trace ("array: " ++ show vs) $ return ()
+
+    let  -- See #1600, we need to inline or unboxing makes perf worse.
+        -- {-# INLINE getFreq #-}
+        {-# INLINE visited #-}
+        visited b = unsafeRead visitedNodes b
+        getFreq b = unsafeRead blockFreqs b
+        -- setFreq :: forall s. Int -> Double -> ST s ()
+        setFreq b f = unsafeWrite blockFreqs b f
+        -- setVisited :: forall s. Node -> ST s ()
+        setVisited b = unsafeWrite visitedNodes b True
+        -- Frequency/probability that edge is taken.
+        getProb' arr b1 b2 = readSTRef arr >>=
+            (\graph ->
+                return .
+                        fromMaybe (error "getFreq 1") .
+                        IM.lookup b2 .
+                        fromMaybe (error "getFreq 2") $
+                        (IM.lookup b1 graph)
+            )
+        setProb' arr b1 b2 prob = do
+          g <- readSTRef arr
+          let !m = fromMaybe (error "Foo") $ IM.lookup b1 g
+              !m' = IM.insert b2 prob m
+          writeSTRef arr $! (IM.insert b1 m' g)
+
+        getEdgeFreq b1 b2 = getProb' edgeProbs b1 b2
+        setEdgeFreq b1 b2 = setProb' edgeProbs b1 b2
+        getProb b1 b2 = fromMaybe (error "getProb") $ do
+            m' <- IM.lookup b1 graph
+            IM.lookup b2 m'
+
+        getBackProb b1 b2 = getProb' edgeBackProbs b1 b2
+        setBackProb b1 b2 = setProb' edgeBackProbs b1 b2
+
+
+    let -- calcOutFreqs :: Node -> ST s ()
+        calcOutFreqs bhead block = do
+          !f <- getFreq block
+          forM (successors block) $ \bi -> do
+            let !prob = getProb block bi
+            let !succFreq = f * prob
+            setEdgeFreq block bi succFreq
+            -- traceM $ "SetOut: " ++ show (block, bi, f, prob, succFreq)
+            when (bi == bhead) $ setBackProb block bi succFreq
+
+
+    let propFreq block head = do
+            -- traceM ("prop:" ++ show (block,head))
+            -- traceShowM block
+
+            !v <- visited block
+            if v then
+                return () --Dont look at nodes twice
+            else if block == head then
+                setFreq block 1.0 -- Loop header frequency is always 1
+            else do
+                let preds = IS.elems $ predecessors block
+                irreducible <- (fmap or) $ forM preds $ \bp -> do
+                    !bp_visited <- visited bp
+                    let bp_backedge = isBackEdge bp block
+                    return (not bp_visited && not bp_backedge)
+
+                if irreducible
+                then return () -- Rare we don't care
+                else do
+                    setFreq block 0
+                    !cycleProb <- sum <$> (forM preds $ \pred -> do
+                        if isBackEdge pred block
+                            then
+                                getBackProb pred block
+                            else do
+                                !f <- getFreq block
+                                !prob <- getEdgeFreq pred block
+                                setFreq block $! f + prob
+                                return 0)
+                    -- traceM $ "cycleProb:" ++ show cycleProb
+                    let limit = 1 - 1/512 -- Paper uses 1 - epsilon, but this works.
+                                          -- determines how large likelyhoods in loops can grow.
+                    !cycleProb <- return $ min cycleProb limit -- <- return $ if cycleProb > limit then limit else cycleProb
+                    -- traceM $ "cycleProb:" ++ show cycleProb
+
+                    !f <- getFreq block
+                    setFreq block (f / (1.0 - cycleProb))
+
+            setVisited block
+            calcOutFreqs head block
+
+    -- Loops, by nesting, inner to outer
+    forM_ loops $ \(head, body) -> do
+        forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i True) -- Mark all nodes as visited.
+        forM_ body (\i -> unsafeWrite visitedNodes i False) -- Mark all blocks reachable from head as not visited
+        forM_ body $ \block -> propFreq block head
+
+    -- After dealing with all loops, deal with non-looping parts of the CFG
+    forM_ [0 .. nodeCount - 1] (\i -> unsafeWrite visitedNodes i False) -- Everything in revPostOrder is reachable
+    forM_ revPostOrder $ \block -> propFreq block (head revPostOrder)
+
+    -- trace ("Final freqs:") $ return ()
+    -- let freqString = pprFreqs freqs
+    -- trace (unlines freqString) $ return ()
+    -- trace (pprFre) $ return ()
+    graph' <- readSTRef edgeProbs
+    freqs' <- unsafeFreeze  blockFreqs
+
+    return (freqs', graph')
+  where
+    predecessors :: Int -> IS.IntSet
+    predecessors b = fromMaybe IS.empty $ IM.lookup b revGraph
+    successors b = fromMaybe (lookupError "succ" b graph)$ IM.keys <$> IM.lookup b graph
+    lookupError s b g = pprPanic ("Lookup error " ++ s) $
+                            ( text "node" <+> ppr b $$
+                                text "graph" <+>
+                                vcat (map (\(k,m) -> ppr (k,m :: IM.IntMap Double)) $ IM.toList g)
+                            )
+
+    nodeCount = IM.foldl' (\count toMap -> IM.foldlWithKey' countTargets count toMap) (IM.size graph) graph
+      where
+        countTargets = (\count k _ -> countNode k + count )
+        countNode n = if IM.member n graph then 0 else 1
+
+    isBackEdge from to = S.member (from,to) backEdgeSet
+    backEdgeSet = S.fromList backEdges
+
+    revGraph :: IntMap IntSet
+    revGraph = IM.foldlWithKey' (\m from toMap -> addEdges m from toMap) IM.empty graph
+        where
+            addEdges m0 from toMap = IM.foldlWithKey' (\m k _ -> addEdge m from k) m0 toMap
+            addEdge m0 from to = IM.insertWith IS.union to (IS.singleton from) m0
index 9c6e24d..52f5909 100644 (file)
@@ -1,4 +1,4 @@
-{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE ScopedTypeVariables, GADTs, BangPatterns #-}
 module RegAlloc.Graph.SpillCost (
         SpillCostRecord,
         plusSpillCostRecord,
@@ -23,6 +23,7 @@ import Reg
 import GraphBase
 
 import Hoopl.Collections (mapLookup)
+import Hoopl.Label
 import Cmm
 import UniqFM
 import UniqSet
@@ -49,9 +50,6 @@ type SpillCostRecord
 type SpillCostInfo
         = UniqFM SpillCostRecord
 
--- | Block membership in a loop
-type LoopMember = Bool
-
 type SpillCostState = State (UniqFM SpillCostRecord) ()
 
 -- | An empty map of spill costs.
@@ -88,45 +86,49 @@ slurpSpillCostInfo platform cfg cmm
  where
         countCmm CmmData{}              = return ()
         countCmm (CmmProc info _ _ sccs)
-                = mapM_ (countBlock info)
+                = mapM_ (countBlock info freqMap)
                 $ flattenSCCs sccs
+            where
+                LiveInfo _ entries _ _ = info
+                freqMap = (fst . mkGlobalWeights (head entries)) <$> cfg
 
         -- Lookup the regs that are live on entry to this block in
         --      the info table from the CmmProc.
-        countBlock info (BasicBlock blockId instrs)
+        countBlock info freqMap (BasicBlock blockId instrs)
                 | LiveInfo _ _ blockLive _ <- info
                 , Just rsLiveEntry  <- mapLookup blockId blockLive
                 , rsLiveEntry_virt  <- takeVirtuals rsLiveEntry
-                = countLIs (loopMember blockId) rsLiveEntry_virt instrs
+                = countLIs (ceiling $ blockFreq freqMap blockId) rsLiveEntry_virt instrs
 
                 | otherwise
                 = error "RegAlloc.SpillCost.slurpSpillCostInfo: bad block"
 
-        countLIs :: LoopMember -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
+
+        countLIs :: Int -> UniqSet VirtualReg -> [LiveInstr instr] -> SpillCostState
         countLIs _      _      []
                 = return ()
 
         -- Skip over comment and delta pseudo instrs.
-        countLIs inLoop rsLive (LiveInstr instr Nothing : lis)
+        countLIs scale rsLive (LiveInstr instr Nothing : lis)
                 | isMetaInstr instr
-                = countLIs inLoop rsLive lis
+                = countLIs scale rsLive lis
 
                 | otherwise
                 = pprPanic "RegSpillCost.slurpSpillCostInfo"
                 $ text "no liveness information on instruction " <> ppr instr
 
-        countLIs inLoop rsLiveEntry (LiveInstr instr (Just live) : lis)
+        countLIs scale rsLiveEntry (LiveInstr instr (Just live) : lis)
          = do
                 -- Increment the lifetime counts for regs live on entry to this instr.
-                mapM_ (incLifetime (loopCount inLoop)) $ nonDetEltsUniqSet rsLiveEntry
+                mapM_ incLifetime $ nonDetEltsUniqSet rsLiveEntry
                     -- This is non-deterministic but we do not
                     -- currently support deterministic code-generation.
                     -- See Note [Unique Determinism and code generation]
 
                 -- Increment counts for what regs were read/written from.
                 let (RU read written)   = regUsageOfInstr platform instr
-                mapM_ (incUses (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub read
-                mapM_ (incDefs (loopCount inLoop)) $ catMaybes $ map takeVirtualReg $ nub written
+                mapM_ (incUses scale) $ catMaybes $ map takeVirtualReg $ nub read
+                mapM_ (incDefs scale) $ catMaybes $ map takeVirtualReg $ nub written
 
                 -- Compute liveness for entry to next instruction.
                 let liveDieRead_virt    = takeVirtuals (liveDieRead  live)
@@ -140,21 +142,18 @@ slurpSpillCostInfo platform cfg cmm
                         = (rsLiveAcross `unionUniqSets` liveBorn_virt)
                                         `minusUniqSet`  liveDieWrite_virt
 
-                countLIs inLoop rsLiveNext lis
+                countLIs scale rsLiveNext lis
 
-        loopCount inLoop
-          | inLoop = 10
-          | otherwise = 1
         incDefs     count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, count, 0, 0)
         incUses     count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, count, 0)
-        incLifetime count reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, count)
+        incLifetime       reg = modify $ \s -> addToUFM_C plusSpillCostRecord s reg (reg, 0, 0, 1)
 
-        loopBlocks = CFG.loopMembers <$> cfg
-        loopMember bid
-          | Just isMember <- join (mapLookup bid <$> loopBlocks)
-          = isMember
+        blockFreq :: Maybe (LabelMap Double) -> Label -> Double
+        blockFreq freqs bid
+          | Just freq <- join (mapLookup bid <$> freqs)
+          = max 1.0 (10000 * freq)
           | otherwise
-          = False
+          = 1.0 -- Only if no cfg given
 
 -- | Take all the virtual registers from this set.
 takeVirtuals :: UniqSet Reg -> UniqSet VirtualReg
@@ -215,31 +214,39 @@ chooseSpill info graph
 --  Without live range splitting, its's better to spill from the outside
 --  in so set the cost of very long live ranges to zero
 --
-{-
-spillCost_chaitin
-        :: SpillCostInfo
-        -> Graph Reg RegClass Reg
-        -> Reg
-        -> Float
 
-spillCost_chaitin info graph reg
-        -- Spilling a live range that only lives for 1 instruction
-        -- isn't going to help us at all - and we definitely want to avoid
-        -- trying to re-spill previously inserted spill code.
-        | lifetime <= 1         = 1/0
-
-        -- It's unlikely that we'll find a reg for a live range this long
-        -- better to spill it straight up and not risk trying to keep it around
-        -- and have to go through the build/color cycle again.
-        | lifetime > allocatableRegsInClass (regClass reg) * 10
-        = 0
+-- spillCost_chaitin
+--         :: SpillCostInfo
+--         -> Graph VirtualReg RegClass RealReg
+--         -> VirtualReg
+--         -> Float
+
+-- spillCost_chaitin info graph reg
+--         -- Spilling a live range that only lives for 1 instruction
+--         -- isn't going to help us at all - and we definitely want to avoid
+--         -- trying to re-spill previously inserted spill code.
+--         | lifetime <= 1         = 1/0
+
+--         -- It's unlikely that we'll find a reg for a live range this long
+--         -- better to spill it straight up and not risk trying to keep it around
+--         -- and have to go through the build/color cycle again.
+
+--         -- To facility this we scale down the spill cost of long ranges.
+--         -- This makes sure long ranges are still spilled first.
+--         -- But this way spill cost remains relevant for long live
+--         -- ranges.
+--         | lifetime >= 128
+--         = (spillCost / conflicts) / 10.0
+
+
+--         -- Otherwise revert to chaitin's regular cost function.
+--         | otherwise = (spillCost / conflicts)
+--         where
+--             !spillCost = fromIntegral (uses + defs) :: Float
+--             conflicts = fromIntegral (nodeDegree classOfVirtualReg graph reg)
+--             (_, defs, uses, lifetime)
+--                 = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg
 
-        -- Otherwise revert to chaitin's regular cost function.
-        | otherwise     = fromIntegral (uses + defs)
-                        / fromIntegral (nodeDegree graph reg)
-        where (_, defs, uses, lifetime)
-                = fromMaybe (reg, 0, 0, 0) $ lookupUFM info reg
--}
 
 -- Just spill the longest live range.
 spillCost_length
index 7a2d599..b1dd9c5 100644 (file)
@@ -3529,7 +3529,7 @@ invertCondBranches (Just cfg) keep bs =
       , Just edgeInfo2 <- getEdgeInfo lbl1 target2 cfg
       -- Both jumps come from the same cmm statement
       , transitionSource edgeInfo1 == transitionSource edgeInfo2
-      , (CmmSource cmmCondBranch) <- transitionSource edgeInfo1
+      , CmmSource {trans_cmmNode = cmmCondBranch} <- transitionSource edgeInfo1
 
       --Int comparisons are invertable
       , CmmCondBranch (CmmMachOp op _args) _ _ _ <- cmmCondBranch
diff --git a/compiler/utils/Dominators.hs b/compiler/utils/Dominators.hs
new file mode 100644 (file)
index 0000000..9877c2c
--- /dev/null
@@ -0,0 +1,588 @@
+{-# LANGUAGE RankNTypes, BangPatterns, FlexibleContexts, Strict #-}\r
+\r
+{- |\r
+  Module      :  Dominators\r
+  Copyright   :  (c) Matt Morrow 2009\r
+  License     :  BSD3\r
+  Maintainer  :  <morrow@moonpatio.com>\r
+  Stability   :  experimental\r
+  Portability :  portable\r
+\r
+  Taken from the dom-lt package.\r
+\r
+  The Lengauer-Tarjan graph dominators algorithm.\r
+\r
+    \[1\] Lengauer, Tarjan,\r
+      /A Fast Algorithm for Finding Dominators in a Flowgraph/, 1979.\r
+\r
+    \[2\] Muchnick,\r
+      /Advanced Compiler Design and Implementation/, 1997.\r
+\r
+    \[3\] Brisk, Sarrafzadeh,\r
+      /Interference Graphs for Procedures in Static Single/\r
+      /Information Form are Interval Graphs/, 2007.\r
+\r
+  Originally taken from the dom-lt package.\r
+-}\r
+\r
+module Dominators (\r
+   Node,Path,Edge\r
+  ,Graph,Rooted\r
+  ,idom,ipdom\r
+  ,domTree,pdomTree\r
+  ,dom,pdom\r
+  ,pddfs,rpddfs\r
+  ,fromAdj,fromEdges\r
+  ,toAdj,toEdges\r
+  ,asTree,asGraph\r
+  ,parents,ancestors\r
+) where\r
+\r
+import GhcPrelude\r
+\r
+import Data.Bifunctor\r
+import Data.Tuple (swap)\r
+\r
+import Data.Tree\r
+import Data.IntMap(IntMap)\r
+import Data.IntSet(IntSet)\r
+import qualified Data.IntMap.Strict as IM\r
+import qualified Data.IntSet as IS\r
+\r
+import Control.Monad\r
+import Control.Monad.ST.Strict\r
+\r
+import Data.Array.ST\r
+import Data.Array.Base\r
+  (unsafeNewArray_\r
+  ,unsafeWrite,unsafeRead)\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+type Node       = Int\r
+type Path       = [Node]\r
+type Edge       = (Node,Node)\r
+type Graph      = IntMap IntSet\r
+type Rooted     = (Node, Graph)\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+-- | /Dominators/.\r
+-- Complexity as for @idom@\r
+dom :: Rooted -> [(Node, Path)]\r
+dom = ancestors . domTree\r
+\r
+-- | /Post-dominators/.\r
+-- Complexity as for @idom@.\r
+pdom :: Rooted -> [(Node, Path)]\r
+pdom = ancestors . pdomTree\r
+\r
+-- | /Dominator tree/.\r
+-- Complexity as for @idom@.\r
+domTree :: Rooted -> Tree Node\r
+domTree a@(r,_) =\r
+  let is = filter ((/=r).fst) (idom a)\r
+      tg = fromEdges (fmap swap is)\r
+  in asTree (r,tg)\r
+\r
+-- | /Post-dominator tree/.\r
+-- Complexity as for @idom@.\r
+pdomTree :: Rooted -> Tree Node\r
+pdomTree a@(r,_) =\r
+  let is = filter ((/=r).fst) (ipdom a)\r
+      tg = fromEdges (fmap swap is)\r
+  in asTree (r,tg)\r
+\r
+-- | /Immediate dominators/.\r
+-- /O(|E|*alpha(|E|,|V|))/, where /alpha(m,n)/ is\r
+-- \"a functional inverse of Ackermann's function\".\r
+--\r
+-- This Complexity bound assumes /O(1)/ indexing. Since we're\r
+-- using @IntMap@, it has an additional /lg |V|/ factor\r
+-- somewhere in there. I'm not sure where.\r
+idom :: Rooted -> [(Node,Node)]\r
+idom rg = runST (evalS idomM =<< initEnv (pruneReach rg))\r
+\r
+-- | /Immediate post-dominators/.\r
+-- Complexity as for @idom@.\r
+ipdom :: Rooted -> [(Node,Node)]\r
+ipdom rg = runST (evalS idomM =<< initEnv (pruneReach (second predG rg)))\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+-- | /Post-dominated depth-first search/.\r
+pddfs :: Rooted -> [Node]\r
+pddfs = reverse . rpddfs\r
+\r
+-- | /Reverse post-dominated depth-first search/.\r
+rpddfs :: Rooted -> [Node]\r
+rpddfs = concat . levels . pdomTree\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+type Dom s a = S s (Env s) a\r
+type NodeSet    = IntSet\r
+type NodeMap a  = IntMap a\r
+data Env s = Env\r
+  {succE      :: !Graph\r
+  ,predE      :: !Graph\r
+  ,bucketE    :: !Graph\r
+  ,dfsE       :: {-# UNPACK #-}!Int\r
+  ,zeroE      :: {-# UNPACK #-}!Node\r
+  ,rootE      :: {-# UNPACK #-}!Node\r
+  ,labelE     :: {-# UNPACK #-}!(Arr s Node)\r
+  ,parentE    :: {-# UNPACK #-}!(Arr s Node)\r
+  ,ancestorE  :: {-# UNPACK #-}!(Arr s Node)\r
+  ,childE     :: {-# UNPACK #-}!(Arr s Node)\r
+  ,ndfsE      :: {-# UNPACK #-}!(Arr s Node)\r
+  ,dfnE       :: {-# UNPACK #-}!(Arr s Int)\r
+  ,sdnoE      :: {-# UNPACK #-}!(Arr s Int)\r
+  ,sizeE      :: {-# UNPACK #-}!(Arr s Int)\r
+  ,domE       :: {-# UNPACK #-}!(Arr s Node)\r
+  ,rnE        :: {-# UNPACK #-}!(Arr s Node)}\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+idomM :: Dom s [(Node,Node)]\r
+idomM = do\r
+  dfsDom =<< rootM\r
+  n <- gets dfsE\r
+  forM_ [n,n-1..1] (\i-> do\r
+    w <- ndfsM i\r
+    sw <- sdnoM w\r
+    ps <- predsM w\r
+    forM_ ps (\v-> do\r
+      u <- eval v\r
+      su <- sdnoM u\r
+      when (su < sw)\r
+        (store sdnoE w su))\r
+    z <- ndfsM =<< sdnoM w\r
+    modify(\e->e{bucketE=IM.adjust\r
+                      (w`IS.insert`)\r
+                      z (bucketE e)})\r
+    pw <- parentM w\r
+    link pw w\r
+    bps <- bucketM pw\r
+    forM_ bps (\v-> do\r
+      u <- eval v\r
+      su <- sdnoM u\r
+      sv <- sdnoM v\r
+      let dv = case su < sv of\r
+                True-> u\r
+                False-> pw\r
+      store domE v dv))\r
+  forM_ [1..n] (\i-> do\r
+    w <- ndfsM i\r
+    j <- sdnoM w\r
+    z <- ndfsM j\r
+    dw <- domM w\r
+    when (dw /= z)\r
+      (do ddw <- domM dw\r
+          store domE w ddw))\r
+  fromEnv\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+eval :: Node -> Dom s Node\r
+eval v = do\r
+  n0 <- zeroM\r
+  a  <- ancestorM v\r
+  case a==n0 of\r
+    True-> labelM v\r
+    False-> do\r
+      compress v\r
+      a   <- ancestorM v\r
+      l   <- labelM v\r
+      la  <- labelM a\r
+      sl  <- sdnoM l\r
+      sla <- sdnoM la\r
+      case sl <= sla of\r
+        True-> return l\r
+        False-> return la\r
+\r
+compress :: Node -> Dom s ()\r
+compress v = do\r
+  n0  <- zeroM\r
+  a   <- ancestorM v\r
+  aa  <- ancestorM a\r
+  when (aa /= n0) (do\r
+    compress a\r
+    a   <- ancestorM v\r
+    aa  <- ancestorM a\r
+    l   <- labelM v\r
+    la  <- labelM a\r
+    sl  <- sdnoM l\r
+    sla <- sdnoM la\r
+    when (sla < sl)\r
+      (store labelE v la)\r
+    store ancestorE v aa)\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+link :: Node -> Node -> Dom s ()\r
+link v w = do\r
+  n0  <- zeroM\r
+  lw  <- labelM w\r
+  slw <- sdnoM lw\r
+  let balance s = do\r
+        c   <- childM s\r
+        lc  <- labelM c\r
+        slc <- sdnoM lc\r
+        case slw < slc of\r
+          False-> return s\r
+          True-> do\r
+            zs  <- sizeM s\r
+            zc  <- sizeM c\r
+            cc  <- childM c\r
+            zcc <- sizeM cc\r
+            case 2*zc <= zs+zcc of\r
+              True-> do\r
+                store ancestorE c s\r
+                store childE s cc\r
+                balance s\r
+              False-> do\r
+                store sizeE c zs\r
+                store ancestorE s c\r
+                balance c\r
+  s   <- balance w\r
+  lw  <- labelM w\r
+  zw  <- sizeM w\r
+  store labelE s lw\r
+  store sizeE v . (+zw) =<< sizeM v\r
+  let follow s = do\r
+        when (s /= n0) (do\r
+          store ancestorE s v\r
+          follow =<< childM s)\r
+  zv  <- sizeM v\r
+  follow =<< case zv < 2*zw of\r
+              False-> return s\r
+              True-> do\r
+                cv <- childM v\r
+                store childE v s\r
+                return cv\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+dfsDom :: Node -> Dom s ()\r
+dfsDom i = do\r
+  _   <- go i\r
+  n0  <- zeroM\r
+  r   <- rootM\r
+  store parentE r n0\r
+  where go i = do\r
+          n <- nextM\r
+          store dfnE   i n\r
+          store sdnoE  i n\r
+          store ndfsE  n i\r
+          store labelE i i\r
+          ss <- succsM i\r
+          forM_ ss (\j-> do\r
+            s <- sdnoM j\r
+            case s==0 of\r
+              False-> return()\r
+              True-> do\r
+                store parentE j i\r
+                go j)\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+initEnv :: Rooted -> ST s (Env s)\r
+initEnv (r0,g0) = do\r
+  let (g,rnmap) = renum 1 g0\r
+      pred      = predG g\r
+      r         = rnmap IM.! r0\r
+      n         = IM.size g\r
+      ns        = [0..n]\r
+      m         = n+1\r
+\r
+  let bucket = IM.fromList\r
+        (zip ns (repeat mempty))\r
+\r
+  rna <- newI m\r
+  writes rna (fmap swap\r
+        (IM.toList rnmap))\r
+\r
+  doms      <- newI m\r
+  sdno      <- newI m\r
+  size      <- newI m\r
+  parent    <- newI m\r
+  ancestor  <- newI m\r
+  child     <- newI m\r
+  label     <- newI m\r
+  ndfs      <- newI m\r
+  dfn       <- newI m\r
+\r
+  forM_ [0..n] (doms.=0)\r
+  forM_ [0..n] (sdno.=0)\r
+  forM_ [1..n] (size.=1)\r
+  forM_ [0..n] (ancestor.=0)\r
+  forM_ [0..n] (child.=0)\r
+\r
+  (doms.=r) r\r
+  (size.=0) 0\r
+  (label.=0) 0\r
+\r
+  return (Env\r
+    {rnE        = rna\r
+    ,dfsE       = 0\r
+    ,zeroE      = 0\r
+    ,rootE      = r\r
+    ,labelE     = label\r
+    ,parentE    = parent\r
+    ,ancestorE  = ancestor\r
+    ,childE     = child\r
+    ,ndfsE      = ndfs\r
+    ,dfnE       = dfn\r
+    ,sdnoE      = sdno\r
+    ,sizeE      = size\r
+    ,succE      = g\r
+    ,predE      = pred\r
+    ,bucketE    = bucket\r
+    ,domE       = doms})\r
+\r
+fromEnv :: Dom s [(Node,Node)]\r
+fromEnv = do\r
+  dom   <- gets domE\r
+  rn    <- gets rnE\r
+  -- r     <- gets rootE\r
+  (_,n) <- st (getBounds dom)\r
+  forM [1..n] (\i-> do\r
+    j <- st(rn!:i)\r
+    d <- st(dom!:i)\r
+    k <- st(rn!:d)\r
+    return (j,k))\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+zeroM :: Dom s Node\r
+zeroM = gets zeroE\r
+domM :: Node -> Dom s Node\r
+domM = fetch domE\r
+rootM :: Dom s Node\r
+rootM = gets rootE\r
+succsM :: Node -> Dom s [Node]\r
+succsM i = gets (IS.toList . (!i) . succE)\r
+predsM :: Node -> Dom s [Node]\r
+predsM i = gets (IS.toList . (!i) . predE)\r
+bucketM :: Node -> Dom s [Node]\r
+bucketM i = gets (IS.toList . (!i) . bucketE)\r
+sizeM :: Node -> Dom s Int\r
+sizeM = fetch sizeE\r
+sdnoM :: Node -> Dom s Int\r
+sdnoM = fetch sdnoE\r
+-- dfnM :: Node -> Dom s Int\r
+-- dfnM = fetch dfnE\r
+ndfsM :: Int -> Dom s Node\r
+ndfsM = fetch ndfsE\r
+childM :: Node -> Dom s Node\r
+childM = fetch childE\r
+ancestorM :: Node -> Dom s Node\r
+ancestorM = fetch ancestorE\r
+parentM :: Node -> Dom s Node\r
+parentM = fetch parentE\r
+labelM :: Node -> Dom s Node\r
+labelM = fetch labelE\r
+nextM :: Dom s Int\r
+nextM = do\r
+  n <- gets dfsE\r
+  let n' = n+1\r
+  modify(\e->e{dfsE=n'})\r
+  return n'\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+type A = STUArray\r
+type Arr s a = A s Int a\r
+\r
+infixl 9 !:\r
+infixr 2 .=\r
+\r
+(.=) :: (MArray (A s) a (ST s))\r
+     => Arr s a -> a -> Int -> ST s ()\r
+(v .= x) i = unsafeWrite v i x\r
+\r
+(!:) :: (MArray (A s) a (ST s))\r
+     => A s Int a -> Int -> ST s a\r
+a !: i = do\r
+  o <- unsafeRead a i\r
+  return $! o\r
+\r
+new :: (MArray (A s) a (ST s))\r
+    => Int -> ST s (Arr s a)\r
+new n = unsafeNewArray_ (0,n-1)\r
+\r
+newI :: Int -> ST s (Arr s Int)\r
+newI = new\r
+\r
+-- newD :: Int -> ST s (Arr s Double)\r
+-- newD = new\r
+\r
+-- dump :: (MArray (A s) a (ST s)) => Arr s a -> ST s [a]\r
+-- dump a = do\r
+--   (m,n) <- getBounds a\r
+--   forM [m..n] (\i -> a!:i)\r
+\r
+writes :: (MArray (A s) a (ST s))\r
+     => Arr s a -> [(Int,a)] -> ST s ()\r
+writes a xs = forM_ xs (\(i,x) -> (a.=x) i)\r
+\r
+-- arr :: (MArray (A s) a (ST s)) => [a] -> ST s (Arr s a)\r
+-- arr xs = do\r
+--   let n = length xs\r
+--   a <- new n\r
+--   go a n 0 xs\r
+--   return a\r
+--   where go _ _ _    [] = return ()\r
+--         go a n i (x:xs)\r
+--           | i <= n = (a.=x) i >> go a n (i+1) xs\r
+--           | otherwise = return ()\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+(!) :: Monoid a => IntMap a -> Int -> a\r
+(!) g n = maybe mempty id (IM.lookup n g)\r
+\r
+fromAdj :: [(Node, [Node])] -> Graph\r
+fromAdj = IM.fromList . fmap (second IS.fromList)\r
+\r
+fromEdges :: [Edge] -> Graph\r
+fromEdges = collectI IS.union fst (IS.singleton . snd)\r
+\r
+toAdj :: Graph -> [(Node, [Node])]\r
+toAdj = fmap (second IS.toList) . IM.toList\r
+\r
+toEdges :: Graph -> [Edge]\r
+toEdges = concatMap (uncurry (fmap . (,))) . toAdj\r
+\r
+predG :: Graph -> Graph\r
+predG g = IM.unionWith IS.union (go g) g0\r
+  where g0 = fmap (const mempty) g\r
+        f :: IntMap IntSet -> Int -> IntSet -> IntMap IntSet\r
+        f m i a = foldl' (\m p -> IM.insertWith mappend p\r
+                                      (IS.singleton i) m)\r
+                        m\r
+                       (IS.toList a)\r
+        go :: IntMap IntSet -> IntMap IntSet\r
+        go = flip IM.foldlWithKey' mempty f\r
+\r
+pruneReach :: Rooted -> Rooted\r
+pruneReach (r,g) = (r,g2)\r
+  where is = reachable\r
+              (maybe mempty id\r
+                . flip IM.lookup g) $ r\r
+        g2 = IM.fromList\r
+            . fmap (second (IS.filter (`IS.member`is)))\r
+            . filter ((`IS.member`is) . fst)\r
+            . IM.toList $ g\r
+\r
+tip :: Tree a -> (a, [Tree a])\r
+tip (Node a ts) = (a, ts)\r
+\r
+parents :: Tree a -> [(a, a)]\r
+parents (Node i xs) = p i xs\r
+        ++ concatMap parents xs\r
+  where p i = fmap (flip (,) i . rootLabel)\r
+\r
+ancestors :: Tree a -> [(a, [a])]\r
+ancestors = go []\r
+  where go acc (Node i xs)\r
+          = let acc' = i:acc\r
+            in p acc' xs ++ concatMap (go acc') xs\r
+        p is = fmap (flip (,) is . rootLabel)\r
+\r
+asGraph :: Tree Node -> Rooted\r
+asGraph t@(Node a _) = let g = go t in (a, fromAdj g)\r
+  where go (Node a ts) = let as = (fst . unzip . fmap tip) ts\r
+                          in (a, as) : concatMap go ts\r
+\r
+asTree :: Rooted -> Tree Node\r
+asTree (r,g) = let go a = Node a (fmap go ((IS.toList . f) a))\r
+                   f = (g !)\r
+            in go r\r
+\r
+reachable :: (Node -> NodeSet) -> (Node -> NodeSet)\r
+reachable f a = go (IS.singleton a) a\r
+  where go seen a = let s = f a\r
+                        as = IS.toList (s `IS.difference` seen)\r
+                    in foldl' go (s `IS.union` seen) as\r
+\r
+collectI :: (c -> c -> c)\r
+        -> (a -> Int) -> (a -> c) -> [a] -> IntMap c\r
+collectI (<>) f g\r
+  = foldl' (\m a -> IM.insertWith (<>)\r
+                                  (f a)\r
+                                  (g a) m) mempty\r
+\r
+-- collect :: (Ord b) => (c -> c -> c)\r
+--         -> (a -> b) -> (a -> c) -> [a] -> Map b c\r
+-- collect (<>) f g\r
+--   = foldl' (\m a -> SM.insertWith (<>)\r
+--                                   (f a)\r
+--                                   (g a) m) mempty\r
+\r
+-- (renamed, old -> new)\r
+renum :: Int -> Graph -> (Graph, NodeMap Node)\r
+renum from = (\(_,m,g)->(g,m))\r
+  . IM.foldlWithKey'\r
+      f (from,mempty,mempty)\r
+  where\r
+    f :: (Int, NodeMap Node, IntMap IntSet) -> Node -> IntSet\r
+      -> (Int, NodeMap Node, IntMap IntSet)\r
+    f (!n,!env,!new) i ss =\r
+            let (j,n2,env2) = go n env i\r
+                (n3,env3,ss2) = IS.fold\r
+                  (\k (!n,!env,!new)->\r
+                      case go n env k of\r
+                        (l,n2,env2)-> (n2,env2,l `IS.insert` new))\r
+                  (n2,env2,mempty) ss\r
+                new2 = IM.insertWith IS.union j ss2 new\r
+            in (n3,env3,new2)\r
+    go :: Int\r
+        -> NodeMap Node\r
+        -> Node\r
+        -> (Node,Int,NodeMap Node)\r
+    go !n !env i =\r
+        case IM.lookup i env of\r
+        Just j -> (j,n,env)\r
+        Nothing -> (n,n+1,IM.insert i n env)\r
+\r
+-----------------------------------------------------------------------------\r
+\r
+newtype S z s a = S {unS :: forall o. (a -> s -> ST z o) -> s -> ST z o}\r
+instance Functor (S z s) where\r
+  fmap f (S g) = S (\k -> g (k . f))\r
+instance Monad (S z s) where\r
+  return = pure\r
+  S g >>= f = S (\k -> g (\a -> unS (f a) k))\r
+instance Applicative (S z s) where\r
+  pure a = S (\k -> k a)\r
+  (<*>) = ap\r
+-- get :: S z s s\r
+-- get = S (\k s -> k s s)\r
+gets :: (s -> a) -> S z s a\r
+gets f = S (\k s -> k (f s) s)\r
+-- set :: s -> S z s ()\r
+-- set s = S (\k _ -> k () s)\r
+modify :: (s -> s) -> S z s ()\r
+modify f = S (\k -> k () . f)\r
+-- runS :: S z s a -> s -> ST z (a, s)\r
+-- runS (S g) = g (\a s -> return (a,s))\r
+evalS :: S z s a -> s -> ST z a\r
+evalS (S g) = g ((return .) . const)\r
+-- execS :: S z s a -> s -> ST z s\r
+-- execS (S g) = g ((return .) . flip const)\r
+st :: ST z a -> S z s a\r
+st m = S (\k s-> do\r
+  a <- m\r
+  k a s)\r
+store :: (MArray (A z) a (ST z))\r
+      => (s -> Arr z a) -> Int -> a -> S z s ()\r
+store f i x = do\r
+  a <- gets f\r
+  st ((a.=x) i)\r
+fetch :: (MArray (A z) a (ST z))\r
+      => (s -> Arr z a) -> Int -> S z s a\r
+fetch f i = do\r
+  a <- gets f\r
+  st (a!:i)\r
+\r
index e8b50e5..8da5038 100644 (file)
@@ -10,14 +10,18 @@ can be appended in linear time.
 -}
 {-# LANGUAGE DeriveFunctor #-}
 
+{-# LANGUAGE BangPatterns #-}
+
 module OrdList (
         OrdList,
         nilOL, isNilOL, unitOL, appOL, consOL, snocOL, concatOL, lastOL,
         headOL,
-        mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse
+        mapOL, fromOL, toOL, foldrOL, foldlOL, reverseOL, fromOLReverse,
+        strictlyEqOL, strictlyOrdOL
 ) where
 
 import GhcPrelude
+import Data.Foldable
 
 import Outputable
 
@@ -49,7 +53,11 @@ instance Monoid (OrdList a) where
   mconcat = concatOL
 
 instance Foldable OrdList where
-  foldr = foldrOL
+  foldr   = foldrOL
+  foldl'  = foldlOL
+  toList  = fromOL
+  null    = isNilOL
+  length  = lengthOL
 
 instance Traversable OrdList where
   traverse f xs = toOL <$> traverse f (fromOL xs)
@@ -64,7 +72,7 @@ appOL    :: OrdList a   -> OrdList a -> OrdList a
 concatOL :: [OrdList a] -> OrdList a
 headOL   :: OrdList a   -> a
 lastOL   :: OrdList a   -> a
-
+lengthOL :: OrdList a   -> Int
 
 nilOL        = None
 unitOL as    = One as
@@ -86,6 +94,13 @@ lastOL (Cons _ as) = lastOL as
 lastOL (Snoc _ a)  = a
 lastOL (Two _ as)  = lastOL as
 
+lengthOL None        = 0
+lengthOL (One _)     = 1
+lengthOL (Many as)   = length as
+lengthOL (Cons _ as) = 1 + length as
+lengthOL (Snoc as _) = 1 + length as
+lengthOL (Two as bs) = length as + length bs
+
 isNilOL None = True
 isNilOL _    = False
 
@@ -126,13 +141,14 @@ foldrOL k z (Snoc xs x) = foldrOL k (k x z) xs
 foldrOL k z (Two b1 b2) = foldrOL k (foldrOL k z b2) b1
 foldrOL k z (Many xs)   = foldr k z xs
 
+-- | Strict left fold.
 foldlOL :: (b->a->b) -> b -> OrdList a -> b
 foldlOL _ z None        = z
 foldlOL k z (One x)     = k z x
-foldlOL k z (Cons x xs) = foldlOL k (k z x) xs
-foldlOL k z (Snoc xs x) = k (foldlOL k z xs) x
-foldlOL k z (Two b1 b2) = foldlOL k (foldlOL k z b1) b2
-foldlOL k z (Many xs)   = foldl k z xs
+foldlOL k z (Cons x xs) = let !z' = (k z x) in foldlOL k z' xs
+foldlOL k z (Snoc xs x) = let !z' = (foldlOL k z xs) in k z' x
+foldlOL k z (Two b1 b2) = let !z' = (foldlOL k z b1) in foldlOL k z' b2
+foldlOL k z (Many xs)   = foldl' k z xs
 
 toOL :: [a] -> OrdList a
 toOL [] = None
@@ -146,3 +162,33 @@ reverseOL (Cons a b) = Snoc (reverseOL b) a
 reverseOL (Snoc a b) = Cons b (reverseOL a)
 reverseOL (Two a b)  = Two (reverseOL b) (reverseOL a)
 reverseOL (Many xs)  = Many (reverse xs)
+
+-- | Compare not only the values but also the structure of two lists
+strictlyEqOL :: Eq a => OrdList a   -> OrdList a -> Bool
+strictlyEqOL None         None       = True
+strictlyEqOL (One x)     (One y)     = x == y
+strictlyEqOL (Cons a as) (Cons b bs) = a == b && as `strictlyEqOL` bs
+strictlyEqOL (Snoc as a) (Snoc bs b) = a == b && as `strictlyEqOL` bs
+strictlyEqOL (Two a1 a2) (Two b1 b2) = a1 `strictlyEqOL` b1 && a2 `strictlyEqOL` b2
+strictlyEqOL (Many as)   (Many bs)   = as == bs
+strictlyEqOL _            _          = False
+
+-- | Compare not only the values but also the structure of two lists
+strictlyOrdOL :: Ord a => OrdList a   -> OrdList a -> Ordering
+strictlyOrdOL None         None       = EQ
+strictlyOrdOL None         _          = LT
+strictlyOrdOL (One x)     (One y)     = compare x y
+strictlyOrdOL (One _)      _          = LT
+strictlyOrdOL (Cons a as) (Cons b bs) =
+  compare a b `mappend` strictlyOrdOL as bs
+strictlyOrdOL (Cons _ _)   _          = LT
+strictlyOrdOL (Snoc as a) (Snoc bs b) =
+  compare a b `mappend` strictlyOrdOL as bs
+strictlyOrdOL (Snoc _ _)   _          = LT
+strictlyOrdOL (Two a1 a2) (Two b1 b2) =
+  (strictlyOrdOL a1 b1) `mappend` (strictlyOrdOL a2 b2)
+strictlyOrdOL (Two _ _)    _          = LT
+strictlyOrdOL (Many as)   (Many bs)   = compare as bs
+strictlyOrdOL (Many _ )   _           = GT
+
+