Return results of Cmm streams in backends
authorÖmer Sinan Ağacan <omeragacan@gmail.com>
Wed, 21 Aug 2019 12:32:09 +0000 (15:32 +0300)
committerÖmer Sinan Ağacan <omeragacan@gmail.com>
Wed, 28 Aug 2019 09:51:12 +0000 (12:51 +0300)
This generalizes code generators (outputAsm, outputLlvm, outputC, and
the call site codeOutput) so that they'll return the return values of
the passed Cmm streams.

This allows accumulating data during Cmm generation and returning it to
the call site in HscMain.

Previously the Cmm streams were assumed to return (), so the code
generators returned () as well.

This change is required by !1304 and !1530.

Skipping CI as this was tested before and I only updated the commit
message.

[skip ci]

compiler/cmm/CmmInfo.hs
compiler/llvmGen/LlvmCodeGen.hs
compiler/llvmGen/LlvmCodeGen/Base.hs
compiler/main/CodeOutput.hs
compiler/main/HscMain.hs
compiler/nativeGen/AsmCodeGen.hs
compiler/utils/Stream.hs

index 138e7aa..60814f8 100644 (file)
@@ -67,16 +67,17 @@ mkEmptyContInfoTable info_lbl
                  , cit_srt  = Nothing
                  , cit_clo  = Nothing }
 
-cmmToRawCmm :: DynFlags -> Stream IO CmmGroup ()
-            -> IO (Stream IO RawCmmGroup ())
+cmmToRawCmm :: DynFlags -> Stream IO CmmGroup a
+            -> IO (Stream IO RawCmmGroup a)
 cmmToRawCmm dflags cmms
   = do { uniqs <- mkSplitUniqSupply 'i'
-       ; let do_one uniqs cmm =
+       ; let do_one :: UniqSupply -> [CmmDecl] -> IO (UniqSupply, [RawCmmDecl])
+             do_one uniqs cmm =
                -- NB. strictness fixes a space leak.  DO NOT REMOVE.
                withTiming (return dflags) (text "Cmm -> Raw Cmm") forceRes $
                  case initUs uniqs $ concatMapM (mkInfoTable dflags) cmm of
                    (b,uniqs') -> return (uniqs',b)
-       ; return (Stream.mapAccumL do_one uniqs cmms >> return ())
+       ; return (snd <$> Stream.mapAccumL_ do_one uniqs cmms)
        }
 
     where forceRes (uniqs, rawcmms) =
index 2a568f8..f649069 100644 (file)
@@ -42,8 +42,8 @@ import System.IO
 -- | Top-level of the LLVM Code generator
 --
 llvmCodeGen :: DynFlags -> Handle -> UniqSupply
-               -> Stream.Stream IO RawCmmGroup ()
-               -> IO ()
+               -> Stream.Stream IO RawCmmGroup a
+               -> IO a
 llvmCodeGen dflags h us cmm_stream
   = withTiming (pure dflags) (text "LLVM CodeGen") (const ()) $ do
        bufh <- newBufHandle h
@@ -66,12 +66,14 @@ llvmCodeGen dflags h us cmm_stream
                             $+$ text "We will try though...")
 
        -- run code generation
-       runLlvm dflags ver bufh us $
+       a <- runLlvm dflags ver bufh us $
          llvmCodeGen' (liftStream cmm_stream)
 
        bFlush bufh
 
-llvmCodeGen' :: Stream.Stream LlvmM RawCmmGroup () -> LlvmM ()
+       return a
+
+llvmCodeGen' :: Stream.Stream LlvmM RawCmmGroup a -> LlvmM a
 llvmCodeGen' cmm_stream
   = do  -- Preamble
         renderLlvm header
@@ -79,13 +81,15 @@ llvmCodeGen' cmm_stream
         cmmMetaLlvmPrelude
 
         -- Procedures
-        () <- Stream.consume cmm_stream llvmGroupLlvmGens
+        a <- Stream.consume cmm_stream llvmGroupLlvmGens
 
         -- Declare aliases for forward references
         renderLlvm . pprLlvmData =<< generateExternDecls
 
         -- Postamble
         cmmUsedLlvmGens
+
+        return a
   where
     header :: SDoc
     header = sdocWithDynFlags $ \dflags ->
index 81f3b9f..7bed4c7 100644 (file)
@@ -253,10 +253,10 @@ liftIO m = LlvmM $ \env -> do x <- m
                               return (x, env)
 
 -- | Get initial Llvm environment.
-runLlvm :: DynFlags -> LlvmVersion -> BufHandle -> UniqSupply -> LlvmM () -> IO ()
+runLlvm :: DynFlags -> LlvmVersion -> BufHandle -> UniqSupply -> LlvmM a -> IO a
 runLlvm dflags ver out us m = do
