From c127594de69c1443993ddff171738d8d7ad058aa Mon Sep 17 00:00:00 2001
From: Kenneth Giusti <kgiusti@gmail.com>
Date: Mon, 1 Jun 2015 10:24:45 -0400
Subject: [PATCH] Provide better detection of failures during message send

This change causes the message sender to block until the messaging
infrastructure (e.g. the broker) assumes ownership of the message (or
fails to accept it).  If the message is accepted, then the sender will
either pend for a response (in the case of RPC), or simply return (in
the case of notification).  If the message is rejected by the
messaging infrastructure a MessagingException will be raised at the
sender.

Change-Id: I3f4a1ed1c17e18f6d629f16e6b5c99de45b083d6
Closes-Bug: #1377228
---
 .../_drivers/protocols/amqp/controller.py     | 72 ++++++++++++++-----
 .../_drivers/protocols/amqp/driver.py         | 12 ++--
 .../_drivers/protocols/amqp/drivertasks.py    | 57 +++++++++------
 oslo_messaging/tests/test_amqp_driver.py      | 20 ++++--
 tests/test_amqp_driver.py                     | 41 +++++++----
 5 files changed, 141 insertions(+), 61 deletions(-)

diff --git a/oslo_messaging/_drivers/protocols/amqp/controller.py b/oslo_messaging/_drivers/protocols/amqp/controller.py
index 73f5cf112..8d949ed2e 100644
--- a/oslo_messaging/_drivers/protocols/amqp/controller.py
+++ b/oslo_messaging/_drivers/protocols/amqp/controller.py
@@ -36,6 +36,7 @@ from six import moves
 
 from oslo_messaging._drivers.protocols.amqp import eventloop
 from oslo_messaging._drivers.protocols.amqp import opts
+from oslo_messaging import exceptions
 from oslo_messaging import transport
 
 LOG = logging.getLogger(__name__)
@@ -87,6 +88,14 @@ class Replies(pyngus.ReceiverEventHandler):
         request.reply_to = self._receiver.source_address
         LOG.debug("Reply for msg id=%s expected on link %s",
                   request.id, request.reply_to)
+        return request.id
+
+    def cancel_response(self, msg_id):
+        """Abort waiting for a response message.  This can be used if the
+        request fails and no reply is expected.
+        """
+        if msg_id in self._correlation:
+            del self._correlation[msg_id]
 
     # Pyngus ReceiverLink event callbacks:
 
@@ -121,16 +130,20 @@ class Replies(pyngus.ReceiverEventHandler):
         key = message.correlation_id
         if key in self._correlation:
             LOG.debug("Received response for msg id=%s", key)
-            self._correlation[key].put(message)
+            result = {"status": "OK",
+                      "response": message}
+            self._correlation[key].put(result)
             # cleanup (only need one response per request)
             del self._correlation[key]
+            receiver.message_accepted(handle)
         else:
             LOG.warn("Can't find receiver for response msg id=%s, dropping!",
                      key)
-        receiver.message_accepted(handle)
+            receiver.message_modified(handle, True, True, None)
 
     def _update_credit(self):
-        if self.capacity > self._credit:
+        # ensure we have enough credit
+        if self._credit < self.capacity / 2:
             self._receiver.add_capacity(self.capacity - self._credit)
             self._credit = self.capacity
 
@@ -143,6 +156,7 @@ class Server(pyngus.ReceiverEventHandler):
     def __init__(self, addresses, incoming):
         self._incoming = incoming
         self._addresses = addresses
