Fix testsuite threading, timeout, encoding and performance issues on Windows
authorTamar Christina <tamar@zhox.com>
Tue, 29 Nov 2016 21:56:08 +0000 (16:56 -0500)
committerBen Gamari <ben@smart-cactus.org>
Wed, 30 Nov 2016 01:38:43 +0000 (20:38 -0500)
In a land far far away, a project called Cygwin was born.
Cygwin used newlib as it's standard C library implementation.

But Cygwin wanted to emulate POSIX systems as closely as possible.
So it implemented `execv` using the Windows function `spawnve`.

Specifically

```
spawnve (_P_OVERLAY, path, argv, cur_environ ())
```

`_P_OVERLAY` is crucial, as it makes the function behave *sort of*
like execv on linux. the child process replaces the original process.

With one major difference because of the difference in process models
on Windows: the original process signals the caller that it's done.

this is why the file is still locked. because it's still running,
control was returned because the parent process was destroyed,
but the child is still running.

I think it's just pure dumb luck, that the older runtimes are slow
enough to give the process time to terminate before we tried deleting
the file.  Which explains why you do have sporadic failures even on
older runtimes like 2.5.0, of a test or two (like T7307).

So this patch fixes a couple of things. I leverage the existing
`timeout.exe` to implement a workaround for this issue.

a) The old timeout used to start the process then assign it to the job.
   This is slightly faulty since child processes are only assigned to a
   job is their parent were assigned at the time they started. So this
   was a race condition. I now create the process suspended, assign it
   to the job and then resume it. Which means all child processes are
   not running under the same job.

b) First things, Is to prevent dangling child processes. I mark the job
   with `JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE` so when the last process in
   the job is done, it insures all processes under the job are killed.

c) Secondly, I change the way we wait for results. Instead of waiting
   for the parent process to terminate, I wait for the job itself to
   terminate.

   There's a slight subtlety there, we can't wait on the job itself.
   Instead we have to create an I/O Completion port and wait for signals
   on it.  See
   https://blogs.msdn.microsoft.com/oldnewthing/20130405-00/?p=4743

This fixes the issues on all runtimes for me and makes T7307 pass
consistenly.

The threading was also simplified by hiding all the locking in a single
semaphore and a completion class. Futhermore some additional error
reporting was added.

For encoding the testsuite now no longer passes a file handle to the
subprocess since on windows, sh.exe seems to acquire a lock on the file
that is not released in a timely fashion.

I suspect this because cygwin seems to emulate console handles by
creating file handles and using those for std handles. So when we give
it an existing file handle it just locks the file. I what's happening is
that it's not releasing the handle until all shared cygwin processes are
dead. Which explains why it worked in single threaded mode.

So now instead we pass a pipe and do not interpret the resulting data.

Any bytes written to stdin or read out of stdout/stderr are done so in
binary mode and we do not interpret the data. The reason for this is
that we have encoding tests in GHC which pass invalid utf-8. If we try
to handle the data as text then python will throw an exception instead
of a test comparison failing.

Also I have fixed the ability to override `PYTHON` when calling `make
tests`. This now works the same as with `.\validate`.

Finally, after cleaning up the locks I was able to make the abort
behavior work correctly as I believe it was intended: when you press
Ctrl+C and send an interrupt signal, the testsuite finishes the active
tests and then gracefully exits showing you a report of the progress it
did make. So using Ctrl+C will not just *die* as it did before.

These changes lift the restriction on which python version you use
(msys/mingw) or which runtime or python 3 or python 2.  All combinations
should now be supported.

Test Plan:
PATH=/usr/local/bin:/mingw64/bin:$APPDATA/cabal/bin:$PATH &&
PYTHON=/usr/bin/python THREADS=9 make test
THREADS=9 make test
PATH=/usr/local/bin:/mingw64/bin:$APPDATA/cabal/bin:$PATH &&
PYTHON=/usr/bin/python ./validate --quiet --testsuite-only

Reviewers: erikd, RyanGlScott, bgamari, austin

Subscribers: jrtc27, mpickering, thomie, #ghc_windows_task_force

Differential Revision: https://phabricator.haskell.org/D2684

GHC Trac Issues: #12725, #12554, #12661, #12004

testsuite/driver/runtests.py
testsuite/driver/testlib.py
testsuite/driver/testutil.py
testsuite/mk/boilerplate.mk
testsuite/timeout/WinCBindings.hsc
testsuite/timeout/timeout.hs

index c97323b..1b6fe12 100644 (file)
@@ -4,6 +4,7 @@
 
 from __future__ import print_function
 
+import signal
 import sys
 import os
 import string
@@ -38,6 +39,9 @@ os.environ['TERM'] = 'vt100'
 global config
 config = getConfig() # get it from testglobals
 
+def signal_handler(signal, frame):
+        stopNow()
+
 # -----------------------------------------------------------------------------
 # cmd-line options
 
