From f4da21353956def06cb40e790b3a6f5275a68814 Mon Sep 17 00:00:00 2001
From: Mehdi Abaakouk <mehdi.abaakouk@enovance.com>
Date: Thu, 24 Apr 2014 12:04:20 +0200
Subject: [PATCH] Remove amqp default exchange hack

This change remove the hack to set the default exchange of a transport in the
amqp driver, by removing the usage of the configuration object to get the
default exchange in rabbit and qpid driver, and instead use the value
passed to the driver constructor into all amqp publishers and consumers
class/method that needs it.

Closes-bug: #1256345
Change-Id: Iba54ca79a49f8545854205c1451b2403735c1006
---
 oslo/messaging/_drivers/amqp.py        |  4 --
 oslo/messaging/_drivers/amqpdriver.py  | 30 ++++----
 oslo/messaging/_drivers/impl_qpid.py   | 40 +++++------
 oslo/messaging/_drivers/impl_rabbit.py | 34 ++++-----
 tests/test_qpid.py                     | 97 ++++++++++++++++++++++++--
 5 files changed, 145 insertions(+), 60 deletions(-)

diff --git a/oslo/messaging/_drivers/amqp.py b/oslo/messaging/_drivers/amqp.py
index 325c3609a..e0d33b393 100644
--- a/oslo/messaging/_drivers/amqp.py
+++ b/oslo/messaging/_drivers/amqp.py
@@ -249,7 +249,3 @@ def _add_unique_id(msg):
     unique_id = uuid.uuid4().hex
     msg.update({UNIQUE_ID: unique_id})
     LOG.debug('UNIQUE_ID is %s.' % (unique_id))
-
-
-def get_control_exchange(conf):
-    return conf.control_exchange
diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py
index 16626d0b2..d990e90a7 100644
--- a/oslo/messaging/_drivers/amqpdriver.py
+++ b/oslo/messaging/_drivers/amqpdriver.py
@@ -297,10 +297,6 @@ class AMQPDriverBase(base.BaseDriver):
 
         self._default_exchange = default_exchange
 
-        # FIXME(markmc): temp hack
-        if self._default_exchange:
-            self.conf.set_override('control_exchange', self._default_exchange)
-
         self._connection_pool = connection_pool
 
         self._reply_q_lock = threading.Lock()
@@ -308,6 +304,9 @@ class AMQPDriverBase(base.BaseDriver):
         self._reply_q_conn = None
         self._waiter = None
 