+        self._capacity = 500   # credit per link
 
     def attach(self, connection):
         """Create receiver links over the given connection for all the
@@ -162,7 +176,7 @@ class Server(pyngus.ReceiverEventHandler):
             # approach would monitor for a back-up of inbound messages to be
             # processed by the consuming application and backpressure the
             # sender based on configured thresholds.
-            r.add_capacity(500)
+            r.add_capacity(self._capacity)
             r.open()
             self._receivers.append(r)
 
@@ -183,9 +197,8 @@ class Server(pyngus.ReceiverEventHandler):
         """This is a Pyngus callback, invoked by Pyngus when a new message
         arrives on this receiver link from the peer.
         """
-        # TODO(kgiusti) Sub-optimal to grant one credit each time a message
-        # arrives.  A better approach would grant batches of credit on demand.
-        receiver.add_capacity(1)
+        if receiver.capacity < self._capacity / 2:
+            receiver.add_capacity(self._capacity - receiver.capacity)
         self._incoming.put(message)
         LOG.debug("message received: %s", message)
         receiver.message_accepted(handle)
@@ -304,17 +317,41 @@ class Controller(pyngus.ConnectionEventHandler):
 
     # methods executed by Tasks created by the driver:
 
-    def request(self, target, request, reply_queue=None):
-        """Send a request message to the given target, and arrange for a
-        response to be put on the optional reply_queue if specified
+    def request(self, target, request, result_queue, reply_expected=False):
+        """Send a request message to the given target and arrange for a
+        result to be put on the result_queue. If reply_expected, the result
+        will include the reply message (if successful).
         """
         address = self._resolve(target)
         LOG.debug("Sending request for %s to %s", target, address)
-        if reply_queue is not None:
-            self._replies.prepare_for_response(request, reply_queue)
-        self._send(address, request)
+        if reply_expected:
+            msg_id = self._replies.prepare_for_response(request, result_queue)
+
+        def _callback(link, handle, state, info):
+            if state == pyngus.SenderLink.ACCEPTED:  # message received
+                if not reply_expected:
+                    # can wake up the sender now
+                    result = {"status": "OK"}
+                    result_queue.put(result)
+                else:
+                    # we will wake up the sender when the reply message is
+                    # received.  See Replies.message_received()
+                    pass
+            else:  # send failed/rejected/etc
+                msg = "Message send failed: remote disposition: %s, info: %s"
+                exc = exceptions.MessageDeliveryFailure(msg % (state, info))
+                result = {"status": "ERROR", "error": exc}
+                if reply_expected:
+                    # no response will be received, so cancel the correlation
+                    self._replies.cancel_response(msg_id)
+                result_queue.put(result)
+        self._send(address, request, _callback)
 
     def response(self, address, response):
+        """Send a response message to the client listening on 'address'.
+        To prevent a misbehaving client from blocking a server indefinitely,
+        the message is send asynchronously.
+        """
         LOG.debug("Sending response to %s", address)
         self._send(address, response)
 
@@ -366,11 +403,14 @@ class Controller(pyngus.ConnectionEventHandler):
             self._senders[address] = sender
         return sender
 
-    def _send(self, addr, message):
-        """Send the message out the link addressed by 'addr'."""
+    def _send(self, addr, message, callback=None, handle=None):
+        """Send the message out the link addressed by 'addr'.  If a
+        delivery_callback is given it will be invoked when the send has
+        completed (whether successfully or in error).
+        """
         address = str(addr)
         message.address = address