@@ -173,6 +177,9 @@ if windows:
         raise Exception("Failure calling SetConsoleCP(65001)")
     if kernel32.SetConsoleOutputCP(65001) == 0:
         raise Exception("Failure calling SetConsoleOutputCP(65001)")
+
+    # register the interrupt handler
+    signal.signal(signal.SIGINT, signal_handler)
 else:
     # Try and find a utf8 locale to use
     # First see if we already have a UTF8 locale
@@ -237,12 +244,6 @@ if windows or darwin:
 global testopts_local
 testopts_local.x = TestOptions()
 
-if config.use_threads:
-    t.lock = threading.Lock()
-    t.thread_pool = threading.Condition(t.lock)
-    t.lockFilesWritten = threading.Lock()
-    t.running_threads = 0
-
 # if timeout == -1 then we try to calculate a sensible value
 if config.timeout == -1:
     config.timeout = int(read_no_crs(config.top + '/timeout/calibrate.out'))
@@ -302,9 +303,11 @@ for file in t_files:
     newTestDir(tempdir, os.path.dirname(file))
     try:
         if PYTHON3:
-            src = io.open(file, encoding='utf8').read()
+            with io.open(file, encoding='utf8') as f:
+                src = f.read()
         else:
-            src = open(file).read()
+            with open(file) as f:
+                src = f.read()
 
         exec(src)
     except Exception as e:
@@ -333,28 +336,34 @@ if config.list_broken:
         print('WARNING:', len(framework_failures), 'framework failures!')
         print('')
 else:
+    # completion watcher
+    watcher = Watcher(len(parallelTests))
+
     # Now run all the tests
-    if config.use_threads:
-        t.running_threads=0
     for oneTest in parallelTests:
         if stopping():
             break
-        oneTest()
-    if config.use_threads:
-        t.thread_pool.acquire()
-        while t.running_threads>0:
-            t.thread_pool.wait()
-        t.thread_pool.release()
+        oneTest(watcher)
+
+    # wait for parallel tests to finish
+    if not stopping():
+        watcher.wait()
+
+    # Run the following tests purely sequential
     config.use_threads = False
     for oneTest in aloneTests:
         if stopping():
             break
-        oneTest()
-        
+        oneTest(watcher)
+
+    # flush everything before we continue
+    sys.stdout.flush()
+
     summary(t, sys.stdout, config.no_print_summary)
 
     if config.summary_file != '':
-        summary(t, open(config.summary_file, 'w'))
+        with open(config.summary_file, 'w') as file:
+            summary(t, file)
 
 cleanup_and_exit(0)
 
index d9d3335..b0252de 100644 (file)
@@ -38,6 +38,11 @@ if config.use_threads:
 
 global wantToStop
 wantToStop = False
+
+global pool_sema
+if config.use_threads:
+    pool_sema = threading.BoundedSemaphore(value=config.threads)
+
 def stopNow():
     global wantToStop
     wantToStop = True
@@ -601,27 +606,20 @@ parallelTests = []
 aloneTests = []
 allTestNames = set([])
 
-def runTest (opts, name, func, args):
-    ok = 0
-
+def runTest(watcher, opts, name, func, args):
     if config.use_threads:
-        t.thread_pool.acquire()
-        try:
-            while config.threads<(t.running_threads+1):
-                t.thread_pool.wait()
-            t.running_threads = t.running_threads+1
-            ok=1
-            t.thread_pool.release()
-            thread.start_new_thread(test_common_thread, (name, opts, func, args))
-        except:
-            if not ok:
-                t.thread_pool.release()
+        pool_sema.acquire()
+        t = threading.Thread(target=test_common_thread,
+                             name=name,
+                             args=(watcher, name, opts, func, args))
+        t.daemon = False
+        t.start()
     else:
-        test_common_work (name, opts, func, args)
+        test_common_work(watcher, name, opts, func, args)
 
 # name  :: String
 # setup :: TestOpts -> IO ()
-def test (name, setup, func, args):
+def test(name, setup, func, args):
     global aloneTests
     global parallelTests
     global allTestNames
@@ -649,7 +647,7 @@ def test (name, setup, func, args):
 
     executeSetups([thisdir_settings, setup], name, myTestOpts)
 
-    thisTest = lambda : runTest(myTestOpts, name, func, args)
+    thisTest = lambda watcher: runTest(watcher, myTestOpts, name, func, args)
     if myTestOpts.alone:
         aloneTests.append(thisTest)
     else:
@@ -657,16 +655,11 @@ def test (name, setup, func, args):
     allTestNames.add(name)
 
 if config.use_threads:
-    def test_common_thread(name, opts, func, args):
-        t.lock.acquire()
-        try:
-            test_common_work(name,opts,func,args)
-        finally:
-            t.lock.release()
-            t.thread_pool.acquire()
-            t.running_threads = t.running_threads - 1
-            t.thread_pool.notify()
-            t.thread_pool.release()
+    def test_common_thread(watcher, name, opts, func, args):
+            try:
+                test_common_work(watcher, name, opts, func, args)
+            finally:
+                pool_sema.release()
 
 def get_package_cache_timestamp():
     if config.package_conf_cache_file == '':
