From 30e0aea8775744733571f48b7bac3c8ea8b1b9d0 Mon Sep 17 00:00:00 2001
From: Mehdi Abaakouk <mehdi.abaakouk@enovance.com>
Date: Tue, 30 Sep 2014 13:45:40 +0200
Subject: [PATCH] Notification listener pools

We can now set the pool name of a notification listener
to create multiple groups/pools of listeners consuming notifications
and that each group/pool only receives one copy of each notification.

The AMQP implementation of that is to set queue_name with the pool name.

Implements blueprint notification-listener-pools
Closes-bug: #1356226

Change-Id: I8dc0549f5550f684a261c78c58737b798fcdd656
---
 oslo/messaging/_drivers/amqpdriver.py         |   4 +-
 oslo/messaging/_drivers/impl_fake.py          |  37 +++--
 oslo/messaging/_drivers/impl_zmq.py           |   6 +-
 .../_drivers/protocols/amqp/driver.py         |   5 +-
 oslo/messaging/_executors/impl_blocking.py    |   5 +-
 oslo/messaging/notify/dispatcher.py           |   7 +-
 oslo/messaging/notify/listener.py             |  14 +-
 oslo/messaging/transport.py                   |   5 +-
 tests/executors/test_executor.py              |   2 +-
 tests/notify/test_dispatcher.py               |   4 +-
 tests/notify/test_listener.py                 | 133 +++++++++++++++---
 11 files changed, 177 insertions(+), 45 deletions(-)

diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py
index 8f63ec8a7..767e5e2ca 100644
--- a/oslo/messaging/_drivers/amqpdriver.py
+++ b/oslo/messaging/_drivers/amqpdriver.py
@@ -431,7 +431,7 @@ class AMQPDriverBase(base.BaseDriver):
 
         return listener
 
-    def listen_for_notifications(self, targets_and_priorities):
+    def listen_for_notifications(self, targets_and_priorities, pool):
         conn = self._get_connection(pooled=False)
 
         listener = AMQPListener(self, conn)
@@ -439,7 +439,7 @@ class AMQPDriverBase(base.BaseDriver):
             conn.declare_topic_consumer(
                 exchange_name=self._get_exchange(target),
                 topic='%s.%s' % (target.topic, priority),
-                callback=listener)
+                callback=listener, queue_name=pool)
         return listener
 
     def cleanup(self):
diff --git a/oslo/messaging/_drivers/impl_fake.py b/oslo/messaging/_drivers/impl_fake.py
index dfce5a4a8..457ef0a22 100644
--- a/oslo/messaging/_drivers/impl_fake.py
+++ b/oslo/messaging/_drivers/impl_fake.py
@@ -13,6 +13,7 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import copy
 import json
 import threading
 import time
@@ -40,16 +41,17 @@ class FakeIncomingMessage(base.IncomingMessage):
 
 class FakeListener(base.Listener):
 
-    def __init__(self, driver, exchange_manager, targets):
+    def __init__(self, driver, exchange_manager, targets, pool=None):
         super(FakeListener, self).__init__(driver)
         self._exchange_manager = exchange_manager
         self._targets = targets
+        self._pool = pool
 
         # NOTE(sileht): Ensure that all needed queues exists even the listener
         # have not been polled yet
         for target in self._targets:
             exchange = self._exchange_manager.get_exchange(target.exchange)
-            exchange.ensure_queue(target)
+            exchange.ensure_queue(target, pool)
 
     def poll(self, timeout=None):
         if timeout is not None:
@@ -59,7 +61,8 @@ class FakeListener(base.Listener):
         while True:
             for target in self._targets:
                 exchange = self._exchange_manager.get_exchange(target.exchange)
-                (ctxt, message, reply_q, requeue) = exchange.poll(target)
+                (ctxt, message, reply_q, requeue) = exchange.poll(target,
+                                                                  self._pool)
                 if message is not None:
                     message = FakeIncomingMessage(self, ctxt, message,
                                                   reply_q, requeue)
@@ -83,15 +86,21 @@ class FakeExchange(object):
         self._topic_queues = {}
         self._server_queues = {}
 
-    def ensure_queue(self, target):
+    def ensure_queue(self, target, pool):
         with self._queues_lock:
             if target.server:
                 self._get_server_queue(target.topic, target.server)
             else:
-                self._get_topic_queue(target.topic)
+                self._get_topic_queue(target.topic, pool)
 