-    _ <- runLlvmM m env
-    return ()
+    (a, _) <- runLlvmM m env
+    return a
   where env = LlvmEnv { envFunMap = emptyUFM
                       , envVarMap = emptyUFM
                       , envStackRegs = []
index 66c11f0..839999a 100644 (file)
@@ -54,10 +54,11 @@ codeOutput :: DynFlags
            -> [(ForeignSrcLang, FilePath)]
            -- ^ additional files to be compiled with with the C compiler
            -> [InstalledUnitId]
-           -> Stream IO RawCmmGroup ()                       -- Compiled C--
+           -> Stream IO RawCmmGroup a                       -- Compiled C--
            -> IO (FilePath,
                   (Bool{-stub_h_exists-}, Maybe FilePath{-stub_c_exists-}),
-                  [(ForeignSrcLang, FilePath)]{-foreign_fps-})
+                  [(ForeignSrcLang, FilePath)]{-foreign_fps-},
+                  a)
 
 codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps
   cmm_stream
@@ -87,15 +88,14 @@ codeOutput dflags this_mod filenm location foreign_stubs foreign_fps pkg_deps
                 }
 
         ; stubs_exist <- outputForeignStubs dflags this_mod location foreign_stubs
-        ; case hscTarget dflags of {
-             HscAsm         -> outputAsm dflags this_mod location filenm
-                                         linted_cmm_stream;
-             HscC           -> outputC dflags filenm linted_cmm_stream pkg_deps;
-             HscLlvm        -> outputLlvm dflags filenm linted_cmm_stream;
-             HscInterpreted -> panic "codeOutput: HscInterpreted";
-             HscNothing     -> panic "codeOutput: HscNothing"
-          }
-        ; return (filenm, stubs_exist, foreign_fps)
+        ; a <- case hscTarget dflags of
+                 HscAsm         -> outputAsm dflags this_mod location filenm
+                                             linted_cmm_stream
+                 HscC           -> outputC dflags filenm linted_cmm_stream pkg_deps
+                 HscLlvm        -> outputLlvm dflags filenm linted_cmm_stream
+                 HscInterpreted -> panic "codeOutput: HscInterpreted"
+                 HscNothing     -> panic "codeOutput: HscNothing"
+        ; return (filenm, stubs_exist, foreign_fps, a)
         }
 
 doOutput :: String -> (Handle -> IO a) -> IO a
@@ -111,13 +111,13 @@ doOutput filenm io_action = bracket (openFile filenm WriteMode) hClose io_action
 
 outputC :: DynFlags
         -> FilePath
-        -> Stream IO RawCmmGroup ()
+        -> Stream IO RawCmmGroup a
         -> [InstalledUnitId]
-        -> IO ()
+        -> IO a
 
 outputC dflags filenm cmm_stream packages
   = do
-       withTiming (return dflags) (text "C codegen") id $ do
+       withTiming (return dflags) (text "C codegen") (\a -> seq a () {- FIXME -}) $ do
 
          -- figure out which header files to #include in the generated .hc file:
          --
@@ -150,18 +150,17 @@ outputC dflags filenm cmm_stream packages
 -}
 
 outputAsm :: DynFlags -> Module -> ModLocation -> FilePath
-          -> Stream IO RawCmmGroup ()
-          -> IO ()
+          -> Stream IO RawCmmGroup a
+          -> IO a
 outputAsm dflags this_mod location filenm cmm_stream
  | platformMisc_ghcWithNativeCodeGen $ platformMisc dflags
   = do ncg_uniqs <- mkSplitUniqSupply 'n'
 
        debugTraceMsg dflags 4 (text "Outputing asm to" <+> text filenm)
 
