Add -foptimal-applicative-do
authorSimon Marlow <marlowsd@gmail.com>
Fri, 4 Mar 2016 13:06:42 +0000 (13:06 +0000)
committerBartosz Nitka <niteria@gmail.com>
Mon, 25 Jul 2016 14:41:44 +0000 (07:41 -0700)
Summary:
The algorithm for ApplicativeDo rearrangement is based on a heuristic
that runs in O(n^2).  This patch adds the optimal algorithm, which is
O(n^3), selected by a flag (-foptimal-applicative-do).  It finds better
solutions in a small number of cases (about 2% of the cases where
ApplicativeDo makes a difference), but it can be very slow for large do
expressions.  I'm mainly adding it for experimental reasons.

ToDo: user guide docs

Test Plan: validate

Reviewers: simonpj, bgamari, austin, niteria, erikd

Subscribers: thomie

Differential Revision: https://phabricator.haskell.org/D1969

(cherry picked from commit 2f45cf3f48162a5f843005755dafa1c5c1b451a7)

compiler/main/DynFlags.hs
compiler/rename/RnExpr.hs
docs/users_guide/glasgow_exts.rst
testsuite/tests/ado/ado-optimal.hs [new file with mode: 0644]
testsuite/tests/ado/ado-optimal.stdout [new file with mode: 0644]
testsuite/tests/ado/ado004.hs
testsuite/tests/ado/ado004.stderr
testsuite/tests/ado/all.T
utils/mkUserGuidePart/Options/Optimizations.hs

index 6a4737e..f6598b9 100644 (file)
@@ -488,6 +488,7 @@ data GeneralFlag
    | Opt_FlatCache
    | Opt_ExternalInterpreter
    | Opt_VersionMacros
+   | Opt_OptimalApplicativeDo
 
    -- PreInlining is on by default. The option is there just to see how
    -- bad things get if you turn it off!