@@ -679,7 +672,7 @@ def get_package_cache_timestamp():
 
 do_not_copy = ('.hi', '.o', '.dyn_hi', '.dyn_o', '.out') # 12112
 
-def test_common_work (name, opts, func, args):
+def test_common_work(watcher, name, opts, func, args):
     try:
         t.total_tests += 1
         setLocalTestOpts(opts)
@@ -779,6 +772,8 @@ def test_common_work (name, opts, func, args):
 
     except Exception as e:
         framework_fail(name, 'runTest', 'Unhandled exception: ' + str(e))
+    finally:
+        watcher.notify()
 
 def do_test(name, way, func, args, files):
     opts = getTestOpts()
@@ -831,9 +826,6 @@ def do_test(name, way, func, args, files):
                 with io.open(dst_makefile, 'w', encoding='utf8') as dst:
                     dst.write(makefile)
 
-    if config.use_threads:
-        t.lock.release()
-
     if opts.pre_cmd:
         exit_code = runCmd('cd "{0}" && {1}'.format(opts.testdir, opts.pre_cmd))
         if exit_code != 0:
@@ -841,9 +833,8 @@ def do_test(name, way, func, args, files):
 
     try:
         result = func(*[name,way] + args)
-    finally:
-        if config.use_threads:
-            t.lock.acquire()
+    except:
+        pass
 
     if opts.expect not in ['pass', 'fail', 'missing-lib']:
         framework_fail(name, way, 'bad expected ' + opts.expect)
@@ -1346,21 +1337,18 @@ def interpreter_run(name, way, extra_hc_opts, top_mod):
 
 def split_file(in_fn, delimiter, out1_fn, out2_fn):
     # See Note [Universal newlines].
-    infile = io.open(in_fn, 'r', encoding='utf8', errors='replace', newline=None)
-    out1 = io.open(out1_fn, 'w', encoding='utf8', newline='')
-    out2 = io.open(out2_fn, 'w', encoding='utf8', newline='')
-
-    line = infile.readline()
-    while (re.sub('^\s*','',line) != delimiter and line != ''):
-        out1.write(line)
-        line = infile.readline()
-    out1.close()
-
-    line = infile.readline()
-    while (line != ''):
-        out2.write(line)
-        line = infile.readline()
-    out2.close()
+    with io.open(in_fn, 'r', encoding='utf8', errors='replace', newline=None) as infile:
+        with io.open(out1_fn, 'w', encoding='utf8', newline='') as out1:
+            with io.open(out2_fn, 'w', encoding='utf8', newline='') as out2:
+                line = infile.readline()
+                while re.sub('^\s*','',line) != delimiter and line != '':
+                    out1.write(line)
+                    line = infile.readline()
+
+                line = infile.readline()
+                while line != '':
+                    out2.write(line)
+                    line = infile.readline()
 
 # -----------------------------------------------------------------------------
 # Utils
@@ -1392,7 +1380,8 @@ def stdout_ok(name, way):
 
 def dump_stdout( name ):
    print('Stdout:')
-   print(open(in_testdir(name, 'run.stdout')).read())
+   with open(in_testdir(name, 'run.stdout')) as f:
+       print(f.read())
 
 def stderr_ok(name, way):
    actual_stderr_file = add_suffix(name, 'run.stderr')
@@ -1405,15 +1394,15 @@ def stderr_ok(name, way):
 
 def dump_stderr( name ):
    print("Stderr:")
-   print(open(in_testdir(name, 'run.stderr')).read())
+   with open(in_testdir(name, 'run.stderr')) as f:
+       print(f.read())
 
 def read_no_crs(file):
     str = ''
     try:
         # See Note [Universal newlines].
-        h = io.open(file, 'r', encoding='utf8', errors='replace', newline=None)
-        str = h.read()
-        h.close
+        with io.open(file, 'r', encoding='utf8', errors='replace', newline=None) as h:
+            str = h.read()
     except:
         # On Windows, if the program fails very early, it seems the
         # files stdout/stderr are redirected to may not get created
@@ -1422,9 +1411,8 @@ def read_no_crs(file):
 
 def write_file(file, str):
     # See Note [Universal newlines].
-    h = io.open(file, 'w', encoding='utf8', newline='')
-    h.write(str)
-    h.close
+    with io.open(file, 'w', encoding='utf8', newline='') as h:
+        h.write(str)
 
 # Note [Universal newlines]
 #
@@ -1734,7 +1722,8 @@ def if_verbose( n, s ):
 def if_verbose_dump( n, f ):
     if config.verbose >= n:
         try:
-            print(open(f).read())
+            with io.open(f) as file:
+                print(file.read())
         except:
             print('')
 
