diff --git a/oslo/messaging/_drivers/amqp.py b/oslo/messaging/_drivers/amqp.py index 325c3609a..e0d33b393 100644 --- a/oslo/messaging/_drivers/amqp.py +++ b/oslo/messaging/_drivers/amqp.py @@ -249,7 +249,3 @@ def _add_unique_id(msg): unique_id = uuid.uuid4().hex msg.update({UNIQUE_ID: unique_id}) LOG.debug('UNIQUE_ID is %s.' % (unique_id)) - - -def get_control_exchange(conf): - return conf.control_exchange diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index 16626d0b2..d990e90a7 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -297,10 +297,6 @@ class AMQPDriverBase(base.BaseDriver): self._default_exchange = default_exchange - # FIXME(markmc): temp hack - if self._default_exchange: - self.conf.set_override('control_exchange', self._default_exchange) - self._connection_pool = connection_pool self._reply_q_lock = threading.Lock() @@ -308,6 +304,9 @@ class AMQPDriverBase(base.BaseDriver): self._reply_q_conn = None self._waiter = None + def _get_exchange(self, target): + return target.exchange or self._default_exchange + def _get_connection(self, pooled=True): return rpc_amqp.ConnectionContext(self.conf, self._url, @@ -364,14 +363,16 @@ class AMQPDriverBase(base.BaseDriver): try: with self._get_connection() as conn: if notify: - conn.notify_send(target.topic, msg) + conn.notify_send(self._get_exchange(target), + target.topic, msg) elif target.fanout: conn.fanout_send(target.topic, msg) else: topic = target.topic if target.server: topic = '%s.%s' % (target.topic, target.server) - conn.topic_send(topic, msg, timeout=timeout) + conn.topic_send(exchange_name=self._get_exchange(target), + topic=topic, msg=msg, timeout=timeout) if wait_for_reply: result = self._waiter.wait(msg_id, timeout) @@ -394,9 +395,13 @@ class AMQPDriverBase(base.BaseDriver): listener = AMQPListener(self, conn) - conn.declare_topic_consumer(target.topic, listener) - conn.declare_topic_consumer('%s.%s' % (target.topic, target.server), - listener) + conn.declare_topic_consumer(exchange_name=self._get_exchange(target), + topic=target.topic, + callback=listener) + conn.declare_topic_consumer(exchange_name=self._get_exchange(target), + topic='%s.%s' % (target.topic, + target.server), + callback=listener) conn.declare_fanout_consumer(target.topic, listener) return listener @@ -406,9 +411,10 @@ class AMQPDriverBase(base.BaseDriver): listener = AMQPListener(self, conn) for target, priority in targets_and_priorities: - conn.declare_topic_consumer('%s.%s' % (target.topic, priority), - callback=listener, - exchange_name=target.exchange) + conn.declare_topic_consumer( + exchange_name=self._get_exchange(target), + topic='%s.%s' % (target.topic, priority), + callback=listener) return listener def cleanup(self): diff --git a/oslo/messaging/_drivers/impl_qpid.py b/oslo/messaging/_drivers/impl_qpid.py index 10fd7207c..def074baf 100644 --- a/oslo/messaging/_drivers/impl_qpid.py +++ b/oslo/messaging/_drivers/impl_qpid.py @@ -248,8 +248,8 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'.""" - def __init__(self, conf, session, topic, callback, name=None, - exchange_name=None): + def __init__(self, conf, session, topic, callback, exchange_name, + name=None): """Init a 'topic' queue. :param session: the amqp session to use @@ -259,7 +259,6 @@ class TopicConsumer(ConsumerBase): :param name: optional queue name, defaults to topic """ - exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) link_opts = { "auto-delete": conf.amqp_auto_delete, "durable": conf.amqp_durable_queues, @@ -376,14 +375,14 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'.""" - def __init__(self, conf, session, msg_id): + def __init__(self, conf, session, topic): """Init a 'direct' publisher.""" if conf.qpid_topology_version == 1: - node_name = "%s/%s" % (msg_id, msg_id) + node_name = "%s/%s" % (topic, topic) node_opts = {"type": "direct"} elif conf.qpid_topology_version == 2: - node_name = "amq.direct/%s" % msg_id + node_name = "amq.direct/%s" % topic node_opts = {} else: raise_invalid_topology_version(conf) @@ -394,11 +393,9 @@ class DirectPublisher(Publisher): class TopicPublisher(Publisher): """Publisher class for 'topic'.""" - def __init__(self, conf, session, topic): + def __init__(self, conf, session, exchange_name, topic): """Init a 'topic' publisher. """ - exchange_name = rpc_amqp.get_control_exchange(conf) - if conf.qpid_topology_version == 1: node_name = "%s/%s" % (exchange_name, topic) elif conf.qpid_topology_version == 2: @@ -430,10 +427,9 @@ class FanoutPublisher(Publisher): class NotifyPublisher(Publisher): """Publisher class for notifications.""" - def __init__(self, conf, session, topic): + def __init__(self, conf, session, exchange_name, topic): """Init a 'topic' publisher. """ - exchange_name = rpc_amqp.get_control_exchange(conf) node_opts = {"durable": True} if conf.qpid_topology_version == 1: @@ -618,7 +614,7 @@ class Connection(object): raise StopIteration yield self.ensure(_error_callback, _consume) - def publisher_send(self, cls, topic, msg): + def publisher_send(self, cls, topic, msg, **kwargs): """Send to a publisher based on the publisher class.""" def _connect_error(exc): @@ -627,7 +623,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publisher_send(): - publisher = cls(self.conf, self.session, topic) + publisher = cls(self.conf, self.session, topic=topic, **kwargs) publisher.send(msg) return self.ensure(_connect_error, _publisher_send) @@ -639,8 +635,8 @@ class Connection(object): """ self.declare_consumer(DirectConsumer, topic, callback) - def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + def declare_topic_consumer(self, exchange_name, topic, callback=None, + queue_name=None): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, @@ -654,9 +650,9 @@ class Connection(object): def direct_send(self, msg_id, msg): """Send a 'direct' message.""" - self.publisher_send(DirectPublisher, msg_id, msg) + self.publisher_send(DirectPublisher, topic=msg_id, msg=msg) - def topic_send(self, topic, msg, timeout=None): + def topic_send(self, exchange_name, topic, msg, timeout=None): """Send a 'topic' message.""" # # We want to create a message with attributes, e.g. a TTL. We @@ -669,15 +665,17 @@ class Connection(object): # will need to be altered accordingly. # qpid_message = qpid_messaging.Message(content=msg, ttl=timeout) - self.publisher_send(TopicPublisher, topic, qpid_message) + self.publisher_send(TopicPublisher, topic=topic, msg=qpid_message, + exchange_name=exchange_name) def fanout_send(self, topic, msg): """Send a 'fanout' message.""" - self.publisher_send(FanoutPublisher, topic, msg) + self.publisher_send(FanoutPublisher, topic=topic, msg=msg) - def notify_send(self, topic, msg, **kwargs): + def notify_send(self, exchange_name, topic, msg, **kwargs): """Send a notify message on a topic.""" - self.publisher_send(NotifyPublisher, topic, msg) + self.publisher_send(NotifyPublisher, topic=topic, msg=msg, + exchange_name=exchange_name) def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index 29d101a07..f7ca5e41c 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -247,8 +247,8 @@ class DirectConsumer(ConsumerBase): class TopicConsumer(ConsumerBase): """Consumer class for 'topic'.""" - def __init__(self, conf, channel, topic, callback, tag, name=None, - exchange_name=None, **kwargs): + def __init__(self, conf, channel, topic, callback, tag, exchange_name, + name=None, **kwargs): """Init a 'topic' queue. :param channel: the amqp channel to use @@ -256,6 +256,7 @@ class TopicConsumer(ConsumerBase): :paramtype topic: str :param callback: the callback to call when messages are received :param tag: a unique ID for the consumer on the channel + :param exchange_name: the exchange name to use :param name: optional queue name, defaults to topic :paramtype name: str @@ -267,7 +268,6 @@ class TopicConsumer(ConsumerBase): 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) - exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) exchange = kombu.entity.Exchange(name=exchange_name, type='topic', durable=options['durable'], @@ -347,7 +347,7 @@ class Publisher(object): class DirectPublisher(Publisher): """Publisher class for 'direct'.""" - def __init__(self, conf, channel, msg_id, **kwargs): + def __init__(self, conf, channel, topic, **kwargs): """Init a 'direct' publisher. Kombu options may be passed as keyword args to override defaults @@ -357,13 +357,13 @@ class DirectPublisher(Publisher): 'auto_delete': True, 'exclusive': False} options.update(kwargs) - super(DirectPublisher, self).__init__(channel, msg_id, msg_id, + super(DirectPublisher, self).__init__(channel, topic, topic, type='direct', **options) class TopicPublisher(Publisher): """Publisher class for 'topic'.""" - def __init__(self, conf, channel, topic, **kwargs): + def __init__(self, conf, channel, exchange_name, topic, **kwargs): """Init a 'topic' publisher. Kombu options may be passed as keyword args to override defaults @@ -372,7 +372,6 @@ class TopicPublisher(Publisher): 'auto_delete': conf.amqp_auto_delete, 'exclusive': False} options.update(kwargs) - exchange_name = rpc_amqp.get_control_exchange(conf) super(TopicPublisher, self).__init__(channel, exchange_name, topic, @@ -398,10 +397,11 @@ class FanoutPublisher(Publisher): class NotifyPublisher(TopicPublisher): """Publisher class for 'notify'.""" - def __init__(self, conf, channel, topic, **kwargs): + def __init__(self, conf, channel, exchange_name, topic, **kwargs): self.durable = kwargs.pop('durable', conf.amqp_durable_queues) self.queue_arguments = _get_queue_arguments(conf) - super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) + super(NotifyPublisher, self).__init__(conf, channel, exchange_name, + topic, **kwargs) def reconnect(self, channel): super(NotifyPublisher, self).reconnect(channel) @@ -731,7 +731,7 @@ class Connection(object): "'%(topic)s': %(err_str)s") % log_info) def _publish(): - publisher = cls(self.conf, self.channel, topic, **kwargs) + publisher = cls(self.conf, self.channel, topic=topic, **kwargs) publisher.send(msg, timeout) self.ensure(_error_callback, _publish) @@ -743,8 +743,8 @@ class Connection(object): """ self.declare_consumer(DirectConsumer, topic, callback) - def declare_topic_consumer(self, topic, callback=None, queue_name=None, - exchange_name=None): + def declare_topic_consumer(self, exchange_name, topic, callback=None, + queue_name=None): """Create a 'topic' consumer.""" self.declare_consumer(functools.partial(TopicConsumer, name=queue_name, @@ -760,17 +760,19 @@ class Connection(object): """Send a 'direct' message.""" self.publisher_send(DirectPublisher, msg_id, msg) - def topic_send(self, topic, msg, timeout=None): + def topic_send(self, exchange_name, topic, msg, timeout=None): """Send a 'topic' message.""" - self.publisher_send(TopicPublisher, topic, msg, timeout) + self.publisher_send(TopicPublisher, topic, msg, timeout, + exchange_name=exchange_name) def fanout_send(self, topic, msg): """Send a 'fanout' message.""" self.publisher_send(FanoutPublisher, topic, msg) - def notify_send(self, topic, msg, **kwargs): + def notify_send(self, exchange_name, topic, msg, **kwargs): """Send a notify message on a topic.""" - self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) + self.publisher_send(NotifyPublisher, topic, msg, timeout=None, + exchange_name=exchange_name, **kwargs) def consume(self, limit=None, timeout=None): """Consume from all queues/consumers.""" diff --git a/tests/test_qpid.py b/tests/test_qpid.py index 23145518a..976d5eb37 100644 --- a/tests/test_qpid.py +++ b/tests/test_qpid.py @@ -167,11 +167,17 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): scenarios = [ ('direct', dict(consumer_cls=qpid_driver.DirectConsumer, - publisher_cls=qpid_driver.DirectPublisher)), + consumer_kwargs={}, + publisher_cls=qpid_driver.DirectPublisher, + publisher_kwargs={})), ('topic', dict(consumer_cls=qpid_driver.TopicConsumer, - publisher_cls=qpid_driver.TopicPublisher)), + consumer_kwargs={'exchange_name': 'openstack'}, + publisher_cls=qpid_driver.TopicPublisher, + publisher_kwargs={'exchange_name': 'openstack'})), ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer, - publisher_cls=qpid_driver.FanoutPublisher)), + consumer_kwargs={}, + publisher_cls=qpid_driver.FanoutPublisher, + publisher_kwargs={})), ] def setUp(self): @@ -195,7 +201,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): self.consumer_cls(self.conf, self.session_receive, msgid_or_topic, - consumer_callback) + consumer_callback, + **self.consumer_kwargs) except Exception as e: recvd_exc_msg = e.message @@ -205,7 +212,8 @@ class TestQpidInvalidTopologyVersion(_QpidBaseTestCase): try: self.publisher_cls(self.conf, self.session_send, - msgid_or_topic) + topic=msgid_or_topic, + **self.publisher_kwargs) except Exception as e: recvd_exc_msg = e.message @@ -307,11 +315,15 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): ] _exchange_class = [ ('topic', dict(consumer_cls=qpid_driver.TopicConsumer, + consumer_kwargs={'exchange_name': 'openstack'}, publisher_cls=qpid_driver.TopicPublisher, + publisher_kwargs={'exchange_name': 'openstack'}, topic='topictest.test', receive_topic='topictest.test')), ('fanout', dict(consumer_cls=qpid_driver.FanoutConsumer, + consumer_kwargs={}, publisher_cls=qpid_driver.FanoutPublisher, + publisher_kwargs={}, topic='fanouttest', receive_topic='fanouttest')), ] @@ -404,7 +416,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): consumer = self.consumer_cls(self.conf, self.session_receive, self.receive_topic, - self.consumer_callback) + self.consumer_callback, + **self.consumer_kwargs) self._receivers.append(consumer) # create receivers threads @@ -415,7 +428,8 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): for sender_id in range(self.no_senders): publisher = self.publisher_cls(self.conf, self.session_send, - self.topic) + topic=self.topic, + **self.publisher_kwargs) self._senders.append(publisher) # create sender threads @@ -450,6 +464,75 @@ class TestQpidTopicAndFanout(_QpidBaseTestCase): TestQpidTopicAndFanout.generate_scenarios() +class AddressNodeMatcher(object): + def __init__(self, node): + self.node = node + + def __eq__(self, address): + return address.split(';')[0].strip() == self.node + + +class TestDriverInterface(_QpidBaseTestCase): + """Unit Test cases to test the amqpdriver with qpid + """ + + def setUp(self): + super(TestDriverInterface, self).setUp() + self.config(qpid_topology_version=2) + transport = messaging.get_transport(self.conf) + self.driver = transport._driver + + def test_listen_and_direct_send(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test", + server="server_test") + + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + session.receiver.side_effect = [mock.Mock(), mock.Mock(), + mock.Mock()] + + listener = self.driver.listen(target) + listener.conn.direct_send("msg_id", {}) + + self.assertEqual(3, len(listener.conn.consumers)) + + expected_calls = [ + mock.call(AddressNodeMatcher( + 'amq.topic/topic/exchange_test/topic_test')), + mock.call(AddressNodeMatcher( + 'amq.topic/topic/exchange_test/topic_test.server_test')), + mock.call(AddressNodeMatcher('amq.topic/fanout/topic_test')), + ] + session.receiver.assert_has_calls(expected_calls) + session.sender.assert_called_with( + AddressNodeMatcher("amq.direct/msg_id")) + + def test_send(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test", + server="server_test") + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + + self.driver.send(target, {}, {}) + session.sender.assert_called_with(AddressNodeMatcher( + "amq.topic/topic/exchange_test/topic_test.server_test")) + + def test_send_notification(self): + target = messaging.Target(exchange="exchange_test", + topic="topic_test.info") + with mock.patch('qpid.messaging.Connection') as conn_cls: + conn = conn_cls.return_value + session = conn.session.return_value + + self.driver.send_notification(target, {}, {}, "2.0") + session.sender.assert_called_with(AddressNodeMatcher( + "amq.topic/topic/exchange_test/topic_test.info")) + + class TestQpidReconnectOrder(test_utils.BaseTestCase): """Unit Test cases to test reconnection """