-    def _get_topic_queue(self, topic):
-        return self._topic_queues.setdefault(topic, [])
+    def _get_topic_queue(self, topic, pool=None):
+        if pool and (topic, pool) not in self._topic_queues:
+            # NOTE(sileht): if the pool name is set, we need to
+            # copy all the already delivered messages from the
+            # default queue to this queue
+            self._topic_queues[(topic, pool)] = copy.deepcopy(
+                self._get_topic_queue(topic))
+        return self._topic_queues.setdefault((topic, pool), [])
 
     def _get_server_queue(self, topic, server):
         return self._server_queues.setdefault((topic, server), [])
@@ -105,7 +114,11 @@ class FakeExchange(object):
             elif server is not None:
                 queues = [self._get_server_queue(topic, server)]
             else:
-                queues = [self._get_topic_queue(topic)]
+                # NOTE(sileht): ensure at least the queue without
+                # pool name exists
+                self._get_topic_queue(topic)
+                queues = [q for t, q in self._topic_queues.items()
+                          if t[0] == topic]
 
             def requeue():
                 self.deliver_message(topic, ctxt, message, server=server,
@@ -114,12 +127,12 @@ class FakeExchange(object):
             for queue in queues:
                 queue.append((ctxt, message, reply_q, requeue))
 
-    def poll(self, target):
+    def poll(self, target, pool):
         with self._queues_lock:
             if target.server:
                 queue = self._get_server_queue(target.topic, target.server)
             else:
-                queue = self._get_topic_queue(target.topic)
+                queue = self._get_topic_queue(target.topic, pool)
             return queue.pop(0) if queue else (None, None, None, None)
 
 
@@ -208,11 +221,11 @@ class FakeDriver(base.BaseDriver):
                                                   exchange=exchange)])
         return listener
 
-    def listen_for_notifications(self, targets_and_priorities):
+    def listen_for_notifications(self, targets_and_priorities, pool):
         targets = [messaging.Target(topic='%s.%s' % (target.topic, priority),
                                     exchange=target.exchange)
                    for target, priority in targets_and_priorities]
-        listener = FakeListener(self, self._exchange_manager, targets)
+        listener = FakeListener(self, self._exchange_manager, targets, pool)
 
         return listener
 
diff --git a/oslo/messaging/_drivers/impl_zmq.py b/oslo/messaging/_drivers/impl_zmq.py
index a5b9bba8e..b6dccc3aa 100644
--- a/oslo/messaging/_drivers/impl_zmq.py
+++ b/oslo/messaging/_drivers/impl_zmq.py
@@ -937,9 +937,11 @@ class ZmqDriver(base.BaseDriver):
 
         return listener
 
-    def listen_for_notifications(self, targets_and_priorities):
+    def listen_for_notifications(self, targets_and_priorities, pool):
         # NOTE(sileht): this listener implementation is limited
-        # because zeromq doesn't support requeing message
+        # because zeromq doesn't support:
+        #  * requeing message
+        #  * pool
         conn = Connection(self.conf)
 
         listener = ZmqListener(self)
diff --git a/oslo/messaging/_drivers/protocols/amqp/driver.py b/oslo/messaging/_drivers/protocols/amqp/driver.py
index 455e32382..645a05ae0 100644
--- a/oslo/messaging/_drivers/protocols/amqp/driver.py
+++ b/oslo/messaging/_drivers/protocols/amqp/driver.py
@@ -265,8 +265,11 @@ class ProtonDriver(base.BaseDriver):
         return listener
 
     @_ensure_connect_called
-    def listen_for_notifications(self, targets_and_priorities):
+    def listen_for_notifications(self, targets_and_priorities, pool):
         LOG.debug("Listen for notifications %s", targets_and_priorities)
+        if pool:
+            raise NotImplementedError('"pool" not implemented by'
+                                      'this transport driver')
         listener = ProtonListener(self)
         for target, priority in targets_and_priorities:
             topic = '%s.%s' % (target.topic, priority)
diff --git a/oslo/messaging/_executors/impl_blocking.py b/oslo/messaging/_executors/impl_blocking.py
index 8e463a0c7..68c1632e1 100644
--- a/oslo/messaging/_executors/impl_blocking.py
+++ b/oslo/messaging/_executors/impl_blocking.py
@@ -36,7 +36,10 @@ class BlockingExecutor(base.ExecutorBase):
     def start(self):
         self._running = True
         while self._running:
-            with self.dispatcher(self.listener.poll()) as callback:
+            message = self.listener.poll(timeout=0.01)
+            if not message:
+                continue
+            with self.dispatcher(message) as callback:
                 callback()
 
     def stop(self):
diff --git a/oslo/messaging/notify/dispatcher.py b/oslo/messaging/notify/dispatcher.py
index f3944095b..549e2829e 100644
--- a/oslo/messaging/notify/dispatcher.py
+++ b/oslo/messaging/notify/dispatcher.py
@@ -44,11 +44,13 @@ class NotificationDispatcher(object):
     message to the endpoints
     """
 
-    def __init__(self, targets, endpoints, serializer, allow_requeue):
+    def __init__(self, targets, endpoints, serializer, allow_requeue,
+                 pool=None):
         self.targets = targets
         self.endpoints = endpoints
         self.serializer = serializer or msg_serializer.NoOpSerializer()
         self.allow_requeue = allow_requeue
+        self.pool = pool
 
         self._callbacks_by_priority = {}
         for endpoint, prio in itertools.product(endpoints, PRIORITIES):
@@ -61,7 +63,8 @@ class NotificationDispatcher(object):
                                                          priorities))
 
     def _listen(self, transport):
-        return transport._listen_for_notifications(self._targets_priorities)
+        return transport._listen_for_notifications(self._targets_priorities,
+                                                   pool=self.pool)
 
     @contextlib.contextmanager
     def __call__(self, incoming):
diff --git a/oslo/messaging/notify/listener.py b/oslo/messaging/notify/listener.py
index feaf90e66..4e7c184ff 100644
--- a/oslo/messaging/notify/listener.py
+++ b/oslo/messaging/notify/listener.py
@@ -61,7 +61,9 @@ A simple example of a notification listener with multiple endpoints might be::
         NotificationEndpoint(),
         ErrorEndpoint(),
     ]
-    server = messaging.get_notification_listener(transport, targets, endpoints)
+    pool = "listener-workers"
+    server = messaging.get_notification_listener(transport, targets, endpoints,
+                                                 pool)
     server.start()
     server.wait()
 
@@ -78,6 +80,10 @@ and a timestamp.
 By supplying a serializer object, a listener can deserialize a request context
 and arguments from - and serialize return values to - primitive types.
 
+By supplying a pool name you can create multiple groups of listeners consuming
+notifications and that each group only receives one copy of each
+notification.
+
 An endpoint method can explicitly return messaging.NotificationResult.HANDLED
 to acknowledge a message or messaging.NotificationResult.REQUEUE to requeue the
 message.
@@ -97,7 +103,7 @@ from oslo.messaging import server as msg_server
 
 def get_notification_listener(transport, targets, endpoints,
                               executor='blocking', serializer=None,
-                              allow_requeue=False):
+                              allow_requeue=False, pool=None):
     """Construct a notification listener
 
     The executor parameter controls how incoming messages will be received and