-        self._sender(address).send(message)
+        self._sender(address).send(message, delivery_callback=callback)
 
     def _server_address(self, target):
         return self._concatenate([self.server_request_prefix,
diff --git a/oslo_messaging/_drivers/protocols/amqp/driver.py b/oslo_messaging/_drivers/protocols/amqp/driver.py
index cc2e1be8e..22fdc9e3b 100644
--- a/oslo_messaging/_drivers/protocols/amqp/driver.py
+++ b/oslo_messaging/_drivers/protocols/amqp/driver.py
@@ -185,16 +185,16 @@ class ProtonDriver(base.BaseDriver):
         LOG.debug("Send to %s", target)
         task = drivertasks.SendTask(target, request, wait_for_reply, expire)
         self._ctrl.add_task(task)
-        result = None
-        if wait_for_reply:
-            # the following can raise MessagingTimeout if no reply received:
-            reply = task.get_reply(timeout)
+        # wait for the eventloop to process the command. If the command is
+        # an RPC call retrieve the reply message
+        reply = task.wait(timeout)
+        if reply:
             # TODO(kgiusti) how to handle failure to un-marshal?  Must log, and
             # determine best way to communicate this failure back up to the
             # caller
-            result = unmarshal_response(reply, self._allowed_remote_exmods)
+            reply = unmarshal_response(reply, self._allowed_remote_exmods)
         LOG.debug("Send to %s returning", target)
-        return result
+        return reply
 
     @_ensure_connect_called
     def send_notification(self, target, ctxt, message, version,
diff --git a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
index 63a32292e..5d9e2ed4d 100644
--- a/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
+++ b/oslo_messaging/_drivers/protocols/amqp/drivertasks.py
@@ -13,10 +13,11 @@
 #    under the License.
 
 import logging
+import threading
 import time
 
-import oslo_messaging
 from oslo_messaging._drivers.protocols.amqp import controller
+from oslo_messaging import exceptions
 
 from six import moves
 
@@ -24,36 +25,44 @@ LOG = logging.getLogger(__name__)
 
 
 class SendTask(controller.Task):
-    """A task that sends a message to a target, and optionally allows for the
-    calling thread to wait for a reply.
+    """A task that sends a message to a target, and optionally waits for a
+    reply message.  The caller may block until the remote confirms receipt or
+    the reply message has arrived.
     """
-    def __init__(self, target, request, reply_expected, deadline):
+    def __init__(self, target, request, wait_for_reply, deadline):
         super(SendTask, self).__init__()
         self._target = target
         self._request = request
         self._deadline = deadline
-        if reply_expected:
-            self._reply_queue = moves.queue.Queue()
-        else:
-            self._reply_queue = None
+        self._wait_for_reply = wait_for_reply
+        self._results_queue = moves.queue.Queue()
+
+    def wait(self, timeout):
+        """Wait for the send to complete, and, optionally, a reply message from
+        the remote.  Will raise MessagingTimeout if the send does not complete
+        or no reply is received within timeout seconds. If the request has
+        failed for any other reason, a MessagingException is raised."
+        """
+        try:
+            result = self._results_queue.get(timeout=timeout)
+        except moves.queue.Empty:
+            if self._wait_for_reply:
+                reason = "Timed out waiting for a reply."
+            else:
+                reason = "Timed out waiting for send to complete."
+            raise exceptions.MessagingTimeout(reason)
+        if result["status"] == "OK":
+            return result.get("response", None)
+        raise result["error"]
 
     def execute(self, controller):
         """Runs on eventloop thread - sends request."""
         if not self._deadline or self._deadline > time.time():
-            controller.request(self._target, self._request, self._reply_queue)
+            controller.request(self._target, self._request,
+                               self._results_queue, self._wait_for_reply)
         else:
             LOG.warn("Send request to %s aborted: TTL expired.", self._target)
 
-    def get_reply(self, timeout):
-        """Retrieve the reply."""
-        if not self._reply_queue:
-            return None
-        try:
-            return self._reply_queue.get(timeout=timeout)
-        except moves.queue.Empty:
-            raise oslo_messaging.MessagingTimeout(
-                'Timed out waiting for a reply')
-
 
 class ListenTask(controller.Task):
     """A task that creates a subscription to the given target.  Messages
@@ -78,13 +87,21 @@ class ListenTask(controller.Task):
 
 
 class ReplyTask(controller.Task):
-    """A task that sends 'response' message to address."""
+    """A task that sends 'response' message to 'address'.
+    """
     def __init__(self, address, response, log_failure):
         super(ReplyTask, self).__init__()
         self._address = address
         self._response = response
         self._log_failure = log_failure
+        self._wakeup = threading.Event()
+
+    def wait(self):
+        """Wait for the controller to send the message.
+        """
+        self._wakeup.wait()
 
     def execute(self, controller):
         """Run on the eventloop thread - send the response message."""
         controller.response(self._address, self._response)
+        self._wakeup.set()
diff --git a/oslo_messaging/tests/test_amqp_driver.py b/oslo_messaging/tests/test_amqp_driver.py
index 0d56f9639..379033d95 100644
--- a/oslo_messaging/tests/test_amqp_driver.py
+++ b/oslo_messaging/tests/test_amqp_driver.py
@@ -264,26 +264,34 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
         msg_count = len(notifications) * 2
         listener = _ListenerThread(nl, msg_count)
         targets = ['topic-1.info',
-                   'topic-1.bad',  # should be dropped
-                   'bad-topic.debug',  # should be dropped
+                   'topic-1.bad',  # will raise MessageDeliveryFailure
+                   'bad-topic.debug',  # will raise MessageDeliveryFailure
                    'topic-1.error',
                    'topic-2.debug']
 
+        excepted_targets = []
+        exception_count = 0
         for version in (1.0, 2.0):
             for t in targets:
-                driver.send_notification(oslo_messaging.Target(topic=t),
-                                         "context", {'target': t},
-                                         version)
+                try:
+                    driver.send_notification(oslo_messaging.Target(topic=t),
+                                             "context", {'target': t},
+                                             version)
+                except oslo_messaging.MessageDeliveryFailure:
+                    exception_count += 1
+                    excepted_targets.append(t)
 
         listener.join(timeout=30)
         self.assertFalse(listener.isAlive())
         topics = [x.message.get('target') for x in listener.get_messages()]
-
         self.assertEqual(len(topics), msg_count)
         self.assertEqual(topics.count('topic-1.info'), 2)
         self.assertEqual(topics.count('topic-1.error'), 2)
         self.assertEqual(topics.count('topic-2.debug'), 2)
         self.assertEqual(self._broker.dropped_count, 4)
+        self.assertEqual(exception_count, 4)
+        self.assertEqual(excepted_targets.count('topic-1.bad'), 2)
+        self.assertEqual(excepted_targets.count('bad-topic.debug'), 2)
         driver.cleanup()
 
 
diff --git a/tests/test_amqp_driver.py b/tests/test_amqp_driver.py
index 6fcba0dea..b2ebc3347 100644
--- a/tests/test_amqp_driver.py
+++ b/tests/test_amqp_driver.py
@@ -258,25 +258,40 @@ class TestAmqpNotification(_AmqpBrokerTestCase):
         notifications = [(messaging.Target(topic="topic-1"), 'info'),
                          (messaging.Target(topic="topic-1"), 'error'),
                          (messaging.Target(topic="topic-2"), 'debug')]
-        nl = driver.listen_for_notifications(notifications)
+        nl = driver.listen_for_notifications(notifications, None)
 
-        listener = _ListenerThread(nl, 3)
+        # send one for each support version:
+        msg_count = len(notifications) * 2
+        listener = _ListenerThread(nl, msg_count)
         targets = ['topic-1.info',
-                   'topic-1.bad',  # should be dropped
-                   'bad-topic.debug',  # should be dropped
-                   'topic-1.error', 'topic-2.debug']
+                   'topic-1.bad',  # will raise MessagingDeliveryFailure
+                   'bad-topic.debug',  # will raise MessagingDeliveryFailure
+                   'topic-1.error',
+                   'topic-2.debug']
+
+        excepted_targets = []
+        exception_count = 0
+        for version in (1.0, 2.0):
+            for t in targets:
+                try:
+                    driver.send_notification(messaging.Target(topic=t),
+                                             "context", {'target': t},
+                                             version)
+                except messaging.MessageDeliveryFailure:
+                    exception_count += 1
+                    excepted_targets.append(t)
 
-        for t in targets:
-            driver.send_notification(messaging.Target(topic=t),
-                                     "context", {'target': t},
-                                     1.0)
         listener.join(timeout=30)
         self.assertFalse(listener.isAlive())
         topics = [x.message.get('target') for x in listener.get_messages()]
-        self.assertTrue('topic-1.info' in topics)
-        self.assertTrue('topic-1.error' in topics)
-        self.assertTrue('topic-2.debug' in topics)
-        self.assertEqual(self._broker.dropped_count, 2)
+        self.assertEqual(len(topics), msg_count)
+        self.assertEqual(topics.count('topic-1.info'), 2)
+        self.assertEqual(topics.count('topic-1.error'), 2)
+        self.assertEqual(topics.count('topic-2.debug'), 2)
+        self.assertEqual(self._broker.dropped_count, 4)
+        self.assertEqual(exception_count, 4)
+        self.assertEqual(excepted_targets.count('topic-1.bad'), 2)
+        self.assertEqual(excepted_targets.count('bad-topic.debug'), 2)
         driver.cleanup()