From 3f3c489aafc1461835b6c266c8c1e742d88d725b Mon Sep 17 00:00:00 2001
From: Matthew Booth <mbooth@redhat.com>
Date: Mon, 19 Oct 2015 13:04:37 +0100
Subject: [PATCH] Fix a race calling blocking MessageHandlingServer.start()

This fixes a race due to the quirkiness of the blocking executor. The
blocking executor does not create a separate thread, but is instead
explicitly executed in the calling thread. Other threads will,
however, continue to interact with it.

In the non-blocking case, the executor will have done certain
initialisation in start() before starting a worker thread and
returning control to the caller. That is, the caller can be sure that
this initialisation has occurred when control is returned. However, in
the blocking case, control is never returned. We currently work round
this by setting self._running to True before executing executor.start,
and by not doing any locking whatsoever in MessageHandlingServer.
However, this current means there is a race whereby executor.stop()
can run before executor.start(). This is fragile and extremely
difficult to reason about robustly, if not currently broken.

The solution is to split the initialisation from the execution in the
blocking case. executor.start() is no longer a blocking operation for
the blocking executor. As for the non-blocking case, executor.start()
returns as soon as initialisation is complete, indicating that it is
safe to subsequently call stop(). Actual execution is done explicitly
via the new execute() method, which blocks.

In doing this, we also make FakeBlockingThread a more complete
implementation of threading.Thread. This fixes a related issue in
that, previously, calling server.wait() on a blocking executor from
another thread would not wait for the completion of the executor. This
has a knock-on effect in test_server's ServerSetupMixin. This mixin
created an endpoint with a stop method which called server.stop().
However, as this is executed by the executor, and also joins the
executor thread, which is now blocking, this results in a deadlock. I
am satisfied that, in general, this is not a sane thing to do.
However, it is useful for these tests. We fix the tests by making the
stop method non-blocking, and do the actual stop and wait calls from
the main thread.

Change-Id: I0d332f74c06c22b44179319432153e15b69f2f45
---
 oslo_messaging/_executors/impl_blocking.py    | 68 ++++++++++++++++---
 oslo_messaging/server.py                      |  5 +-
 .../tests/executors/test_executor.py          | 15 +++-
 oslo_messaging/tests/rpc/test_server.py       | 47 +++++++++----
 4 files changed, 108 insertions(+), 27 deletions(-)

diff --git a/oslo_messaging/_executors/impl_blocking.py b/oslo_messaging/_executors/impl_blocking.py
index b59818f5c..b788c47f4 100644
--- a/oslo_messaging/_executors/impl_blocking.py
+++ b/oslo_messaging/_executors/impl_blocking.py
@@ -14,28 +14,57 @@
 #    under the License.
 
 import futurist
+import threading
 
 from oslo_messaging._executors import impl_pooledexecutor
+from oslo_utils import timeutils
 
 
 class FakeBlockingThread(object):
+    '''A minimal implementation of threading.Thread which does not create a
+    thread or start executing the target when start() is called. Instead, the
+    caller must explicitly execute the non-blocking thread.execute() method
+    after start() has been called.
+    '''
+
     def __init__(self, target):
         self._target = target
+        self._running = False
+        self._running_cond = threading.Condition()
 
     def start(self):
-        self._target()
+        if self._running:
+            # Not a user error. No need to translate.
+            raise RuntimeError('FakeBlockingThread already started')
 
-    @staticmethod
-    def join(timeout=None):
-        pass
+        with self._running_cond:
+            self._running = True
+            self._running_cond.notify_all()
 
-    @staticmethod
-    def stop():
-        pass
+    def join(self, timeout=None):
+        with timeutils.StopWatch(duration=timeout) as w, self._running_cond:
+            while self._running:
+                self._running_cond.wait(w.leftover(return_none=True))
 
-    @staticmethod
-    def is_alive():
-        return False
+                # Thread.join() does not raise an exception on timeout. It is
+                # the caller's responsibility to check is_alive().
+                if w.expired():
+                    return
+
+    def is_alive(self):
+        return self._running
+
+    def execute(self):
+        if not self._running:
+            # Not a user error. No need to translate.
+            raise RuntimeError('FakeBlockingThread not started')
+
+        try:
+            self._target()
+        finally:
+            with self._running_cond:
+                self._running = False
+                self._running_cond.notify_all()
 
 
 class BlockingExecutor(impl_pooledexecutor.PooledExecutor):
@@ -52,3 +81,22 @@ class BlockingExecutor(impl_pooledexecutor.PooledExecutor):
 
     _executor_cls = lambda __, ___: futurist.SynchronousExecutor()
     _thread_cls = FakeBlockingThread