+    def _get_exchange(self, target):
+        return target.exchange or self._default_exchange
+
     def _get_connection(self, pooled=True):
         return rpc_amqp.ConnectionContext(self.conf,
                                           self._url,
@@ -364,14 +363,16 @@ class AMQPDriverBase(base.BaseDriver):
         try:
             with self._get_connection() as conn:
                 if notify:
-                    conn.notify_send(target.topic, msg)
+                    conn.notify_send(self._get_exchange(target),
+                                     target.topic, msg)
                 elif target.fanout:
                     conn.fanout_send(target.topic, msg)
                 else:
                     topic = target.topic
                     if target.server:
                         topic = '%s.%s' % (target.topic, target.server)
-                    conn.topic_send(topic, msg, timeout=timeout)
+                    conn.topic_send(exchange_name=self._get_exchange(target),
+                                    topic=topic, msg=msg, timeout=timeout)
 
             if wait_for_reply:
                 result = self._waiter.wait(msg_id, timeout)
@@ -394,9 +395,13 @@ class AMQPDriverBase(base.BaseDriver):
 
         listener = AMQPListener(self, conn)
 
-        conn.declare_topic_consumer(target.topic, listener)
-        conn.declare_topic_consumer('%s.%s' % (target.topic, target.server),
-                                    listener)
+        conn.declare_topic_consumer(exchange_name=self._get_exchange(target),
+                                    topic=target.topic,
+                                    callback=listener)
+        conn.declare_topic_consumer(exchange_name=self._get_exchange(target),
+                                    topic='%s.%s' % (target.topic,
+                                                     target.server),
+                                    callback=listener)
         conn.declare_fanout_consumer(target.topic, listener)
 
         return listener
@@ -406,9 +411,10 @@ class AMQPDriverBase(base.BaseDriver):
 
         listener = AMQPListener(self, conn)
         for target, priority in targets_and_priorities:
-            conn.declare_topic_consumer('%s.%s' % (target.topic, priority),
-                                        callback=listener,
-                                        exchange_name=target.exchange)
+            conn.declare_topic_consumer(
+                exchange_name=self._get_exchange(target),
+                topic='%s.%s' % (target.topic, priority),
+                callback=listener)
         return listener
 
     def cleanup(self):
diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py
index 10fd7207c..def074baf 100644
--- a/oslo/messaging/_drivers/impl_qpid.py
+++ b/oslo/messaging/_drivers/impl_qpid.py
@@ -248,8 +248,8 @@ class DirectConsumer(ConsumerBase):
 class TopicConsumer(ConsumerBase):
     """Consumer class for 'topic'."""
 
-    def __init__(self, conf, session, topic, callback, name=None,
-                 exchange_name=None):
+    def __init__(self, conf, session, topic, callback, exchange_name,
+                 name=None):
         """Init a 'topic' queue.
 
         :param session: the amqp session to use
@@ -259,7 +259,6 @@ class TopicConsumer(ConsumerBase):
         :param name: optional queue name, defaults to topic
         """
 
-        exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf)
         link_opts = {
             "auto-delete": conf.amqp_auto_delete,
             "durable": conf.amqp_durable_queues,
@@ -376,14 +375,14 @@ class Publisher(object):
 
 class DirectPublisher(Publisher):
     """Publisher class for 'direct'."""
-    def __init__(self, conf, session, msg_id):
+    def __init__(self, conf, session, topic):
         """Init a 'direct' publisher."""
 
         if conf.qpid_topology_version == 1:
-            node_name = "%s/%s" % (msg_id, msg_id)
+            node_name = "%s/%s" % (topic, topic)
             node_opts = {"type": "direct"}
         elif conf.qpid_topology_version == 2:
-            node_name = "amq.direct/%s" % msg_id
+            node_name = "amq.direct/%s" % topic
             node_opts = {}
         else:
             raise_invalid_topology_version(conf)
@@ -394,11 +393,9 @@ class DirectPublisher(Publisher):
 
 class TopicPublisher(Publisher):
     """Publisher class for 'topic'."""
-    def __init__(self, conf, session, topic):
+    def __init__(self, conf, session, exchange_name, topic):
         """Init a 'topic' publisher.
         """
-        exchange_name = rpc_amqp.get_control_exchange(conf)
-
         if conf.qpid_topology_version == 1:
             node_name = "%s/%s" % (exchange_name, topic)
         elif conf.qpid_topology_version == 2:
@@ -430,10 +427,9 @@ class FanoutPublisher(Publisher):
 
 class NotifyPublisher(Publisher):
     """Publisher class for notifications."""
-    def __init__(self, conf, session, topic):
+    def __init__(self, conf, session, exchange_name, topic):
         """Init a 'topic' publisher.
         """
-        exchange_name = rpc_amqp.get_control_exchange(conf)
         node_opts = {"durable": True}
 
         if conf.qpid_topology_version == 1:
@@ -618,7 +614,7 @@ class Connection(object):
                 raise StopIteration
             yield self.ensure(_error_callback, _consume)
 
-    def publisher_send(self, cls, topic, msg):
+    def publisher_send(self, cls, topic, msg, **kwargs):
         """Send to a publisher based on the publisher class."""
 
         def _connect_error(exc):
@@ -627,7 +623,7 @@ class Connection(object):
                           "'%(topic)s': %(err_str)s") % log_info)
 
         def _publisher_send():
-            publisher = cls(self.conf, self.session, topic)
+            publisher = cls(self.conf, self.session, topic=topic, **kwargs)
             publisher.send(msg)
 
         return self.ensure(_connect_error, _publisher_send)
