diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index f9f1b06e8..f0028251e 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -127,6 +127,10 @@ class AMQPListener(base.Listener): else: self.conn.consume(limit=1) + def cleanup(self): + # Closes listener connection + self.conn.close() + class ReplyWaiters(object): diff --git a/oslo/messaging/_drivers/base.py b/oslo/messaging/_drivers/base.py index 2b0303ec6..ec24460f5 100644 --- a/oslo/messaging/_drivers/base.py +++ b/oslo/messaging/_drivers/base.py @@ -59,6 +59,15 @@ class Listener(object): ending. """ + def cleanup(self): + """Cleanup listener. + Close connection used by listener if any. For some listeners like + zmq there is no connection so no need to close connection. + As this is listener specific method, overwrite it in to derived class + if cleanup of listener required. + """ + pass + @six.add_metaclass(abc.ABCMeta) class BaseDriver(object): diff --git a/oslo/messaging/server.py b/oslo/messaging/server.py index 9d354e4b6..f36ab314a 100644 --- a/oslo/messaging/server.py +++ b/oslo/messaging/server.py @@ -140,4 +140,7 @@ class MessageHandlingServer(object): """ if self._executor is not None: self._executor.wait() + # Close listener connection after processing all messages + self._executor.listener.cleanup() + self._executor = None diff --git a/tests/rpc/test_server.py b/tests/rpc/test_server.py index 6ba84a1a5..b4bdcccec 100644 --- a/tests/rpc/test_server.py +++ b/tests/rpc/test_server.py @@ -15,6 +15,7 @@ import threading +import mock import testscenarios from oslo.config import cfg @@ -110,6 +111,25 @@ class TestRPCServer(test_utils.BaseTestCase, ServerSetupMixin): self.assertIs(server.dispatcher.serializer, serializer) self.assertEqual('blocking', server.executor) + def test_server_wait_method(self): + transport = messaging.get_transport(self.conf, url='fake:') + target = messaging.Target(topic='foo', server='bar') + endpoints = [object()] + serializer = object() + + server = messaging.get_rpc_server(transport, target, endpoints, + serializer=serializer) + # Mocking executor + server._executor = mock.Mock() + # Here assigning executor's listener object to listener variable + # before calling wait method, beacuse in wait method we are + # setting executor to None. + listener = server._executor.listener + # call server wait method + server.wait() + self.assertIsNone(server._executor) + self.assertEqual(1, listener.cleanup.call_count) + def test_no_target_server(self): transport = messaging.get_transport(self.conf, url='fake:')