+
+    def __init__(self, *args, **kwargs):
+        super(BlockingExecutor, self).__init__(*args, **kwargs)
+
+    def execute(self):
+        '''Explicitly run the executor in the current context.'''
+        # NOTE(mdbooth): Splitting start into start and execute for the
+        # blocking executor closes a potential race. On a non-blocking
+        # executor, calling start performs some initialisation synchronously
+        # before starting the executor and returning control to the caller. In
+        # the non-blocking caller there was no externally visible boundary
+        # between the completion of initialisation and the start of execution,
+        # meaning the caller cannot indicate to another thread that
+        # initialisation is complete. With the split, the start call for the
+        # blocking executor becomes analogous to the non-blocking case,
+        # indicating that initialisation is complete. The caller can then
+        # synchronously call execute.
+        if self._poller is not None:
+            self._poller.execute()
diff --git a/oslo_messaging/server.py b/oslo_messaging/server.py
index 02bae191a..491ccbf52 100644
--- a/oslo_messaging/server.py
+++ b/oslo_messaging/server.py
@@ -140,12 +140,15 @@ class MessageHandlingServer(service.ServiceBase):
                 listener = self.dispatcher._listen(self.transport)
             except driver_base.TransportDriverError as ex:
                 raise ServerListenError(self.target, ex)
-            self._running = True
             self._executor_obj = self._executor_cls(self.conf, listener,
                                                     self.dispatcher)
             self._executor_obj.start()
+            self._running = True
             self._state_cond.notify_all()
 
+        if self.executor == 'blocking':
+            self._executor_obj.execute()
+
     def stop(self):
         """Stop handling incoming messages.
 
diff --git a/oslo_messaging/tests/executors/test_executor.py b/oslo_messaging/tests/executors/test_executor.py
index 007d3ac6a..1e175fdf8 100644
--- a/oslo_messaging/tests/executors/test_executor.py
+++ b/oslo_messaging/tests/executors/test_executor.py
@@ -81,6 +81,12 @@ class TestExecutor(test_utils.BaseTestCase):
             aioeventlet_class = None
         is_aioeventlet = (self.executor == aioeventlet_class)
 
+        if impl_blocking is not None:
+            blocking_class = impl_blocking.BlockingExecutor
+        else:
+            blocking_class = None
+        is_blocking = (self.executor == blocking_class)
+
         if is_aioeventlet:
             policy = aioeventlet.EventLoopPolicy()
             trollius.set_event_loop_policy(policy)
@@ -110,8 +116,15 @@ class TestExecutor(test_utils.BaseTestCase):
 
             endpoint = mock.MagicMock(return_value=simple_coroutine('result'))
             event = eventlet.event.Event()
-        else:
+        elif is_blocking:
+            def run_executor(executor):
+                executor.start()
+                executor.execute()
+                executor.wait()
 
+            endpoint = mock.MagicMock(return_value='result')
+            event = None
+        else:
             def run_executor(executor):
                 executor.start()
                 executor.wait()
diff --git a/oslo_messaging/tests/rpc/test_server.py b/oslo_messaging/tests/rpc/test_server.py
index 9a2b53b24..258dacb24 100644
--- a/oslo_messaging/tests/rpc/test_server.py
+++ b/oslo_messaging/tests/rpc/test_server.py
@@ -27,22 +27,38 @@ load_tests = testscenarios.load_tests_apply_scenarios
 
 class ServerSetupMixin(object):
 
-    class Server(object):
+    class Server(threading.Thread):
         def __init__(self, transport, topic, server, endpoint, serializer):
+            self.controller = ServerSetupMixin.ServerController()
             target = oslo_messaging.Target(topic=topic, server=server)
-            self._server = oslo_messaging.get_rpc_server(transport,
-                                                         target,
-                                                         [endpoint, self],
-                                                         serializer=serializer)
+            self.server = oslo_messaging.get_rpc_server(transport,
+                                                        target,
+                                                        [endpoint,
+                                                         self.controller],
+                                                        serializer=serializer)
+
+            super(ServerSetupMixin.Server, self).__init__()
+            self.daemon = True
+
+        def wait(self):
+            # Wait for the executor to process the stop message, indicating all
+            # test messages have been processed
+            self.controller.stopped.wait()
+
+            # Check start() does nothing with a running server
+            self.server.start()
+            self.server.stop()
+            self.server.wait()
+
+        def run(self):
+            self.server.start()
+
+    class ServerController(object):
+        def __init__(self):
+            self.stopped = threading.Event()
 
         def stop(self, ctxt):
-            # Check start() does nothing with a running server
-            self._server.start()
-            self._server.stop()
-            self._server.wait()
-
-        def start(self):
-            self._server.start()
+            self.stopped.set()
 
     class TestSerializer(object):
 
@@ -72,13 +88,14 @@ class ServerSetupMixin(object):
         thread.daemon = True
         thread.start()
 
-        return thread
+        return server
 
-    def _stop_server(self, client, server_thread, topic=None):
+    def _stop_server(self, client, server, topic=None):
         if topic is not None:
             client = client.prepare(topic=topic)
         client.cast({}, 'stop')
-        server_thread.join(timeout=30)
+        server.wait()
+
 
     def _setup_client(self, transport, topic='testtopic'):
         return oslo_messaging.RPCClient(transport,