@@ -639,8 +635,8 @@ class Connection(object):
         """
         self.declare_consumer(DirectConsumer, topic, callback)
 
-    def declare_topic_consumer(self, topic, callback=None, queue_name=None,
-                               exchange_name=None):
+    def declare_topic_consumer(self, exchange_name, topic, callback=None,
+                               queue_name=None):
         """Create a 'topic' consumer."""
         self.declare_consumer(functools.partial(TopicConsumer,
                                                 name=queue_name,
@@ -654,9 +650,9 @@ class Connection(object):
 
     def direct_send(self, msg_id, msg):
         """Send a 'direct' message."""
-        self.publisher_send(DirectPublisher, msg_id, msg)
+        self.publisher_send(DirectPublisher, topic=msg_id, msg=msg)
 
-    def topic_send(self, topic, msg, timeout=None):
+    def topic_send(self, exchange_name, topic, msg, timeout=None):
         """Send a 'topic' message."""
         #
         # We want to create a message with attributes, e.g. a TTL. We
@@ -669,15 +665,17 @@ class Connection(object):
         # will need to be altered accordingly.
         #
         qpid_message = qpid_messaging.Message(content=msg, ttl=timeout)
-        self.publisher_send(TopicPublisher, topic, qpid_message)
+        self.publisher_send(TopicPublisher, topic=topic, msg=qpid_message,
+                            exchange_name=exchange_name)
 
     def fanout_send(self, topic, msg):
         """Send a 'fanout' message."""
-        self.publisher_send(FanoutPublisher, topic, msg)
+        self.publisher_send(FanoutPublisher, topic=topic, msg=msg)
 
-    def notify_send(self, topic, msg, **kwargs):
+    def notify_send(self, exchange_name, topic, msg, **kwargs):
         """Send a notify message on a topic."""
-        self.publisher_send(NotifyPublisher, topic, msg)
+        self.publisher_send(NotifyPublisher, topic=topic, msg=msg,
+                            exchange_name=exchange_name)
 
     def consume(self, limit=None, timeout=None):
         """Consume from all queues/consumers."""
diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py
index 29d101a07..f7ca5e41c 100644
--- a/oslo/messaging/_drivers/impl_rabbit.py
+++ b/oslo/messaging/_drivers/impl_rabbit.py
@@ -247,8 +247,8 @@ class DirectConsumer(ConsumerBase):
 class TopicConsumer(ConsumerBase):
     """Consumer class for 'topic'."""
 
-    def __init__(self, conf, channel, topic, callback, tag, name=None,
-                 exchange_name=None, **kwargs):
+    def __init__(self, conf, channel, topic, callback, tag, exchange_name,
+                 name=None, **kwargs):
         """Init a 'topic' queue.
 
         :param channel: the amqp channel to use
@@ -256,6 +256,7 @@ class TopicConsumer(ConsumerBase):
         :paramtype topic: str
         :param callback: the callback to call when messages are received
         :param tag: a unique ID for the consumer on the channel
+        :param exchange_name: the exchange name to use
         :param name: optional queue name, defaults to topic
         :paramtype name: str
 
@@ -267,7 +268,6 @@ class TopicConsumer(ConsumerBase):
                    'auto_delete': conf.amqp_auto_delete,
                    'exclusive': False}
         options.update(kwargs)
-        exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf)
         exchange = kombu.entity.Exchange(name=exchange_name,
                                          type='topic',
                                          durable=options['durable'],
@@ -347,7 +347,7 @@ class Publisher(object):
 
 class DirectPublisher(Publisher):
     """Publisher class for 'direct'."""
-    def __init__(self, conf, channel, msg_id, **kwargs):
+    def __init__(self, conf, channel, topic, **kwargs):
         """Init a 'direct' publisher.
 
         Kombu options may be passed as keyword args to override defaults
@@ -357,13 +357,13 @@ class DirectPublisher(Publisher):
                    'auto_delete': True,
                    'exclusive': False}
         options.update(kwargs)
-        super(DirectPublisher, self).__init__(channel, msg_id, msg_id,
+        super(DirectPublisher, self).__init__(channel, topic, topic,
                                               type='direct', **options)
 
 
 class TopicPublisher(Publisher):
     """Publisher class for 'topic'."""
-    def __init__(self, conf, channel, topic, **kwargs):
+    def __init__(self, conf, channel, exchange_name, topic, **kwargs):
         """Init a 'topic' publisher.
 
         Kombu options may be passed as keyword args to override defaults
@@ -372,7 +372,6 @@ class TopicPublisher(Publisher):
                    'auto_delete': conf.amqp_auto_delete,
                    'exclusive': False}
         options.update(kwargs)
