diff --git a/oslo_messaging/_drivers/protocols/amqp/controller.py b/oslo_messaging/_drivers/protocols/amqp/controller.py index 73f5cf112..8d949ed2e 100644 --- a/oslo_messaging/_drivers/protocols/amqp/controller.py +++ b/oslo_messaging/_drivers/protocols/amqp/controller.py @@ -36,6 +36,7 @@ from six import moves from oslo_messaging._drivers.protocols.amqp import eventloop from oslo_messaging._drivers.protocols.amqp import opts +from oslo_messaging import exceptions from oslo_messaging import transport LOG = logging.getLogger(__name__) @@ -87,6 +88,14 @@ class Replies(pyngus.ReceiverEventHandler): request.reply_to = self._receiver.source_address LOG.debug("Reply for msg id=%s expected on link %s", request.id, request.reply_to) + return request.id + + def cancel_response(self, msg_id): + """Abort waiting for a response message. This can be used if the + request fails and no reply is expected. + """ + if msg_id in self._correlation: + del self._correlation[msg_id] # Pyngus ReceiverLink event callbacks: @@ -121,16 +130,20 @@ class Replies(pyngus.ReceiverEventHandler): key = message.correlation_id if key in self._correlation: LOG.debug("Received response for msg id=%s", key) - self._correlation[key].put(message) + result = {"status": "OK", + "response": message} + self._correlation[key].put(result) # cleanup (only need one response per request) del self._correlation[key] + receiver.message_accepted(handle) else: LOG.warn("Can't find receiver for response msg id=%s, dropping!", key) - receiver.message_accepted(handle) + receiver.message_modified(handle, True, True, None) def _update_credit(self): - if self.capacity > self._credit: + # ensure we have enough credit + if self._credit < self.capacity / 2: self._receiver.add_capacity(self.capacity - self._credit) self._credit = self.capacity @@ -143,6 +156,7 @@ class Server(pyngus.ReceiverEventHandler): def __init__(self, addresses, incoming): self._incoming = incoming self._addresses = addresses + self._capacity = 500 # credit per link def attach(self, connection): """Create receiver links over the given connection for all the @@ -162,7 +176,7 @@ class Server(pyngus.ReceiverEventHandler): # approach would monitor for a back-up of inbound messages to be # processed by the consuming application and backpressure the # sender based on configured thresholds. - r.add_capacity(500) + r.add_capacity(self._capacity) r.open() self._receivers.append(r) @@ -183,9 +197,8 @@ class Server(pyngus.ReceiverEventHandler): """This is a Pyngus callback, invoked by Pyngus when a new message arrives on this receiver link from the peer. """ - # TODO(kgiusti) Sub-optimal to grant one credit each time a message - # arrives. A better approach would grant batches of credit on demand. - receiver.add_capacity(1) + if receiver.capacity < self._capacity / 2: + receiver.add_capacity(self._capacity - receiver.capacity) self._incoming.put(message) LOG.debug("message received: %s", message) receiver.message_accepted(handle) @@ -304,17 +317,41 @@ class Controller(pyngus.ConnectionEventHandler): # methods executed by Tasks created by the driver: - def request(self, target, request, reply_queue=None): - """Send a request message to the given target, and arrange for a - response to be put on the optional reply_queue if specified + def request(self, target, request, result_queue, reply_expected=False): + """Send a request message to the given target and arrange for a + result to be put on the result_queue. If reply_expected, the result + will include the reply message (if successful). """ address = self._resolve(target) LOG.debug("Sending request for %s to %s", target, address) - if reply_queue is not None: - self._replies.prepare_for_response(request, reply_queue) - self._send(address, request) + if reply_expected: + msg_id = self._replies.prepare_for_response(request, result_queue) + + def _callback(link, handle, state, info): + if state == pyngus.SenderLink.ACCEPTED: # message received + if not reply_expected: + # can wake up the sender now + result = {"status": "OK"} + result_queue.put(result) + else: + # we will wake up the sender when the reply message is + # received. See Replies.message_received() + pass + else: # send failed/rejected/etc + msg = "Message send failed: remote disposition: %s, info: %s" + exc = exceptions.MessageDeliveryFailure(msg % (state, info)) + result = {"status": "ERROR", "error": exc} + if reply_expected: + # no response will be received, so cancel the correlation + self._replies.cancel_response(msg_id) + result_queue.put(result) + self._send(address, request, _callback) def response(self, address, response): + """Send a response message to the client listening on 'address'. + To prevent a misbehaving client from blocking a server indefinitely, + the message is send asynchronously. + """ LOG.debug("Sending response to %s", address) self._send(address, response) @@ -366,11 +403,14 @@ class Controller(pyngus.ConnectionEventHandler): self._senders[address] = sender return sender - def _send(self, addr, message): - """Send the message out the link addressed by 'addr'.""" + def _send(self, addr, message, callback=None, handle=None): + """Send the message out the link addressed by 'addr'. If a + delivery_callback is given it will be invoked when the send has + completed (whether successfully or in error). + """ address = str(addr) message.address = address - self._sender(address).send(message) + self._sender(address).send(message, delivery_callback=callback) def _server_address(self, target): return self._concatenate([self.server_request_prefix, diff --git a/oslo_messaging/_drivers/protocols/amqp/driver.py b/oslo_messaging/_drivers/protocols/amqp/driver.py index cc2e1be8e..22fdc9e3b 100644 --- a/oslo_messaging/_drivers/protocols/amqp/driver.py +++ b/oslo_messaging/_drivers/protocols/amqp/driver.py @@ -185,16 +185,16 @@ class ProtonDriver(base.BaseDriver): LOG.debug("Send to %s", target) task = drivertasks.SendTask(target, request, wait_for_reply, expire) self._ctrl.add_task(task) - result = None - if wait_for_reply: - # the following can raise MessagingTimeout if no reply received: - reply = task.get_reply(timeout) + # wait for the eventloop to process the command. If the command is + # an RPC call retrieve the reply message + reply = task.wait(timeout) + if reply: # TODO(kgiusti) how to handle failure to un-marshal? Must log, and # determine best way to communicate this failure back up to the # caller - result = unmarshal_response(reply, self._allowed_remote_exmods) + reply = unmarshal_response(reply, self._allowed_remote_exmods) LOG.debug("Send to %s returning", target) - return result + return reply @_ensure_connect_called def send_notification(self, target, ctxt, message, version, diff --git a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py index 63a32292e..5d9e2ed4d 100644 --- a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py +++ b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py @@ -13,10 +13,11 @@ # under the License. import logging +import threading import time -import oslo_messaging from oslo_messaging._drivers.protocols.amqp import controller +from oslo_messaging import exceptions from six import moves @@ -24,36 +25,44 @@ LOG = logging.getLogger(__name__) class SendTask(controller.Task): - """A task that sends a message to a target, and optionally allows for the - calling thread to wait for a reply. + """A task that sends a message to a target, and optionally waits for a + reply message. The caller may block until the remote confirms receipt or + the reply message has arrived. """ - def __init__(self, target, request, reply_expected, deadline): + def __init__(self, target, request, wait_for_reply, deadline): super(SendTask, self).__init__() self._target = target self._request = request self._deadline = deadline - if reply_expected: - self._reply_queue = moves.queue.Queue() - else: - self._reply_queue = None + self._wait_for_reply = wait_for_reply + self._results_queue = moves.queue.Queue() + + def wait(self, timeout): + """Wait for the send to complete, and, optionally, a reply message from + the remote. Will raise MessagingTimeout if the send does not complete + or no reply is received within timeout seconds. If the request has + failed for any other reason, a MessagingException is raised." + """ + try: + result = self._results_queue.get(timeout=timeout) + except moves.queue.Empty: + if self._wait_for_reply: + reason = "Timed out waiting for a reply." + else: + reason = "Timed out waiting for send to complete." + raise exceptions.MessagingTimeout(reason) + if result["status"] == "OK": + return result.get("response", None) + raise result["error"] def execute(self, controller): """Runs on eventloop thread - sends request.""" if not self._deadline or self._deadline > time.time(): - controller.request(self._target, self._request, self._reply_queue) + controller.request(self._target, self._request, + self._results_queue, self._wait_for_reply) else: LOG.warn("Send request to %s aborted: TTL expired.", self._target) - def get_reply(self, timeout): - """Retrieve the reply.""" - if not self._reply_queue: - return None - try: - return self._reply_queue.get(timeout=timeout) - except moves.queue.Empty: - raise oslo_messaging.MessagingTimeout( - 'Timed out waiting for a reply') - class ListenTask(controller.Task): """A task that creates a subscription to the given target. Messages @@ -78,13 +87,21 @@ class ListenTask(controller.Task): class ReplyTask(controller.Task): - """A task that sends 'response' message to address.""" + """A task that sends 'response' message to 'address'. + """ def __init__(self, address, response, log_failure): super(ReplyTask, self).__init__() self._address = address self._response = response self._log_failure = log_failure + self._wakeup = threading.Event() + + def wait(self): + """Wait for the controller to send the message. + """ + self._wakeup.wait() def execute(self, controller): """Run on the eventloop thread - send the response message.""" controller.response(self._address, self._response) + self._wakeup.set() diff --git a/oslo_messaging/tests/test_amqp_driver.py b/oslo_messaging/tests/test_amqp_driver.py index 0d56f9639..379033d95 100644 --- a/oslo_messaging/tests/test_amqp_driver.py +++ b/oslo_messaging/tests/test_amqp_driver.py @@ -264,26 +264,34 @@ class TestAmqpNotification(_AmqpBrokerTestCase): msg_count = len(notifications) * 2 listener = _ListenerThread(nl, msg_count) targets = ['topic-1.info', - 'topic-1.bad', # should be dropped - 'bad-topic.debug', # should be dropped + 'topic-1.bad', # will raise MessageDeliveryFailure + 'bad-topic.debug', # will raise MessageDeliveryFailure 'topic-1.error', 'topic-2.debug'] + excepted_targets = [] + exception_count = 0 for version in (1.0, 2.0): for t in targets: - driver.send_notification(oslo_messaging.Target(topic=t), - "context", {'target': t}, - version) + try: + driver.send_notification(oslo_messaging.Target(topic=t), + "context", {'target': t}, + version) + except oslo_messaging.MessageDeliveryFailure: + exception_count += 1 + excepted_targets.append(t) listener.join(timeout=30) self.assertFalse(listener.isAlive()) topics = [x.message.get('target') for x in listener.get_messages()] - self.assertEqual(len(topics), msg_count) self.assertEqual(topics.count('topic-1.info'), 2) self.assertEqual(topics.count('topic-1.error'), 2) self.assertEqual(topics.count('topic-2.debug'), 2) self.assertEqual(self._broker.dropped_count, 4) + self.assertEqual(exception_count, 4) + self.assertEqual(excepted_targets.count('topic-1.bad'), 2) + self.assertEqual(excepted_targets.count('bad-topic.debug'), 2) driver.cleanup() diff --git a/tests/test_amqp_driver.py b/tests/test_amqp_driver.py index 6fcba0dea..b2ebc3347 100644 --- a/tests/test_amqp_driver.py +++ b/tests/test_amqp_driver.py @@ -258,25 +258,40 @@ class TestAmqpNotification(_AmqpBrokerTestCase): notifications = [(messaging.Target(topic="topic-1"), 'info'), (messaging.Target(topic="topic-1"), 'error'), (messaging.Target(topic="topic-2"), 'debug')] - nl = driver.listen_for_notifications(notifications) + nl = driver.listen_for_notifications(notifications, None) - listener = _ListenerThread(nl, 3) + # send one for each support version: + msg_count = len(notifications) * 2 + listener = _ListenerThread(nl, msg_count) targets = ['topic-1.info', - 'topic-1.bad', # should be dropped - 'bad-topic.debug', # should be dropped - 'topic-1.error', 'topic-2.debug'] + 'topic-1.bad', # will raise MessagingDeliveryFailure + 'bad-topic.debug', # will raise MessagingDeliveryFailure + 'topic-1.error', + 'topic-2.debug'] + + excepted_targets = [] + exception_count = 0 + for version in (1.0, 2.0): + for t in targets: + try: + driver.send_notification(messaging.Target(topic=t), + "context", {'target': t}, + version) + except messaging.MessageDeliveryFailure: + exception_count += 1 + excepted_targets.append(t) - for t in targets: - driver.send_notification(messaging.Target(topic=t), - "context", {'target': t}, - 1.0) listener.join(timeout=30) self.assertFalse(listener.isAlive()) topics = [x.message.get('target') for x in listener.get_messages()] - self.assertTrue('topic-1.info' in topics) - self.assertTrue('topic-1.error' in topics) - self.assertTrue('topic-2.debug' in topics) - self.assertEqual(self._broker.dropped_count, 2) + self.assertEqual(len(topics), msg_count) + self.assertEqual(topics.count('topic-1.info'), 2) + self.assertEqual(topics.count('topic-1.error'), 2) + self.assertEqual(topics.count('topic-2.debug'), 2) + self.assertEqual(self._broker.dropped_count, 4) + self.assertEqual(exception_count, 4) + self.assertEqual(excepted_targets.count('topic-1.bad'), 2) + self.assertEqual(excepted_targets.count('bad-topic.debug'), 2) driver.cleanup()