@@ -117,10 +123,12 @@ def get_notification_listener(transport, targets, endpoints,
     :type serializer: Serializer
     :param allow_requeue: whether NotificationResult.REQUEUE support is needed
     :type allow_requeue: bool
+    :param pool: the pool name
+    :type pool: str
     :raises: NotImplementedError
     """
     transport._require_driver_features(requeue=allow_requeue)
     dispatcher = notify_dispatcher.NotificationDispatcher(targets, endpoints,
                                                           serializer,
-                                                          allow_requeue)
+                                                          allow_requeue, pool)
     return msg_server.MessageHandlingServer(transport, dispatcher, executor)
diff --git a/oslo/messaging/transport.py b/oslo/messaging/transport.py
index 9c9844b06..cfe47f8f8 100644
--- a/oslo/messaging/transport.py
+++ b/oslo/messaging/transport.py
@@ -103,13 +103,14 @@ class Transport(object):
                                            target)
         return self._driver.listen(target)
 
-    def _listen_for_notifications(self, targets_and_priorities):
+    def _listen_for_notifications(self, targets_and_priorities, pool):
         for target, priority in targets_and_priorities:
             if not target.topic:
                 raise exceptions.InvalidTarget('A target must have '
                                                'topic specified',
                                                target)
-        return self._driver.listen_for_notifications(targets_and_priorities)
+        return self._driver.listen_for_notifications(
+            targets_and_priorities, pool)
 
     def cleanup(self):
         """Release all resources associated with this transport."""
diff --git a/tests/executors/test_executor.py b/tests/executors/test_executor.py
index 6b5b9bd0b..ef247c288 100644
--- a/tests/executors/test_executor.py
+++ b/tests/executors/test_executor.py
@@ -71,7 +71,7 @@ class TestExecutor(test_utils.BaseTestCase):
         incoming_message = mock.MagicMock(ctxt={},
                                           message={'payload': 'data'})
 
-        def fake_poll():
+        def fake_poll(timeout=None):
             if self.stop_before_return:
                 executor.stop()
                 return incoming_message
diff --git a/tests/notify/test_dispatcher.py b/tests/notify/test_dispatcher.py
index 36f74078a..dacc6dde5 100644
--- a/tests/notify/test_dispatcher.py
+++ b/tests/notify/test_dispatcher.py
@@ -98,7 +98,7 @@ class TestDispatcher(test_utils.BaseTestCase):
 
         targets = [messaging.Target(topic='notifications')]
         dispatcher = notify_dispatcher.NotificationDispatcher(
-            targets, endpoints, None, allow_requeue=True)
+            targets, endpoints, None, allow_requeue=True, pool=None)
 
         # check it listen on wanted topics
         self.assertEqual(sorted(set((targets[0], prio)
@@ -142,7 +142,7 @@ class TestDispatcher(test_utils.BaseTestCase):
         msg = notification_msg.copy()
         msg['priority'] = 'what???'
         dispatcher = notify_dispatcher.NotificationDispatcher(
-            [mock.Mock()], [mock.Mock()], None, allow_requeue=True)
+            [mock.Mock()], [mock.Mock()], None, allow_requeue=True, pool=None)
         with dispatcher(mock.Mock(ctxt={}, message=msg)) as callback:
             callback()
         mylog.warning.assert_called_once_with('Unknown priority "%s"',
diff --git a/tests/notify/test_listener.py b/tests/notify/test_listener.py
index 2186d47a0..70643a7ac 100644
--- a/tests/notify/test_listener.py
+++ b/tests/notify/test_listener.py
@@ -28,39 +28,60 @@ load_tests = testscenarios.load_tests_apply_scenarios
 
 class ListenerSetupMixin(object):
 
-    class Listener(object):
-        def __init__(self, transport, targets, endpoints, expect_messages):
+    class ListenerTracker(object):
+        def __init__(self, expect_messages):
             self._expect_messages = expect_messages
             self._received_msgs = 0
-            self._listener = messaging.get_notification_listener(
-                transport, targets, [self] + endpoints, allow_requeue=True)
+            self.listeners = []
 
         def info(self, ctxt, publisher_id, event_type, payload, metadata):
             self._received_msgs += 1
             if self._expect_messages == self._received_msgs:
-                # Check start() does nothing with a running listener
-                self._listener.start()
-                self._listener.stop()
-                self._listener.wait()
+                self.stop()
 
-        def start(self):
-            self._listener.start()
+        def stop(self):
+            for listener in self.listeners:
+                # Check start() does nothing with a running listener
+                listener.start()
+                listener.stop()
+                listener.wait()
+            self.listeners = []
+
+    def setUp(self):
+        self.trackers = {}
+        self.addCleanup(self._stop_trackers)
+
+    def _stop_trackers(self):
+        for pool in self.trackers:
+            self.trackers[pool].stop()
+        self.trackers = {}
 
     def _setup_listener(self, transport, endpoints, expect_messages,
-                        targets=None):
-        listener = self.Listener(transport,
-                                 targets=targets or [
-                                     messaging.Target(topic='testtopic')],
-                                 expect_messages=expect_messages,
-                                 endpoints=endpoints)
+                        targets=None, pool=None):
+
+        if pool is None:
+            tracker_name = '__default__'
+        else:
+            tracker_name = pool
+
+        if targets is None:
+            targets = [messaging.Target(topic='testtopic')]
+
+        tracker = self.trackers.setdefault(
+            tracker_name, self.ListenerTracker(expect_messages))
+        listener = messaging.get_notification_listener(
+            transport, targets=targets, endpoints=[tracker] + endpoints,
+            allow_requeue=True, pool=pool)
+        tracker.listeners.append(listener)
 
         thread = threading.Thread(target=listener.start)
         thread.daemon = True
+        thread.listener = listener
         thread.start()
         return thread
 
     def _stop_listener(self, thread):
-        thread.join(timeout=5)
+        thread.join(timeout=15)
         return thread.isAlive()
 
     def _setup_notifier(self, transport, topic='testtopic',
@@ -78,6 +99,7 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
 
     def setUp(self):
         super(TestNotifyListener, self).setUp(conf=cfg.ConfigOpts())
+        ListenerSetupMixin.setUp(self)
 
     def test_constructor(self):
         transport = messaging.get_transport(self.conf, url='fake:')
@@ -250,3 +272,80 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
                       {'timestamp': mock.ANY, 'message_id': mock.ANY}),
             mock.call({}, 'testpublisher', 'an_event.start', 'test',
                       {'timestamp': mock.ANY, 'message_id': mock.ANY})])
+
+    def test_two_pools(self):
+        transport = messaging.get_transport(self.conf, url='fake:')
+
+        endpoint1 = mock.Mock()
+        endpoint1.info.return_value = None
+        endpoint2 = mock.Mock()
+        endpoint2.info.return_value = None
+
+        targets = [messaging.Target(topic="topic")]
+        listener1_thread = self._setup_listener(transport, [endpoint1], 2,
+                                                targets=targets, pool="pool1")
+        listener2_thread = self._setup_listener(transport, [endpoint2], 2,
+                                                targets=targets, pool="pool2")
+
+        notifier = self._setup_notifier(transport, topic="topic")
+        notifier.info({'ctxt': '0'}, 'an_event.start', 'test message0')
+        notifier.info({'ctxt': '1'}, 'an_event.start', 'test message1')
+
+        self.assertFalse(self._stop_listener(listener2_thread))
+        self.assertFalse(self._stop_listener(listener1_thread))
+
+        def mocked_endpoint_call(i):
+            return mock.call({'ctxt': '%d' % i}, 'testpublisher',
+                             'an_event.start', 'test message%d' % i,
+                             {'timestamp': mock.ANY, 'message_id': mock.ANY})
+
+        endpoint1.info.assert_has_calls([mocked_endpoint_call(0),
+                                         mocked_endpoint_call(1)])
+        endpoint2.info.assert_has_calls([mocked_endpoint_call(0),
+                                         mocked_endpoint_call(1)])
+
+    def test_two_pools_three_listener(self):
+        transport = messaging.get_transport(self.conf, url='fake:')
+
+        endpoint1 = mock.Mock()
+        endpoint1.info.return_value = None
+        endpoint2 = mock.Mock()
+        endpoint2.info.return_value = None
+        endpoint3 = mock.Mock()
+        endpoint3.info.return_value = None
+
+        targets = [messaging.Target(topic="topic")]
+        listener1_thread = self._setup_listener(transport, [endpoint1], 100,
+                                                targets=targets, pool="pool1")
+        listener2_thread = self._setup_listener(transport, [endpoint2], 100,
+                                                targets=targets, pool="pool2")
+        listener3_thread = self._setup_listener(transport, [endpoint3], 100,
+                                                targets=targets, pool="pool2")
+
+        def mocked_endpoint_call(i):
+            return mock.call({'ctxt': '%d' % i}, 'testpublisher',
+                             'an_event.start', 'test message%d' % i,
+                             {'timestamp': mock.ANY, 'message_id': mock.ANY})
+
+        notifier = self._setup_notifier(transport, topic="topic")
+        mocked_endpoint1_calls = []
+        for i in range(0, 100):
+            notifier.info({'ctxt': '%d' % i}, 'an_event.start',
+                          'test message%d' % i)
+            mocked_endpoint1_calls.append(mocked_endpoint_call(i))
+
+        self.assertFalse(self._stop_listener(listener3_thread))
+        self.assertFalse(self._stop_listener(listener2_thread))
+        self.assertFalse(self._stop_listener(listener1_thread))
+
+        self.assertEqual(100, endpoint1.info.call_count)
+        endpoint1.info.assert_has_calls(mocked_endpoint1_calls)
+
+        self.assertNotEqual(0, endpoint2.info.call_count)
+        self.assertNotEqual(0, endpoint3.info.call_count)
+
+        self.assertEqual(100, endpoint2.info.call_count +
+                         endpoint3.info.call_count)
+        for call in mocked_endpoint1_calls:
+            self.assertIn(call, endpoint2.info.mock_calls +
+                          endpoint3.info.mock_calls)