-        exchange_name = rpc_amqp.get_control_exchange(conf)
         super(TopicPublisher, self).__init__(channel,
                                              exchange_name,
                                              topic,
@@ -398,10 +397,11 @@ class FanoutPublisher(Publisher):
 class NotifyPublisher(TopicPublisher):
     """Publisher class for 'notify'."""
 
-    def __init__(self, conf, channel, topic, **kwargs):
+    def __init__(self, conf, channel, exchange_name, topic, **kwargs):
         self.durable = kwargs.pop('durable', conf.amqp_durable_queues)
         self.queue_arguments = _get_queue_arguments(conf)
-        super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs)
+        super(NotifyPublisher, self).__init__(conf, channel, exchange_name,
+                                              topic, **kwargs)
 
     def reconnect(self, channel):
         super(NotifyPublisher, self).reconnect(channel)
@@ -731,7 +731,7 @@ class Connection(object):
                           "'%(topic)s': %(err_str)s") % log_info)
 
         def _publish():
-            publisher = cls(self.conf, self.channel, topic, **kwargs)
+            publisher = cls(self.conf, self.channel, topic=topic, **kwargs)
             publisher.send(msg, timeout)
 
         self.ensure(_error_callback, _publish)
@@ -743,8 +743,8 @@ class Connection(object):
         """
         self.declare_consumer(DirectConsumer, topic, callback)
 
-    def declare_topic_consumer(self, topic, callback=None, queue_name=None,
-                               exchange_name=None):
+    def declare_topic_consumer(self, exchange_name, topic, callback=None,
+                               queue_name=None):
         """Create a 'topic' consumer."""
         self.declare_consumer(functools.partial(TopicConsumer,
                                                 name=queue_name,
@@ -760,17 +760,19 @@ class Connection(object):
         """Send a 'direct' message."""
         self.publisher_send(DirectPublisher, msg_id, msg)
 
-    def topic_send(self, topic, msg, timeout=None):
+    def topic_send(self, exchange_name, topic, msg, timeout=None):
         """Send a 'topic' message."""
-        self.publisher_send(TopicPublisher, topic, msg, timeout)
+        self.publisher_send(TopicPublisher, topic, msg, timeout,
+                            exchange_name=exchange_name)
 
     def fanout_send(self, topic, msg):
         """Send a 'fanout' message."""
         self.publisher_send(FanoutPublisher, topic, msg)
 
-    def notify_send(self, topic, msg, **kwargs):
+    def notify_send(self, exchange_name, topic, msg, **kwargs):
         """Send a notify message on a topic."""
-        self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs)
+        self.publisher_send(NotifyPublisher, topic, msg, timeout=None,
+                            exchange_name=exchange_name, **kwargs)
 
     def consume(self, limit=None, timeout=None):
         """Consume from all queues/consumers."""
diff --git a/tests/test_qpid.py b/tests/test_qpid.py
index 23145518a..976d5eb37 100644
--- a/tests/test_qpid.py
+++ b/tests/test_qpid.py
@@ -167,11 +167,17 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
 
     scenarios = [
         ('direct', dict(consumer_cls=qpid_driver.DirectConsumer,
-                        publisher_cls=qpid_driver.DirectPublisher)),
+                        consumer_kwargs={},
+                        publisher_cls=qpid_driver.DirectPublisher,
+                        publisher_kwargs={})),
         ('topic', dict(consumer_cls=qpid_driver.TopicConsumer,
-                       publisher_cls=qpid_driver.TopicPublisher)),
+                       consumer_kwargs={'exchange_name': 'openstack'},
+                       publisher_cls=qpid_driver.TopicPublisher,
+                       publisher_kwargs={'exchange_name': 'openstack'})),
         ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer,
-                        publisher_cls=qpid_driver.FanoutPublisher)),
+                        consumer_kwargs={},
+                        publisher_cls=qpid_driver.FanoutPublisher,
+                        publisher_kwargs={})),
     ]
 
     def setUp(self):
@@ -195,7 +201,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
             self.consumer_cls(self.conf,
                               self.session_receive,
                               msgid_or_topic,
-                              consumer_callback)
+                              consumer_callback,
+                              **self.consumer_kwargs)
         except Exception as e:
             recvd_exc_msg = e.message
 
@@ -205,7 +212,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase):
         try:
             self.publisher_cls(self.conf,
                                self.session_send,
-                               msgid_or_topic)
+                               topic=msgid_or_topic,
+                               **self.publisher_kwargs)
         except Exception as e:
             recvd_exc_msg = e.message
 
@@ -307,11 +315,15 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase):
     ]
     _exchange_class = [
         ('topic', dict(consumer_cls=qpid_driver.TopicConsumer,
+                       consumer_kwargs={'exchange_name': 'openstack'},
                        publisher_cls=qpid_driver.TopicPublisher,
+                       publisher_kwargs={'exchange_name': 'openstack'},
                        topic='topictest.test',
                        receive_topic='topictest.test')),
         ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer,
+                        consumer_kwargs={},
                         publisher_cls=qpid_driver.FanoutPublisher,
+                        publisher_kwargs={},
                         topic='fanouttest',
                         receive_topic='fanouttest')),
     ]
@@ -404,7 +416,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase):
             consumer = self.consumer_cls(self.conf,
                                          self.session_receive,
                                          self.receive_topic,
-                                         self.consumer_callback)
+                                         self.consumer_callback,
+                                         **self.consumer_kwargs)
             self._receivers.append(consumer)
 
             # create receivers threads
@@ -415,7 +428,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase):
         for sender_id in range(self.no_senders):
             publisher = self.publisher_cls(self.conf,
                                            self.session_send,
-                                           self.topic)
+                                           topic=self.topic,
+                                           **self.publisher_kwargs)
             self._senders.append(publisher)
 
             # create sender threads
@@ -450,6 +464,75 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase):
 TestQpidTopicAndFanout.generate_scenarios()
 
 
+class AddressNodeMatcher(object):
+    def __init__(self, node):
+        self.node = node
+
+    def __eq__(self, address):
+        return address.split(';')[0].strip() == self.node
+
+
+class TestDriverInterface(_QpidBaseTestCase):
+    """Unit Test cases to test the amqpdriver with qpid
+    """
+
+    def setUp(self):
+        super(TestDriverInterface, self).setUp()
+        self.config(qpid_topology_version=2)
+        transport = messaging.get_transport(self.conf)
+        self.driver = transport._driver
+
+    def test_listen_and_direct_send(self):
+        target = messaging.Target(exchange="exchange_test",
+                                  topic="topic_test",
+                                  server="server_test")
+
+        with mock.patch('qpid.messaging.Connection') as conn_cls:
+            conn = conn_cls.return_value
+            session = conn.session.return_value
+            session.receiver.side_effect = [mock.Mock(), mock.Mock(),
+                                            mock.Mock()]
+
+            listener = self.driver.listen(target)
+            listener.conn.direct_send("msg_id", {})
+
+        self.assertEqual(3, len(listener.conn.consumers))
+
+        expected_calls = [
+            mock.call(AddressNodeMatcher(
+                'amq.topic/topic/exchange_test/topic_test')),
+            mock.call(AddressNodeMatcher(
+                'amq.topic/topic/exchange_test/topic_test.server_test')),
+            mock.call(AddressNodeMatcher('amq.topic/fanout/topic_test')),
+        ]
+        session.receiver.assert_has_calls(expected_calls)
+        session.sender.assert_called_with(
+            AddressNodeMatcher("amq.direct/msg_id"))
+
+    def test_send(self):
+        target = messaging.Target(exchange="exchange_test",
+                                  topic="topic_test",
+                                  server="server_test")
+        with mock.patch('qpid.messaging.Connection') as conn_cls:
+            conn = conn_cls.return_value
+            session = conn.session.return_value
+
+            self.driver.send(target, {}, {})
+            session.sender.assert_called_with(AddressNodeMatcher(
+                "amq.topic/topic/exchange_test/topic_test.server_test"))
+
+    def test_send_notification(self):
+        target = messaging.Target(exchange="exchange_test",
+                                  topic="topic_test.info")
+        with mock.patch('qpid.messaging.Connection') as conn_cls:
+            conn = conn_cls.return_value
+            session = conn.session.return_value
+
+            self.driver.send_notification(target, {}, {}, "2.0")
+            session.sender.assert_called_with(AddressNodeMatcher(
+                "amq.topic/topic/exchange_test/topic_test.info"))
+
+
 class TestQpidReconnectOrder(test_utils.BaseTestCase):
     """Unit Test cases to test reconnection
     """