@@ -3340,6 +3341,7 @@ fFlagsDeps = [
   flagSpec "loopification"                    Opt_Loopification,
   flagSpec "omit-interface-pragmas"           Opt_OmitInterfacePragmas,
   flagSpec "omit-yields"                      Opt_OmitYields,
+  flagSpec "optimal-applicative-do"           Opt_OptimalApplicativeDo,
   flagSpec "pedantic-bottoms"                 Opt_PedanticBottoms,
   flagSpec "pre-inlining"                     Opt_SimplPreInlining,
   flagGhciSpec "print-bind-contents"          Opt_PrintBindContents,
index 4c01315..5921718 100644 (file)
@@ -48,6 +48,9 @@ import Control.Monad
 import TysWiredIn       ( nilDataConName )
 import qualified GHC.LanguageExtensions as LangExt
 
+import Data.Ord
+import Data.Array
+
 {-
 ************************************************************************
 *                                                                      *
@@ -1439,25 +1442,120 @@ rearrangeForApplicativeDo
   -> RnM ([ExprLStmt Name], FreeVars)
 
 rearrangeForApplicativeDo _ [] = return ([], emptyNameSet)
+rearrangeForApplicativeDo _ [(one,_)] = return ([one], emptyNameSet)
 rearrangeForApplicativeDo ctxt stmts0 = do
-  (stmts', fvs) <- ado ctxt stmts [last] last_fvs
-  return (stmts', fvs)
-  where (stmts,(last,last_fvs)) = findLast stmts0
-        findLast [] = error "findLast"
-        findLast [last] = ([],last)
-        findLast (x:xs) = (x:rest,last) where (rest,last) = findLast xs
-
--- | The ApplicativeDo transformation.
-ado
+  optimal_ado <- goptM Opt_OptimalApplicativeDo
+  let stmt_tree | optimal_ado = mkStmtTreeOptimal stmts
+                | otherwise = mkStmtTreeHeuristic stmts
+  stmtTreeToStmts ctxt stmt_tree [last] last_fvs
+  where
+    (stmts,(last,last_fvs)) = findLast stmts0
+    findLast [] = error "findLast"
+    findLast [last] = ([],last)
+    findLast (x:xs) = (x:rest,last) where (rest,last) = findLast xs
+
+-- | A tree of statements using a mixture of applicative and bind constructs.
+data StmtTree a
+  = StmtTreeOne a
+  | StmtTreeBind (StmtTree a) (StmtTree a)
+  | StmtTreeApplicative [StmtTree a]
+
+flattenStmtTree :: StmtTree a -> [a]
+flattenStmtTree t = go t []
+ where
+  go (StmtTreeOne a) as = a : as
+  go (StmtTreeBind l r) as = go l (go r as)
+  go (StmtTreeApplicative ts) as = foldr go as ts
+
+type ExprStmtTree = StmtTree (ExprLStmt Name, FreeVars)
+type Cost = Int
+
+-- | Turn a sequence of statements into an ExprStmtTree using a
+-- heuristic algorithm.  /O(n^2)/
+mkStmtTreeHeuristic :: [(ExprLStmt Name, FreeVars)] -> ExprStmtTree
+mkStmtTreeHeuristic [one] = StmtTreeOne one
+mkStmtTreeHeuristic stmts =
+  case segments stmts of
+    [one] -> split one
+    segs -> StmtTreeApplicative (map split segs)
+ where
+  split [one] = StmtTreeOne one
+  split stmts =
+    StmtTreeBind (mkStmtTreeHeuristic before) (mkStmtTreeHeuristic after)
+    where (before, after) = splitSegment stmts
+
+-- | Turn a sequence of statements into an ExprStmtTree optimally,
+-- using dynamic programming.  /O(n^3)/
+mkStmtTreeOptimal :: [(ExprLStmt Name, FreeVars)] -> ExprStmtTree
+mkStmtTreeOptimal stmts =
+  ASSERT(not (null stmts)) -- the empty case is handled by the caller;
+                           -- we don't support empty StmtTrees.
+  fst (arr ! (0,n))
+  where
+    n = length stmts - 1
+    stmt_arr = listArray (0,n) stmts
+
+    -- lazy cache of optimal trees for subsequences of the input
+    arr :: Array (Int,Int) (ExprStmtTree, Cost)
+    arr = array ((0,0),(n,n))
+             [ ((lo,hi), tree lo hi)
+             | lo <- [0..n]
+             , hi <- [lo..n] ]
+
+    -- compute the optimal tree for the sequence [lo..hi]
+    tree lo hi
+      | hi == lo = (StmtTreeOne (stmt_arr ! lo), 1)
+      | otherwise =
+         case segments [ stmt_arr ! i | i <- [lo..hi] ] of
+           [] -> panic "mkStmtTree"
+           [_one] -> split lo hi
+           segs -> (StmtTreeApplicative trees, maximum costs)
+             where
+               bounds = scanl (\(_,hi) a -> (hi+1, hi + length a)) (0,lo-1) segs
+               (trees,costs) = unzip (map (uncurry split) (tail bounds))
+
+    -- find the best place to split the segment [lo..hi]
+    split :: Int -> Int -> (ExprStmtTree, Cost)
+    split lo hi
+      | hi == lo = (StmtTreeOne (stmt_arr ! lo), 1)
+      | otherwise = (StmtTreeBind before after, c1+c2)
+        where
+         -- As per the paper, for a sequence s1...sn, we want to find
+         -- the split with the minimum cost, where the cost is the
+         -- sum of the cost of the left and right subsequences.
+         --
+         -- As an optimisation (also in the paper) if the cost of
+         -- s1..s(n-1) is different from the cost of s2..sn, we know
+         -- that the optimal solution is the lower of the two.  Only
+         -- in the case that these two have the same cost do we need
+         -- to do the exhaustive search.
+         --
+         ((before,c1),(after,c2))
+           | hi - lo == 1
+           = ((StmtTreeOne (stmt_arr ! lo), 1),
+              (StmtTreeOne (stmt_arr ! hi), 1))
+           | left_cost < right_cost
+           = ((left,left_cost), (StmtTreeOne (stmt_arr ! hi), 1))
+           | otherwise -- left_cost > right_cost
+           = ((StmtTreeOne (stmt_arr ! lo), 1), (right,right_cost))
+           | otherwise = minimumBy (comparing cost) alternatives
+           where
+             (left, left_cost) = arr ! (lo,hi-1)
+             (right, right_cost) = arr ! (lo+1,hi)
+             cost ((_,c1),(_,c2)) = c1 + c2
+             alternatives = [ (arr ! (lo,k), arr ! (k+1,hi))
+                            | k <- [lo .. hi-1] ]
+
+
+-- | Turn the ExprStmtTree back into a sequence of statements, using
+-- ApplicativeStmt where necessary.
+stmtTreeToStmts
   :: HsStmtContext Name
-  -> [(ExprLStmt Name, FreeVars)] -- ^ input statements
+  -> ExprStmtTree
   -> [ExprLStmt Name]             -- ^ the "tail"
-  -> FreeVars                                -- ^ free variables of the tail
+  -> FreeVars                     -- ^ free variables of the tail
   -> RnM ( [ExprLStmt Name]       -- ( output statements,
-         , FreeVars )                        -- , things we needed
-                                             --    e.g. <$>, <*>, join )
-
-ado _ctxt []        tail _ = return (tail, emptyNameSet)
+         , FreeVars )             -- , things we needed
 
 -- If we have a single bind, and we can do it without a join, transform
 -- to an ApplicativeStmt.  This corresponds to the rule
@@ -1465,7 +1563,8 @@ ado _ctxt []        tail _ = return (tail, emptyNameSet)
 -- In the spec, but we do it here rather than in the desugarer,
 -- because we need the typechecker to typecheck the <$> form rather than
 -- the bind form, which would give rise to a Monad constraint.
-ado ctxt [(L _ (BindStmt pat rhs _ _ _),_)] tail _
+stmtTreeToStmts ctxt (StmtTreeOne (L _ (BindStmt pat rhs _ _ _),_))
+                tail _tail_fvs
   | isIrrefutableHsPat pat, (False,tail') <- needJoin tail
     -- WARNING: isIrrefutableHsPat on (HsPat Name) doesn't have enough info
     --          to know which types have only one constructor.  So only
@@ -1474,65 +1573,41 @@ ado ctxt [(L _ (BindStmt pat rhs _ _ _),_)] tail _
     --          isIrrefuatableHsPat
   = mkApplicativeStmt ctxt [ApplicativeArgOne pat rhs] False tail'
 
-ado _ctxt [(one,_)] tail _ = return (one:tail, emptyNameSet)
-
-ado ctxt stmts tail tail_fvs =
-  case segments stmts of  -- chop into segments
-    [] -> panic "ado"
-    [one] ->
-      -- one indivisible segment, divide it by adding a bind
-      adoSegment ctxt one tail tail_fvs
-    segs ->
-      -- multiple segments; recursively transform the segments, and
-      -- combine into an ApplicativeStmt
-      do { pairs <- mapM (adoSegmentArg ctxt tail_fvs) segs
-         ; let (stmts', fvss) = unzip pairs
-         ; let (need_join, tail') = needJoin tail
-         ; (stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
-         ; return (stmts, unionNameSets (fvs:fvss)) }
-
--- | Deal with an indivisible segment.  We pick a place to insert a
--- bind (it will actually be a join), and recursively transform the
--- two halves.
-adoSegment
-  :: HsStmtContext Name
-  -> [(ExprLStmt Name, FreeVars)]
-  -> [ExprLStmt Name]
-  -> FreeVars
-  -> RnM ( [ExprLStmt Name], FreeVars )
-adoSegment ctxt stmts tail tail_fvs
- = do {  -- choose somewhere to put a bind
-        let (before,after) = splitSegment stmts
-      ; (stmts1, fvs1) <- ado ctxt after tail tail_fvs
-      ; let tail1_fvs = unionNameSets (tail_fvs : map snd after)
-      ; (stmts2, fvs2) <- ado ctxt before stmts1 tail1_fvs
-      ; return (stmts2, fvs1 `plusFV` fvs2) }
-
--- | Given a segment, make an ApplicativeArg.  Here we recursively
--- call adoSegment on the segment's contents to extract any further
--- available parallelism.
-adoSegmentArg
-  :: HsStmtContext Name
-  -> FreeVars
-  -> [(ExprLStmt Name, FreeVars)]
-  -> RnM (ApplicativeArg Name Name, FreeVars)
-adoSegmentArg _ _ [(L _ (BindStmt pat exp _ _ _),_)] =
-  return (ApplicativeArgOne pat exp, emptyFVs)
-adoSegmentArg ctxt tail_fvs stmts =
-  do { let pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
-                      `intersectNameSet` tail_fvs
-           pvars = nameSetElems pvarset
-           pat = mkBigLHsVarPatTup pvars
-           tup = mkBigLHsVarTup pvars
-     ; (stmts',fvs2) <- adoSegment ctxt stmts [] pvarset
-     ; (mb_ret, fvs1) <-
-          if | L _ ApplicativeStmt{} <- last stmts' ->
-               return (unLoc tup, emptyNameSet)
-             | otherwise -> do
-               (ret,fvs) <- lookupStmtNamePoly ctxt returnMName
-               return (HsApp (noLoc ret) tup, fvs)
-     ; return ( ApplicativeArgMany stmts' mb_ret pat
-              , fvs1 `plusFV` fvs2) }
+stmtTreeToStmts _ctxt (StmtTreeOne (s,_)) tail _tail_fvs =
+  return (s : tail, emptyNameSet)
+
+stmtTreeToStmts ctxt (StmtTreeBind before after) tail tail_fvs = do
+  (stmts1, fvs1) <- stmtTreeToStmts ctxt after tail tail_fvs
+  let tail1_fvs = unionNameSets (tail_fvs : map snd (flattenStmtTree after))
+  (stmts2, fvs2) <- stmtTreeToStmts ctxt before stmts1 tail1_fvs
+  return (stmts2, fvs1 `plusFV` fvs2)
+
+stmtTreeToStmts ctxt (StmtTreeApplicative trees) tail tail_fvs = do
+   pairs <- mapM (stmtTreeArg ctxt tail_fvs) trees
+   let (stmts', fvss) = unzip pairs
+   let (need_join, tail') = needJoin tail
+   (stmts, fvs) <- mkApplicativeStmt ctxt stmts' need_join tail'
+   return (stmts, unionNameSets (fvs:fvss))
+ where
+   stmtTreeArg _ctxt _tail_fvs (StmtTreeOne (L _ (BindStmt pat exp _ _ _), _)) =
+     return (ApplicativeArgOne pat exp, emptyFVs)
+   stmtTreeArg ctxt tail_fvs tree = do
+     let stmts = flattenStmtTree tree
+         pvarset = mkNameSet (concatMap (collectStmtBinders.unLoc.fst) stmts)
+                     `intersectNameSet` tail_fvs
+         pvars = nameSetElems pvarset
+         pat = mkBigLHsVarPatTup pvars
+         tup = mkBigLHsVarTup pvars
+     (stmts',fvs2) <- stmtTreeToStmts ctxt tree [] pvarset
+     (mb_ret, fvs1) <-
+        if | L _ ApplicativeStmt{} <- last stmts' ->
+             return (unLoc tup, emptyNameSet)
+           | otherwise -> do
+             (ret,fvs) <- lookupStmtNamePoly ctxt returnMName
+             return (HsApp (noLoc ret) tup, fvs)
+     return ( ApplicativeArgMany stmts' mb_ret pat
+            , fvs1 `plusFV` fvs2)
+
 
 -- | Divide a sequence of statements into segments, where no segment
 -- depends on any variables defined by a statement in another segment.
@@ -1689,7 +1764,6 @@ isReturnApp (L _ (HsApp f arg))
   is_return _ = False
 isReturnApp _ = Nothing
 
-
 {-
 ************************************************************************
 *                                                                      *
index 408f63f..cbecda7 100644 (file)
@@ -939,6 +939,23 @@ then the expression will only require ``Applicative``. Otherwise, the expression
 will require ``Monad``. The block may return a pure expression ``E`` depending
 upon the results ``p1...pn`` with either ``return`` or ``pure``.
 
+When the statements of a ``do`` expression have dependencies between
+them, and ``ApplicativeDo`` cannot infer an ``Applicative`` type, it
+uses a heuristic algorithm to try to use ``<*>`` as much as possible.
+This algorithm usually finds the best solution, but in rare complex
+cases it might miss an opportunity.  There is an algorithm that finds
+the optimal solution, provided as an option:
+
+.. ghc-flag:: -foptimal-applicative-do
+
+    :since: 8.0.1
+
+    Enables an alternative algorithm for choosing where to use ``<*>``
+    in conjunction with the ``ApplicativeDo`` language extension.
+    This algorithm always finds the optimal solution, but it is
+    expensive: ``O(n^3)``, so this option can lead to long compile
+    times when there are very large ``do`` expressions (over 100
+    statements).  The default ``ApplicativeDo`` algorithm is ``O(n^2)``.
 
 .. _applicative-do-pitfall:
 
diff --git a/testsuite/tests/ado/ado-optimal.hs b/testsuite/tests/ado/ado-optimal.hs
new file mode 100644 (file)
index 0000000..aab8d53
--- /dev/null
@@ -0,0 +1,59 @@
+{-# LANGUAGE ScopedTypeVariables, ExistentialQuantification, ApplicativeDo #-}
+{-# OPTIONS_GHC -foptimal-applicative-do #-}
+module Main where
+
+import Control.Applicative
+import Text.PrettyPrint
+
+(a:b:c:d:e:f:g:h:_) = map (\c -> doc [c]) ['a'..]
+
+-- This one requires -foptimal-applicative-do to find the best solution
+-- ((a; b) | (c; d)); e
+test1 :: M ()
+test1 = do
+  x1 <- a
+  x2 <- const b x1
+  x3 <- c
+  x4 <- const d x3
+  x5 <- const e (x1,x4)
+  return (const () x5)
+
+main = mapM_ run
+ [ test1
+ ]
+
+-- Testing code, prints out the structure of a monad/applicative expression
+
+newtype M a = M (Bool -> (Maybe Doc, a))
+
+maybeParen True d = parens d
+maybeParen _ d = d
+
+run :: M a -> IO ()
+run (M m) = print d where (Just d,_) = m False
+
+instance Functor M where
+  fmap f m = m >>= return . f
+
+instance Applicative M where
+  pure a = M $ \_ -> (Nothing, a)
+  M f <*> M a = M $ \p ->
+    let (Just d1, f') = f True
+        (Just d2, a') = a True
+    in
+        (Just (maybeParen p (d1 <+> char '|' <+> d2)), f' a')
+
+instance Monad M where
+  return = pure
+  M m >>= k = M $ \p ->
+    let (d1, a) = m True
+        (d2, b) = case k a of M f -> f True
+    in
+    case (d1,d2) of
+      (Nothing,Nothing) -> (Nothing, b)
+      (Just d, Nothing) -> (Just d, b)
+      (Nothing, Just d) -> (Just d, b)
+      (Just d1, Just d2) -> (Just (maybeParen p (d1 <> semi <+> d2)), b)
+
+doc :: String -> M ()
+doc d = M $ \_ -> (Just (text d), ())
diff --git a/testsuite/tests/ado/ado-optimal.stdout b/testsuite/tests/ado/ado-optimal.stdout
new file mode 100644 (file)
index 0000000..29f9856
--- /dev/null
@@ -0,0 +1 @@
+((a; b) | (c; d)); e
index 67e04c1..6ddc839 100644 (file)
@@ -15,6 +15,15 @@ test2 f = do
   y <- f 4
   return (x + y)
 
+-- Test we can also infer the Functor version of the type
+test2a f = do
+  x <- f 3
+  return (x + 1)
+
+-- Test for just one statement
+test2b f = do
+  return (f 3)
+
 -- This one will use join
 test3 f g = do
   x <- f 3
index f1cc36c..2bb2e6d 100644 (file)
@@ -5,6 +5,12 @@ TYPE SIGNATURES
     forall t b (f :: * -> *).
     (Num b, Num t, Applicative f) =>
     (t -> f b) -> f b
+  test2a ::
+    forall (f :: * -> *) b t.
+    (Num t, Num b, Functor f) =>
+    (t -> f b) -> f b
+  test2b ::
+    forall (m :: * -> *) a t. (Num t, Monad m) => (t -> a) -> m a
   test3 ::
     forall a t (m :: * -> *) t1.
     (Num t1, Monad m) =>
index e1efdf2..06cdbf9 100644 (file)
@@ -6,3 +6,4 @@ test('ado005', normal, compile_fail, [''])
 test('ado006', normal, compile, [''])
 test('ado007', normal, compile, [''])
 test('T11607', normal, compile_and_run, [''])
+test('ado-optimal', normal, compile_and_run, [''])
index 389cd37..dd9ffd9 100644 (file)
@@ -216,6 +216,12 @@ optimizationsOptions =
          , flagType = DynamicFlag
          , flagReverse = "-fno-omit-yields"
          }
+  , flag { flagName = "-foptimal-applicative-do"
+         , flagDescription =
+           "Use a slower but better algorithm for ApplicativeDo"
+         , flagType = DynamicFlag
+         , flagReverse = "-fno-optimal-applicative-do"
+         }
   , flag { flagName = "-fpedantic-bottoms"
          , flagDescription =
            "Make GHC be more precise about its treatment of bottom (but see "++