@@ -1746,34 +1735,61 @@ def runCmd(cmd, stdin=None, stdout=None, stderr=None, timeout_multiplier=1.0):
     cmd = cmd.format(**config.__dict__)
     if_verbose(3, cmd + ('< ' + os.path.basename(stdin) if stdin else ''))
 
-    if stdin:
-        stdin = open(stdin, 'r')
-    if stdout:
-        stdout = open(stdout, 'w')
-    if stderr and stderr is not subprocess.STDOUT:
-        stderr = open(stderr, 'w')
-
-    # cmd is a complex command in Bourne-shell syntax
-    # e.g (cd . && 'C:/users/simonpj/HEAD/inplace/bin/ghc-stage2' ...etc)
-    # Hence it must ultimately be run by a Bourne shell. It's timeout's job
-    # to invoke the Bourne shell
-    r = subprocess.call([timeout_prog, timeout, cmd],
-                        stdin=stdin, stdout=stdout, stderr=stderr)
+    # declare the buffers to a default
+    stdin_buffer  = None
 
+    # ***** IMPORTANT *****
+    # We have to treat input and output as
+    # just binary data here. Don't try to decode
+    # it to a string, since we have tests that actually
+    # feed malformed utf-8 to see how GHC handles it.
     if stdin:
-        stdin.close()
-    if stdout:
-        stdout.close()
-    if stderr and stderr is not subprocess.STDOUT:
-        stderr.close()
+        with io.open(stdin, 'rb') as f:
+            stdin_buffer = f.read()
+
+    stdout_buffer = u''
+    stderr_buffer = u''
+
+    hStdErr = subprocess.PIPE
+    if stderr is subprocess.STDOUT:
+        hStdErr = subprocess.STDOUT
+
+    try:
+        # cmd is a complex command in Bourne-shell syntax
+        # e.g (cd . && 'C:/users/simonpj/HEAD/inplace/bin/ghc-stage2' ...etc)
+        # Hence it must ultimately be run by a Bourne shell. It's timeout's job
+        # to invoke the Bourne shell
+
+        r = subprocess.Popen([timeout_prog, timeout, cmd],
+                             stdin=subprocess.PIPE,
+                             stdout=subprocess.PIPE,
+                             stderr=hStdErr)
 
-    if r == 98:
+        stdout_buffer, stderr_buffer = r.communicate(stdin_buffer)
+    except Exception as e:
+        traceback.print_exc()
+        framework_fail(name, way, str(e))
+    finally:
+        try:
+            if stdout:
+                with io.open(stdout, 'ab') as f:
+                    f.write(stdout_buffer)
+            if stderr:
+                if stderr is not subprocess.STDOUT:
+                    with io.open(stderr, 'ab') as f:
+                        f.write(stderr_buffer)
+
+        except Exception as e:
+            traceback.print_exc()
+            framework_fail(name, way, str(e))
+
+    if r.returncode == 98:
         # The python timeout program uses 98 to signal that ^C was pressed
         stopNow()
-    if r == 99 and getTestOpts().exit_code != 99:
+    if r.returncode == 99 and getTestOpts().exit_code != 99:
         # Only print a message when timeout killed the process unexpectedly.
         if_verbose(1, 'Timeout happened...killed process "{0}"...\n'.format(cmd))
-    return r
+    return r.returncode
 
 # -----------------------------------------------------------------------------
 # checking if ghostscript is available for checking the output of hp2ps
index b4159d1..d35fb81 100644 (file)
@@ -4,6 +4,8 @@ import platform
 import subprocess
 import shutil
 
+import threading
+
 def strip_quotes(s):
     # Don't wrap commands to subprocess.call/Popen in quotes.
     return s.strip('\'"')
@@ -56,3 +58,25 @@ if platform.system() == 'Windows':
     link_or_copy_file = shutil.copyfile
 else:
     link_or_copy_file = os.symlink
+
+class Watcher(object):
+    global pool
+    global evt
+    global sync_lock
+    
+    def __init__(self, count):
+        self.pool = count
+        self.evt = threading.Event()
+        self.sync_lock = threading.Lock()
+        if count <= 0:
+            self.evt.set()
+
+    def wait(self):
+        self.evt.wait()
+
+    def notify(self):
+        self.sync_lock.acquire()
+        self.pool -= 1
+        if self.pool <= 0:
+            self.evt.set()
+        self.sync_lock.release()
index 09c61a4..1aa58ab 100644 (file)
@@ -217,9 +217,14 @@ $(eval $(call canonicalise,TOP_ABS))
 GS = gs
 CP = cp
 RM = rm -f
-PYTHON = python
+# Allow the user to override the python version, just like with validate
+ifeq "$(shell $(SHELL) -c '$(PYTHON) -c 0' 2> /dev/null && echo exists)" "exists"
+else
 ifeq "$(shell $(SHELL) -c 'python2 -c 0' 2> /dev/null && echo exists)" "exists"
 PYTHON = python2
+else
+PYTHON = python
+endif
 endif
 
 CHECK_API_ANNOTATIONS := $(abspath $(TOP)/../inplace/bin/check-api-annotations)