-       _ <- {-# SCC "OutputAsm" #-} doOutput filenm $
+       {-# SCC "OutputAsm" #-} doOutput filenm $
            \h -> {-# SCC "NativeCodeGen" #-}
                  nativeCodeGen dflags this_mod location h ncg_uniqs cmm_stream
-       return ()
 
  | otherwise
   = panic "This compiler was built without a native code generator"
@@ -174,7 +173,7 @@ outputAsm dflags this_mod location filenm cmm_stream
 ************************************************************************
 -}
 
-outputLlvm :: DynFlags -> FilePath -> Stream IO RawCmmGroup () -> IO ()
+outputLlvm :: DynFlags -> FilePath -> Stream IO RawCmmGroup a -> IO a
 outputLlvm dflags filenm cmm_stream
   = do ncg_uniqs <- mkSplitUniqSupply 'n'
 
index d12ff03..a9e443c 100644 (file)
@@ -1426,7 +1426,7 @@ hscGenHardCode hsc_env cgguts mod_summary output_filename = do
                             return a
                 rawcmms1 = Stream.mapM dump rawcmms0
 
-            (output_filename, (_stub_h_exists, stub_c_exists), foreign_fps)
+            (output_filename, (_stub_h_exists, stub_c_exists), foreign_fps, ())
                 <- {-# SCC "codeOutput" #-}
                   codeOutput dflags this_mod output_filename location
                   foreign_stubs foreign_files dependencies rawcmms1
index 40a1e0b..fe59a4d 100644 (file)
@@ -157,14 +157,14 @@ The machine-dependent bits break down as follows:
 -}
 
 --------------------
-nativeCodeGen :: DynFlags -> Module -> ModLocation -> Handle -> UniqSupply
-              -> Stream IO RawCmmGroup ()
-              -> IO UniqSupply
+nativeCodeGen :: forall a . DynFlags -> Module -> ModLocation -> Handle -> UniqSupply
+              -> Stream IO RawCmmGroup a
+              -> IO a
 nativeCodeGen dflags this_mod modLoc h us cmms
  = let platform = targetPlatform dflags
        nCG' :: ( Outputable statics, Outputable instr
                , Outputable jumpDest, Instruction instr)
-            => NcgImpl statics instr jumpDest -> IO UniqSupply
+            => NcgImpl statics instr jumpDest -> IO a
        nCG' ncgImpl = nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
    in case platformArch platform of
       ArchX86       -> nCG' (x86NcgImpl    dflags)
@@ -314,8 +314,8 @@ nativeCodeGen' :: (Outputable statics, Outputable instr,Outputable jumpDest,
                -> NcgImpl statics instr jumpDest
                -> Handle
                -> UniqSupply
-               -> Stream IO RawCmmGroup ()
-               -> IO UniqSupply
+               -> Stream IO RawCmmGroup a
+               -> IO a
 nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
  = do
         -- BufHandle is a performance hack.  We could hide it inside
@@ -323,9 +323,10 @@ nativeCodeGen' dflags this_mod modLoc ncgImpl h us cmms
         -- printDocs here (in order to do codegen in constant space).
         bufh <- newBufHandle h
         let ngs0 = NGS [] [] [] [] [] [] emptyUFM mapEmpty
-        (ngs, us') <- cmmNativeGenStream dflags this_mod modLoc ncgImpl bufh us
+        (ngs, us', a) <- cmmNativeGenStream dflags this_mod modLoc ncgImpl bufh us
                                          cmms ngs0
-        finishNativeGen dflags modLoc bufh us' ngs
+        _ <- finishNativeGen dflags modLoc bufh us' ngs
+        return a
 
 finishNativeGen :: Instruction instr
                 => DynFlags
@@ -386,20 +387,21 @@ cmmNativeGenStream :: (Outputable statics, Outputable instr
               -> NcgImpl statics instr jumpDest
               -> BufHandle
               -> UniqSupply
-              -> Stream IO RawCmmGroup ()
+              -> Stream IO RawCmmGroup a
               -> NativeGenAcc statics instr
-              -> IO (NativeGenAcc statics instr, UniqSupply)
+              -> IO (NativeGenAcc statics instr, UniqSupply, a)
 
 cmmNativeGenStream dflags this_mod modLoc ncgImpl h us cmm_stream ngs
  = do r <- Stream.runStream cmm_stream
       case r of
-        Left () ->
+        Left a ->
           return (ngs { ngs_imports = reverse $ ngs_imports ngs
                       , ngs_natives = reverse $ ngs_natives ngs
                       , ngs_colorStats = reverse $ ngs_colorStats ngs
                       , ngs_linearStats = reverse $ ngs_linearStats ngs
                       },
-                  us)
+                  us,
+                  a)
         Right (cmms, cmm_stream') -> do
           (us', ngs'') <-
             withTiming (return dflags)
index 2ad2b8c..7eabbe1 100644 (file)
@@ -7,8 +7,8 @@
 -- -----------------------------------------------------------------------------
 module Stream (
     Stream(..), yield, liftIO,
-    collect, consume, fromList,
-    Stream.map, Stream.mapM, Stream.mapAccumL
+    collect, collect_, consume, fromList,
+    Stream.map, Stream.mapM, Stream.mapAccumL, Stream.mapAccumL_
   ) where
 
 import GhcPrelude
@@ -71,6 +71,16 @@ collect str = go str []
       Left () -> return (reverse acc)
       Right (a, str') -> go str' (a:acc)
 
+-- | Turn a Stream into an ordinary list, by demanding all the elements.
+collect_ :: Monad m => Stream m a r -> m ([a], r)
+collect_ str = go str []
+ where
+  go str acc = do
+    r <- runStream str
+    case r of
+      Left r -> return (reverse acc, r)
+      Right (a, str') -> go str' (a:acc)
+
 consume :: Monad m => Stream m a b -> (a -> m ()) -> m b
 consume str f = do
     r <- runStream str
@@ -113,3 +123,13 @@ mapAccumL f c str = Stream $ do
     Right (a, str') -> do
       (c',b) <- f c a
       return (Right (b, mapAccumL f c' str'))
+
+mapAccumL_ :: Monad m => (c -> a -> m (c,b)) -> c -> Stream m a r
+           -> Stream m b (c, r)
+mapAccumL_ f c str = Stream $ do
+  r <- runStream str
+  case r of
+    Left  r -> return (Left (c, r))
+    Right (a, str') -> do
+      (c',b) <- f c a
+      return (Right (b, mapAccumL_ f c' str'))