index 51764dc..87e4341 100644 (file)
@@ -3,7 +3,16 @@ module WinCBindings where
 
 #if defined(mingw32_HOST_OS)
 
+##if defined(i386_HOST_ARCH)
+## define WINDOWS_CCONV stdcall
+##elif defined(x86_64_HOST_ARCH)
+## define WINDOWS_CCONV ccall
+##else
+## error Unknown mingw32 arch
+##endif
+
 import Foreign
+import Foreign.C.Types
 import System.Win32.File
 import System.Win32.Types
 
@@ -109,9 +118,169 @@ instance Storable STARTUPINFO where
             siStdOutput     =  vhStdOutput,
             siStdError      =  vhStdError}
 
-foreign import stdcall unsafe "windows.h WaitForSingleObject"
+data JOBOBJECT_EXTENDED_LIMIT_INFORMATION = JOBOBJECT_EXTENDED_LIMIT_INFORMATION
+    { jeliBasicLimitInformation :: JOBOBJECT_BASIC_LIMIT_INFORMATION
+    , jeliIoInfo                :: IO_COUNTERS
+    , jeliProcessMemoryLimit    :: SIZE_T
+    , jeliJobMemoryLimit        :: SIZE_T
+    , jeliPeakProcessMemoryUsed :: SIZE_T
+    , jeliPeakJobMemoryUsed     :: SIZE_T
+    } deriving Show
+
+instance Storable JOBOBJECT_EXTENDED_LIMIT_INFORMATION where
+    sizeOf = const #size JOBOBJECT_EXTENDED_LIMIT_INFORMATION
+    alignment = const #alignment JOBOBJECT_EXTENDED_LIMIT_INFORMATION
+    poke buf jeli = do
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, BasicLimitInformation) buf (jeliBasicLimitInformation jeli)
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, IoInfo)                buf (jeliIoInfo jeli)
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, ProcessMemoryLimit)    buf (jeliProcessMemoryLimit jeli)
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobMemoryLimit)        buf (jeliJobMemoryLimit jeli)
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, PeakProcessMemoryUsed) buf (jeliPeakProcessMemoryUsed jeli)
+        (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, PeakJobMemoryUsed)     buf (jeliPeakJobMemoryUsed jeli)
+    peek buf = do
+        vBasicLimitInformation <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, BasicLimitInformation) buf
+        vIoInfo                <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, IoInfo)                buf
+        vProcessMemoryLimit    <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, ProcessMemoryLimit)    buf
+        vJobMemoryLimit        <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JobMemoryLimit)        buf
+        vPeakProcessMemoryUsed <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, PeakProcessMemoryUsed) buf
+        vPeakJobMemoryUsed     <- (#peek JOBOBJECT_EXTENDED_LIMIT_INFORMATION, PeakJobMemoryUsed)     buf
+        return $ JOBOBJECT_EXTENDED_LIMIT_INFORMATION {
+            jeliBasicLimitInformation = vBasicLimitInformation,
+            jeliIoInfo                = vIoInfo,
+            jeliProcessMemoryLimit    = vProcessMemoryLimit,
+            jeliJobMemoryLimit        = vJobMemoryLimit,
+            jeliPeakProcessMemoryUsed = vPeakProcessMemoryUsed,
+            jeliPeakJobMemoryUsed     = vPeakJobMemoryUsed}
+
+type ULONGLONG = #type ULONGLONG
+
+data IO_COUNTERS = IO_COUNTERS
+    { icReadOperationCount  :: ULONGLONG
+    , icWriteOperationCount :: ULONGLONG
+    , icOtherOperationCount :: ULONGLONG
+    , icReadTransferCount   :: ULONGLONG
+    , icWriteTransferCount  :: ULONGLONG
+    , icOtherTransferCount  :: ULONGLONG
+    } deriving Show
+
+instance Storable IO_COUNTERS where
+    sizeOf = const #size IO_COUNTERS
+    alignment = const #alignment IO_COUNTERS
+    poke buf ic = do
+        (#poke IO_COUNTERS, ReadOperationCount)  buf (icReadOperationCount ic)
+        (#poke IO_COUNTERS, WriteOperationCount) buf (icWriteOperationCount ic)
+        (#poke IO_COUNTERS, OtherOperationCount) buf (icOtherOperationCount ic)
+        (#poke IO_COUNTERS, ReadTransferCount)   buf (icReadTransferCount ic)
+        (#poke IO_COUNTERS, WriteTransferCount)  buf (icWriteTransferCount ic)
+        (#poke IO_COUNTERS, OtherTransferCount)  buf (icOtherTransferCount ic)
+    peek buf = do
+        vReadOperationCount  <- (#peek IO_COUNTERS, ReadOperationCount)  buf
+        vWriteOperationCount <- (#peek IO_COUNTERS, WriteOperationCount) buf
+        vOtherOperationCount <- (#peek IO_COUNTERS, OtherOperationCount) buf
+        vReadTransferCount   <- (#peek IO_COUNTERS, ReadTransferCount)   buf
+        vWriteTransferCount  <- (#peek IO_COUNTERS, WriteTransferCount)  buf
+        vOtherTransferCount  <- (#peek IO_COUNTERS, OtherTransferCount)  buf
+        return $ IO_COUNTERS {
+            icReadOperationCount  = vReadOperationCount,
+            icWriteOperationCount = vWriteOperationCount,
+            icOtherOperationCount = vOtherOperationCount,
+            icReadTransferCount   = vReadTransferCount,
+            icWriteTransferCount  = vWriteTransferCount,
+            icOtherTransferCount  = vOtherTransferCount}
+
+data JOBOBJECT_BASIC_LIMIT_INFORMATION = JOBOBJECT_BASIC_LIMIT_INFORMATION
+    { jbliPerProcessUserTimeLimit :: LARGE_INTEGER
+    , jbliPerJobUserTimeLimit     :: LARGE_INTEGER
+    , jbliLimitFlags              :: DWORD
+    , jbliMinimumWorkingSetSize   :: SIZE_T
+    , jbliMaximumWorkingSetSize   :: SIZE_T
+    , jbliActiveProcessLimit      :: DWORD
+    , jbliAffinity                :: ULONG_PTR
+    , jbliPriorityClass           :: DWORD
+    , jbliSchedulingClass         :: DWORD
+    } deriving Show
+
+instance Storable JOBOBJECT_BASIC_LIMIT_INFORMATION where
+    sizeOf = const #size JOBOBJECT_BASIC_LIMIT_INFORMATION
+    alignment = const #alignment JOBOBJECT_BASIC_LIMIT_INFORMATION
+    poke buf jbli = do
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, PerProcessUserTimeLimit) buf (jbliPerProcessUserTimeLimit jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, PerJobUserTimeLimit)     buf (jbliPerJobUserTimeLimit jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, LimitFlags)              buf (jbliLimitFlags jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, MinimumWorkingSetSize)   buf (jbliMinimumWorkingSetSize jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, MaximumWorkingSetSize)   buf (jbliMaximumWorkingSetSize jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, ActiveProcessLimit)      buf (jbliActiveProcessLimit jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, Affinity)                buf (jbliAffinity jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, PriorityClass)           buf (jbliPriorityClass jbli)
+        (#poke JOBOBJECT_BASIC_LIMIT_INFORMATION, SchedulingClass)         buf (jbliSchedulingClass jbli)
+    peek buf = do
+        vPerProcessUserTimeLimit <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, PerProcessUserTimeLimit) buf
+        vPerJobUserTimeLimit     <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, PerJobUserTimeLimit)     buf
+        vLimitFlags              <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, LimitFlags)              buf
+        vMinimumWorkingSetSize   <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, MinimumWorkingSetSize)   buf
+        vMaximumWorkingSetSize   <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, MaximumWorkingSetSize)   buf
+        vActiveProcessLimit      <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, ActiveProcessLimit)      buf
+        vAffinity                <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, Affinity)                buf
+        vPriorityClass           <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, PriorityClass)           buf
+        vSchedulingClass         <- (#peek JOBOBJECT_BASIC_LIMIT_INFORMATION, SchedulingClass)         buf
+        return $ JOBOBJECT_BASIC_LIMIT_INFORMATION {
+            jbliPerProcessUserTimeLimit = vPerProcessUserTimeLimit,
+            jbliPerJobUserTimeLimit     = vPerJobUserTimeLimit,
+            jbliLimitFlags              = vLimitFlags,
+            jbliMinimumWorkingSetSize   = vMinimumWorkingSetSize,
+            jbliMaximumWorkingSetSize   = vMaximumWorkingSetSize,
+            jbliActiveProcessLimit      = vActiveProcessLimit,
+            jbliAffinity                = vAffinity,
+            jbliPriorityClass           = vPriorityClass,
+            jbliSchedulingClass         = vSchedulingClass}
+
+data JOBOBJECT_ASSOCIATE_COMPLETION_PORT = JOBOBJECT_ASSOCIATE_COMPLETION_PORT
+    { jacpCompletionKey  :: PVOID
+    , jacpCompletionPort :: HANDLE
+    } deriving Show
+
+instance Storable JOBOBJECT_ASSOCIATE_COMPLETION_PORT where
+    sizeOf = const #size JOBOBJECT_ASSOCIATE_COMPLETION_PORT
+    alignment = const #alignment JOBOBJECT_ASSOCIATE_COMPLETION_PORT
+    poke buf jacp = do
+        (#poke JOBOBJECT_ASSOCIATE_COMPLETION_PORT, CompletionKey)  buf (jacpCompletionKey jacp)
+        (#poke JOBOBJECT_ASSOCIATE_COMPLETION_PORT, CompletionPort) buf (jacpCompletionPort jacp)
+    peek buf = do
+        vCompletionKey  <- (#peek JOBOBJECT_ASSOCIATE_COMPLETION_PORT, CompletionKey)  buf
+        vCompletionPort <- (#peek JOBOBJECT_ASSOCIATE_COMPLETION_PORT, CompletionPort) buf
+        return $ JOBOBJECT_ASSOCIATE_COMPLETION_PORT {
+            jacpCompletionKey  = vCompletionKey,
+            jacpCompletionPort = vCompletionPort}
+
+
+foreign import WINDOWS_CCONV unsafe "windows.h WaitForSingleObject"
     waitForSingleObject :: HANDLE -> DWORD -> IO DWORD
 
+type JOBOBJECTINFOCLASS = CInt
+
+type PVOID = Ptr ()
+
+type ULONG_PTR  = CUIntPtr
+type PULONG_PTR = Ptr ULONG_PTR
+
+jobObjectExtendedLimitInformation :: JOBOBJECTINFOCLASS
+jobObjectExtendedLimitInformation = #const JobObjectExtendedLimitInformation
+
+jobObjectAssociateCompletionPortInformation :: JOBOBJECTINFOCLASS
+jobObjectAssociateCompletionPortInformation = #const JobObjectAssociateCompletionPortInformation
+
+cJOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE :: DWORD
+cJOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE = #const JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
+
+cJOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO :: DWORD
+cJOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO = #const JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO
+
+cJOB_OBJECT_MSG_EXIT_PROCESS :: DWORD
+cJOB_OBJECT_MSG_EXIT_PROCESS = #const JOB_OBJECT_MSG_EXIT_PROCESS
+
+cJOB_OBJECT_MSG_NEW_PROCESS :: DWORD
+cJOB_OBJECT_MSG_NEW_PROCESS = #const JOB_OBJECT_MSG_NEW_PROCESS
+
 cWAIT_ABANDONED :: DWORD
 cWAIT_ABANDONED = #const WAIT_ABANDONED
 
@@ -121,23 +290,100 @@ cWAIT_OBJECT_0 = #const WAIT_OBJECT_0
 cWAIT_TIMEOUT :: DWORD
 cWAIT_TIMEOUT = #const WAIT_TIMEOUT
 
-foreign import stdcall unsafe "windows.h GetExitCodeProcess"
+cCREATE_SUSPENDED :: DWORD
+cCREATE_SUSPENDED = #const CREATE_SUSPENDED
+
+foreign import WINDOWS_CCONV unsafe "windows.h GetExitCodeProcess"
     getExitCodeProcess :: HANDLE -> LPDWORD -> IO BOOL
 
-foreign import stdcall unsafe "windows.h TerminateJobObject"
+foreign import WINDOWS_CCONV unsafe "windows.h CloseHandle"
+    closeHandle :: HANDLE -> IO BOOL
+
+foreign import WINDOWS_CCONV unsafe "windows.h TerminateJobObject"
     terminateJobObject :: HANDLE -> UINT -> IO BOOL
 
-foreign import stdcall unsafe "windows.h AssignProcessToJobObject"
+foreign import WINDOWS_CCONV unsafe "windows.h AssignProcessToJobObject"
     assignProcessToJobObject :: HANDLE -> HANDLE -> IO BOOL
 
-foreign import stdcall unsafe "windows.h CreateJobObjectW"
+foreign import WINDOWS_CCONV unsafe "windows.h CreateJobObjectW"
     createJobObjectW :: LPSECURITY_ATTRIBUTES -> LPCTSTR -> IO HANDLE
 
-foreign import stdcall unsafe "windows.h CreateProcessW"
+foreign import WINDOWS_CCONV unsafe "windows.h CreateProcessW"
     createProcessW :: LPCTSTR -> LPTSTR
                    -> LPSECURITY_ATTRIBUTES -> LPSECURITY_ATTRIBUTES
                    -> BOOL -> DWORD -> LPVOID -> LPCTSTR -> LPSTARTUPINFO
                    -> LPPROCESS_INFORMATION -> IO BOOL
 
+foreign import WINDOWS_CCONV unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO (Ptr a)
+
+foreign import WINDOWS_CCONV unsafe "windows.h SetInformationJobObject"
+    setInformationJobObject :: HANDLE -> JOBOBJECTINFOCLASS -> LPVOID -> DWORD -> IO BOOL
+
+foreign import WINDOWS_CCONV unsafe "windows.h CreateIoCompletionPort"
+    createIoCompletionPort :: HANDLE -> HANDLE -> ULONG_PTR -> DWORD -> IO HANDLE
+
+foreign import WINDOWS_CCONV unsafe "windows.h GetQueuedCompletionStatus"
+    getQueuedCompletionStatus :: HANDLE -> LPDWORD -> PULONG_PTR -> Ptr LPOVERLAPPED -> DWORD -> IO BOOL
+
+setJobParameters :: HANDLE -> IO BOOL
+setJobParameters hJob = alloca $ \p_jeli -> do
+    let jeliSize = sizeOf (undefined :: JOBOBJECT_EXTENDED_LIMIT_INFORMATION)
+    _ <- memset p_jeli 0 $ fromIntegral jeliSize
+    -- Configure all child processes associated with the job to terminate when the
+    -- Last process in the job terminates. This prevent half dead processes and that
+    -- hanging ghc-iserv.exe process that happens when you interrupt the testsuite.
+    (#poke JOBOBJECT_EXTENDED_LIMIT_INFORMATION, BasicLimitInformation.LimitFlags)
+      p_jeli cJOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
+    setInformationJobObject hJob jobObjectExtendedLimitInformation
+                            p_jeli (fromIntegral jeliSize)
+
+createCompletionPort :: HANDLE -> IO HANDLE
+createCompletionPort hJob = do
+    ioPort <- createIoCompletionPort iNVALID_HANDLE_VALUE nullPtr 0 1
+    if ioPort == nullPtr
+       then do err_code <- getLastError
+               putStrLn $ "CreateIoCompletionPort error: " ++ show err_code
+               return nullPtr
+       else with (JOBOBJECT_ASSOCIATE_COMPLETION_PORT {
+                    jacpCompletionKey  = hJob,
+                    jacpCompletionPort = ioPort}) $ \p_Port -> do
+              res <- setInformationJobObject hJob jobObjectAssociateCompletionPortInformation
+                         (castPtr p_Port) (fromIntegral (sizeOf (undefined :: JOBOBJECT_ASSOCIATE_COMPLETION_PORT)))
+              if res
+                 then return ioPort
+                 else do err_code <- getLastError
+                         putStrLn $ "SetInformation, error: " ++ show err_code
+                         return nullPtr
+
+waitForJobCompletion :: HANDLE -> HANDLE -> DWORD -> IO BOOL
+waitForJobCompletion hJob ioPort timeout
+  = alloca $ \p_CompletionCode ->
+    alloca $ \p_CompletionKey ->
+    alloca $ \p_Overlapped -> do
+
+    -- getQueuedCompletionStatus is a blocking call,
+    -- it will wake up for each completion event. So if it's
+    -- not the one we want, sleep again.
+    let loop :: IO ()
+        loop = do
+          res <- getQueuedCompletionStatus ioPort p_CompletionCode p_CompletionKey
+                                           p_Overlapped timeout
+          completionCode <- peek p_CompletionCode
+
+          if completionCode == cJOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO
+                     then return ()
+             else if completionCode == cJOB_OBJECT_MSG_EXIT_PROCESS
+                     then loop
+             else if completionCode == cJOB_OBJECT_MSG_NEW_PROCESS
+                     then loop
+                     else loop
+
+    loop
+
+    overlapped    <- peek p_Overlapped
+    completionKey <- peek $ castPtr p_CompletionKey
+    return $ if overlapped == nullPtr && completionKey /= hJob
+                then False -- Timeout occurred. *dark voice* YOU HAVE FAILED THIS TEST!.
+                else True
 #endif
 
index c015eb6..cf6c448 100644 (file)
@@ -103,28 +103,41 @@ run secs cmd =
     alloca $ \p_pi ->
     withTString cmd' $ \cmd'' ->
     do job <- createJobObjectW nullPtr nullPtr
-       let creationflags = 0
+       b_info <- setJobParameters job
+       unless b_info $ errorWin "setJobParameters"
+
+       ioPort <- createCompletionPort job
+       when (ioPort == nullPtr) $ errorWin "createCompletionPort, cannot continue."
+
+       let creationflags = cCREATE_SUSPENDED
        b <- createProcessW nullPtr cmd'' nullPtr nullPtr True
                            creationflags
                            nullPtr nullPtr p_startupinfo p_pi
        unless b $ errorWin "createProcessW"
+
        pi <- peek p_pi
-       assignProcessToJobObject job (piProcess pi)
+       b_assign <- assignProcessToJobObject job (piProcess pi)
+       unless b_assign $ errorWin "assignProcessToJobObject, cannot continue."
+
        let handleInterrupt action =
                action `onException` terminateJobObject job 99
+
        handleInterrupt $ do
           resumeThread (piThread pi)
-
           -- The program is now running
-
           let handle = piProcess pi
           let millisecs = secs * 1000
-          rc <- waitForSingleObject handle (fromIntegral millisecs)
-          if rc == cWAIT_TIMEOUT
+          rc <- waitForJobCompletion job ioPort (fromIntegral millisecs)
+          closeHandle ioPort
+
+          if not rc
               then do terminateJobObject job 99
+                      closeHandle job
                       exitWith (ExitFailure 99)
               else alloca $ \p_exitCode ->
-                    do r <- getExitCodeProcess handle p_exitCode
+                    do terminateJobObject job 0 -- Ensure it's all really dead.
+                       closeHandle job
+                       r <- getExitCodeProcess handle p_exitCode
                        if r then do ec <- peek p_exitCode
                                     let ec' = if ec == 0
                                               then ExitSuccess