From 9fab0bdbc8d9309de28a18ed4f167b8abccb765f Mon Sep 17 00:00:00 2001
From: Kenneth Giusti <kgiusti@gmail.com>
Date: Mon, 5 Dec 2016 08:46:43 -0500
Subject: [PATCH] [AMQP 1.0] Resend messages that are released or modified

A message ack status of 'RELEASED' or 'MODIFIED' indicates that the
message was not accepted by the destination due to some temporary
issue.  These status are used to indicate to the sender that the
message can be safely re-transmitted without risk of duplication
(i.e. the delivery is not 'in-doubt').  For example this may happen
during a message bus topology update if a message is sent before the
topology stabilizes.

This change implements re-send for these cases.

Closes-Bug: #1646586
Change-Id: I419e23b59e3eb90fda3f1c0e7ddf54ef98870e4b
---
 .../_drivers/amqp1_driver/controller.py       | 155 ++++++++++++------
 oslo_messaging/_drivers/amqp1_driver/opts.py  |   9 +-
 oslo_messaging/_drivers/impl_amqp1.py         |   6 +-
 .../tests/drivers/test_amqp_driver.py         |  90 ++++++++--
 4 files changed, 191 insertions(+), 69 deletions(-)

diff --git a/oslo_messaging/_drivers/amqp1_driver/controller.py b/oslo_messaging/_drivers/amqp1_driver/controller.py
index 49aba929e..56fd97708 100644
--- a/oslo_messaging/_drivers/amqp1_driver/controller.py
+++ b/oslo_messaging/_drivers/amqp1_driver/controller.py
@@ -104,10 +104,10 @@ class SendTask(Task):
         self.target = target() if isinstance(target, Target) else target
         self.message = message
         self.deadline = deadline
-        self.retry = retry
         self.wait_for_ack = wait_for_ack
         self.service = SERVICE_NOTIFY if notification else SERVICE_RPC
         self.timer = None
+        self._retry = None if retry is None or retry < 0 else retry
         self._wakeup = threading.Event()
         self._error = None
 
@@ -122,18 +122,15 @@ class SendTask(Task):
         """Called immediately before the message is handed off to the i/o
         system.  This implies that the sender link is up.
         """
-        if not self.wait_for_ack:
-            # sender is not concerned with waiting for acknowledgment
-            # "best effort at-most-once delivery"
-            self._cleanup()
-            self._wakeup.set()
+        pass
 
     def _on_ack(self, state, info):
-        """Called by eventloop thread when the ack/nack is received from the
-        peer.
+        """If wait_for_ack is True, this is called by the eventloop thread when
+        the ack/nack is received from the peer.  If wait_for_ack is False this
+        is called by the eventloop right after the message is written to the
+        link.  In the last case state will always be set to ACCEPTED.
         """
         if state != pyngus.SenderLink.ACCEPTED:
-            # TODO(kgiusti): could retry if deadline not hit
             msg = ("{name} message send to {target} failed: remote"
                    " disposition: {disp}, info:"
                    "{info}".format(name=self.name,
@@ -179,15 +176,23 @@ class SendTask(Task):
             self.timer.cancel()
             self.timer = None
 
+    @property
+    def _can_retry(self):
+        # has the retry count expired?
+        if self._retry is not None:
+            self._retry -= 1
+            if self._retry < 0:
+                return False
+        return True
+
 
 class RPCCallTask(SendTask):
     """Performs an RPC Call.  Sends the request and waits for a response from
     the destination.
     """
-
-    def __init__(self, target, message, deadline, retry, wait_for_ack):
+    def __init__(self, target, message, deadline, retry):
         super(RPCCallTask, self).__init__("RPC Call", message, target,
-                                          deadline, retry, wait_for_ack)
+                                          deadline, retry, wait_for_ack=True)
         self._reply_link = None
         self._reply_msg = None
         self._msg_id = None
@@ -198,32 +203,30 @@ class RPCCallTask(SendTask):
 
     def _prepare(self, sender):
         # reserve a message id for mapping the received response
+        if self._msg_id:
+            # already set so this is a re-transmit. To be safe cancel the old
+            # msg_id and allocate a fresh one.
+            self._reply_link.cancel_response(self._msg_id)
         self._reply_link = sender._reply_link
         rl = self._reply_link
         self._msg_id = rl.prepare_for_response(self.message, self._on_reply)
 
     def _on_reply(self, message):
         # called if/when the reply message arrives
-        if self._wakeup.is_set():
-            LOG.debug("RPC Reply received after call completed")
-            return
         self._reply_msg = message
-        self._reply_link = None
+        self._msg_id = None  # to prevent _cleanup() from cancelling it
         self._cleanup()
         self._wakeup.set()
 
     def _on_ack(self, state, info):
-        if self._wakeup.is_set():
-            LOG.debug("RPC ACKed after call completed: %s %s", state, info)
-            return
         if state != pyngus.SenderLink.ACCEPTED:
             super(RPCCallTask, self)._on_ack(state, info)
         # must wait for reply if ACCEPTED
 
     def _cleanup(self):
-        if self._reply_link and self._msg_id:
+        if self._msg_id:
             self._reply_link.cancel_response(self._msg_id)
-            self._msg_id = None
+        self._reply_link = None
         super(RPCCallTask, self)._cleanup()
 
 
@@ -260,18 +263,23 @@ class Sender(pyngus.SenderEventHandler):
         self._address = None
         self._link = None
         self._scheduler = scheduler
-        self._delay = delay  # for re-connecting
+        self._delay = delay  # for re-connecting/re-transmitting
         # holds all pending SendTasks
         self._pending_sends = collections.deque()
         # holds all messages sent but not yet acked
         self._unacked = set()
         self._reply_link = None
         self._connection = None
+        self._resend_timer = None
 
     @property
     def pending_messages(self):
         return len(self._pending_sends)
 
+    @property
+    def unacked_messages(self):
+        return len(self._unacked)
+
     def attach(self, connection, reply_link, addresser):
         """Open the link. Called by the Controller when the AMQP connection
         becomes active.
@@ -290,6 +298,9 @@ class Sender(pyngus.SenderEventHandler):
         LOG.debug("Sender %s detached", self._address)
         self._connection = None
         self._reply_link = None
+        if self._resend_timer:
+            self._resend_timer.cancel()
+            self._resend_timer = None
         if self._link:
             self._link.close()
 
@@ -376,11 +387,9 @@ class Sender(pyngus.SenderEventHandler):
         # sends that have exhausted their retry count:
         expired = set()
         for send_task in self._pending_sends:
-            if send_task.retry is not None:
-                send_task.retry -= 1
-                if send_task.retry <= 0:
-                    expired.add(send_task)
-                    send_task._on_error("Message send failed: %s" % reason)
+            if not send_task._can_retry:
+                expired.add(send_task)
+                send_task._on_error("Message send failed: %s" % reason)
         while expired:
             self._pending_sends.remove(expired.pop())
 
@@ -401,26 +410,75 @@ class Sender(pyngus.SenderEventHandler):
     def _can_send(self):
         return self._link and self._link.active
 
+    # acknowledge status
+    _TIMED_OUT = pyngus.SenderLink.TIMED_OUT
+    _ACCEPTED = pyngus.SenderLink.ACCEPTED
+    _RELEASED = pyngus.SenderLink.RELEASED
+    _MODIFIED = pyngus.SenderLink.MODIFIED
+
     def _send(self, send_task):
         send_task._prepare(self)
         send_task.message.address = self._address
+        if send_task.wait_for_ack:
+            self._unacked.add(send_task)
 
-        def pyngus_callback(link, handle, state, info):
-            # invoked when the message bus (n)acks this message
-            if state == pyngus.SenderLink.TIMED_OUT:
-                # ignore pyngus timeout - we maintain our own timer
-                return
-            self._unacked.discard(send_task)
-            send_task._on_ack(state, info)
+            def pyngus_callback(link, handle, state, info):
+                # invoked when the message bus (n)acks this message
+                if state == Sender._TIMED_OUT:
+                    # ignore pyngus timeout - we maintain our own timer
+                    # which will properly deal with this case
+                    return
+                self._unacked.discard(send_task)
+                if state == Sender._ACCEPTED:
+                    send_task._on_ack(Sender._ACCEPTED, info)
+                elif (state == Sender._RELEASED
+                      or (state == Sender._MODIFIED and
+                          # assuming delivery-failed means in-doubt:
+                          not info.get("delivery-failed") and
+                          not info.get("undeliverable-here"))):
+                    # These states indicate that the message was never
+                    # forwarded beyond the next hop so they can be
+                    # re-transmitted without risk of duplication
+                    self._resend(send_task)
+                else:
+                    # some error - let task figure it out...
+                    send_task._on_ack(state, info)
 
-        self._unacked.add(send_task)
-        self._link.send(send_task.message,
-                        delivery_callback=pyngus_callback,
-                        handle=self,
-                        deadline=send_task.deadline)
+            self._link.send(send_task.message,
+                            delivery_callback=pyngus_callback,
+                            handle=self,
+                            deadline=send_task.deadline)
+        else:  # do not wait for ack
+            self._link.send(send_task.message,
+                            delivery_callback=None,
+                            handle=self,
+                            deadline=send_task.deadline)
+            send_task._on_ack(pyngus.SenderLink.ACCEPTED, {})
+
+    def _resend(self, send_task):
+        # the message bus returned the message without forwarding it. Wait a
+        # bit for other outstanding sends to finish - most likely ending up
+        # here since they are all going to the same destination - then resend
+        # this message
+        if send_task._can_retry:
+            # note well: once there is something on the pending list no further
+            # messages will be sent (they will all queue up behind this one).
+            self._pending_sends.append(send_task)
+            if self._resend_timer is None:
+                sched = self._scheduler
+                # this will get the pending sends going again
+                self._resend_timer = sched.defer(self._resend_pending,
+                                                 self._delay)
+        else:
+            send_task._on_error("Send retries exhausted")
+
+    def _resend_pending(self):
+        # run from the _resend_timer, attempt to resend pending messages
+        self._resend_timer = None
+        self._send_pending()
 
     def _send_pending(self):
-        # send all pending messages
+        # flush all pending messages out
         if self._can_send:
             while self._pending_sends:
                 self._send(self._pending_sends.popleft())
@@ -472,7 +530,7 @@ class Replies(pyngus.ReceiverEventHandler):
             self._receiver.close()
 
     def destroy(self):
-        self._correlation = None
+        self._correlation.clear()
         if self._receiver:
             self._receiver.destroy()
             self._receiver = None
@@ -494,11 +552,10 @@ class Replies(pyngus.ReceiverEventHandler):
         """Abort waiting for the response message corresponding to msg_id.
         This can be used if the request fails and no reply is expected.
         """
-        if self._correlation:
-            try:
-                del self._correlation[msg_id]
-            except KeyError:
-                pass
+        try:
+            del self._correlation[msg_id]
+        except KeyError:
+            pass
 
     @property
     def active(self):
@@ -864,8 +921,6 @@ class Controller(pyngus.ConnectionEventHandler):
         if send_task.deadline and send_task.deadline <= now():
             send_task._on_timeout()
             return
-        if send_task.retry is None or send_task.retry < 0:
-            send_task.retry = None
         key = keyify(send_task.target, send_task.service)
         sender = self._all_senders.get(key)
         if not sender:
@@ -1142,7 +1197,7 @@ class Controller(pyngus.ConnectionEventHandler):
         self._active_senders.clear()
         unused = []
         for key, sender in iteritems(self._all_senders):
-            # clean up any unused sender links
+            # clean up any sender links that no longer have messages to send
             if sender.pending_messages == 0:
                 unused.append(key)
             else:
@@ -1183,7 +1238,7 @@ class Controller(pyngus.ConnectionEventHandler):
             purge = set(self._all_senders.keys()) - self._active_senders
             for key in purge:
                 sender = self._all_senders[key]
-                if sender.pending_messages == 0:
+                if not sender.pending_messages and not sender.unacked_messages:
                     sender.detach()
                     self._purged_senders.append(self._all_senders.pop(key))
             self._active_senders.clear()
diff --git a/oslo_messaging/_drivers/amqp1_driver/opts.py b/oslo_messaging/_drivers/amqp1_driver/opts.py
index 9278ebdd0..127dd0d6c 100644
--- a/oslo_messaging/_drivers/amqp1_driver/opts.py
+++ b/oslo_messaging/_drivers/amqp1_driver/opts.py
@@ -109,11 +109,16 @@ amqp1_opts = [
                help='Time to pause between re-connecting an AMQP 1.0 link that'
                ' failed due to a recoverable error.'),
 
+    cfg.IntOpt('default_reply_retry',
+               default=0,
+               min=-1,
+               help='The maximum number of attempts to re-send a reply message'
+               ' which failed due to a recoverable error.'),
+
     cfg.IntOpt('default_reply_timeout',
                default=30,
                min=5,
-               help='The deadline for an rpc reply message delivery.'
-               ' Only used when caller does not provide a timeout expiry.'),
+               help='The deadline for an rpc reply message delivery.'),
 
     cfg.IntOpt('default_send_timeout',
                default=30,
diff --git a/oslo_messaging/_drivers/impl_amqp1.py b/oslo_messaging/_drivers/impl_amqp1.py
index e48c47824..44b99ff14 100644
--- a/oslo_messaging/_drivers/impl_amqp1.py
+++ b/oslo_messaging/_drivers/impl_amqp1.py
@@ -109,7 +109,7 @@ class ProtonIncomingMessage(base.RpcIncomingMessage):
             task = controller.SendTask("RPC Reply", response, self._reply_to,
                                        # analogous to kombu missing dest t/o:
                                        deadline,
-                                       retry=0,
+                                       retry=driver._default_reply_retry,
                                        wait_for_ack=ack)
             driver._ctrl.add_task(task)
             rc = task.wait()
@@ -216,6 +216,7 @@ class ProtonDriver(base.BaseDriver):
         self._default_reply_timeout = opt_name.default_reply_timeout
         self._default_send_timeout = opt_name.default_send_timeout
         self._default_notify_timeout = opt_name.default_notify_timeout
+        self._default_reply_retry = opt_name.default_reply_retry
 
         # which message types should be sent pre-settled?
         ps = [s.lower() for s in opt_name.pre_settled]
@@ -301,8 +302,7 @@ class ProtonDriver(base.BaseDriver):
             expire = compute_timeout(self._default_send_timeout)
         if wait_for_reply:
             ack = not self._pre_settle_call
-            task = controller.RPCCallTask(target, request, expire, retry,
-                                          wait_for_ack=ack)
+            task = controller.RPCCallTask(target, request, expire, retry)
         else:
             ack = not self._pre_settle_cast
             task = controller.SendTask("RPC Cast", request, target, expire,
diff --git a/oslo_messaging/tests/drivers/test_amqp_driver.py b/oslo_messaging/tests/drivers/test_amqp_driver.py
index 12db37a60..30b796b32 100644
--- a/oslo_messaging/tests/drivers/test_amqp_driver.py
+++ b/oslo_messaging/tests/drivers/test_amqp_driver.py
@@ -288,7 +288,7 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
         driver.cleanup()
 
     def test_send_timeout(self):
-        """Verify send timeout."""
+        """Verify send timeout - no reply sent."""
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         target = oslo_messaging.Target(topic="test-topic")
         listener = _ListenerThread(
@@ -310,17 +310,19 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         target = oslo_messaging.Target(topic="no listener")
 
-        # the broker will send a nack:
+        # the broker will send a nack (released) since there is no active
+        # listener for the target:
         self.assertRaises(oslo_messaging.MessageDeliveryFailure,
                           driver.send, target,
                           {"context": "whatever"},
                           {"method": "drop"},
                           wait_for_reply=True,
+                          retry=0,
                           timeout=1.0)
         driver.cleanup()
 
     def test_send_not_acked(self):
-        """Verify exception thrown if send Nacked."""
+        """Verify exception thrown ack dropped."""
         self.config(pre_settled=[],
                     group="oslo_messaging_amqp")
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
@@ -333,7 +335,8 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
                           driver.send, target,
                           {"context": "whatever"},
                           {"method": "drop"},
-                          wait_for_reply=False)
+                          retry=0,
+                          wait_for_reply=True)
         driver.cleanup()
 
     def test_no_ack_cast(self):
@@ -393,7 +396,7 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
         driver.cleanup()
 
     def test_call_failed_reply(self):
-        """Send back an exception"""
+        """Send back an exception generated at the listener"""
         class _FailedResponder(_ListenerThread):
             def __init__(self, listener):
                 super(_FailedResponder, self).__init__(listener, 1)
@@ -434,7 +437,7 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
                 self.started.set()
                 while not self._done:
                     for in_msg in self.listener.poll(timeout=0.5):
-                        # reply will never be acked:
+                        # reply will never be acked (simulate drop):
                         in_msg._reply_to = "!no-ack!"
                         in_msg.reply(reply={'correlation-id':
                                             in_msg.message.get("id")})
@@ -458,6 +461,7 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
 
     def test_listener_requeue(self):
         "Emulate Server requeue on listener incoming messages"
+        self.config(pre_settled=[], group="oslo_messaging_amqp")
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         driver.require_features(requeue=True)
         target = oslo_messaging.Target(topic="test-topic")
@@ -472,10 +476,6 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
         listener.join(timeout=30)
         self.assertFalse(listener.isAlive())
 
-        for x in listener.get_messages():
-            x.requeue()
-            self.assertEqual(x.message, {"msg": "value"})
-
         predicate = lambda: (self._broker.sender_link_requeue_count == 1)
         _wait_until(predicate, 30)
         self.assertTrue(predicate())
@@ -575,7 +575,7 @@ class TestAmqpNotification(_AmqpBrokerTestCaseAuto):
                 try:
                     driver.send_notification(oslo_messaging.Target(topic=t),
                                              "context", {'target': t},
-                                             version)
+                                             version, retry=0)
                 except oslo_messaging.MessageDeliveryFailure:
                     excepted_targets.append(t)
 
@@ -592,15 +592,18 @@ class TestAmqpNotification(_AmqpBrokerTestCaseAuto):
         driver.cleanup()
 
     def test_released_notification(self):
+        """Broker sends a Nack (released)"""
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         self.assertRaises(oslo_messaging.MessageDeliveryFailure,
                           driver.send_notification,
                           oslo_messaging.Target(topic="bad address"),
                           "context", {'target': "bad address"},
-                          2.0)
+                          2.0,
+                          retry=0)
         driver.cleanup()
 
     def test_notification_not_acked(self):
+        """Simulate drop of ack from broker"""
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         # set this directly so we can use a value < minimum allowed
         driver._default_notify_timeout = 2
@@ -608,7 +611,7 @@ class TestAmqpNotification(_AmqpBrokerTestCaseAuto):
                           driver.send_notification,
                           oslo_messaging.Target(topic="!no-ack!"),
                           "context", {'target': "!no-ack!"},
-                          2.0)
+                          2.0, retry=0)
         driver.cleanup()
 
     def test_no_ack_notification(self):
@@ -1388,6 +1391,64 @@ class TestAddressing(test_utils.BaseTestCase):
                               LegacyAddresser)
 
 
+@testtools.skipUnless(pyngus, "proton modules not present")
+class TestMessageRetransmit(_AmqpBrokerTestCase):
+    # test message is retransmitted if safe to do so
+    def _test_retransmit(self, nack_method):
+        self._nack_count = 2
+
+        def _on_message(message, handle, link):
+            if self._nack_count:
+                self._nack_count -= 1
+                nack_method(link, handle)
+            else:
+                self._broker.forward_message(message, handle, link)
+
+        self._broker.on_message = _on_message
+        self._broker.start()
+        self.config(link_retry_delay=1, pre_settled=[],
+                    group="oslo_messaging_amqp")
+        driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
+        target = oslo_messaging.Target(topic="test-topic")
+        listener = _ListenerThread(driver.listen(target,
+                                                 None,
+                                                 None)._poll_style_listener,
+                                   1)
+        rc = driver.send(target, {"context": "whatever"},
+                         {"method": "echo", "id": "blah"},
+                         wait_for_reply=True,
+                         retry=2)  # initial send + up to 2 resends
+        self.assertIsNotNone(rc)
+        self.assertEqual(0, self._nack_count)
+        self.assertEqual(rc.get('correlation-id'), 'blah')
+        listener.join(timeout=30)
+        self.assertFalse(listener.isAlive())
+        driver.cleanup()
+
+    def test_released(self):
+        # should retry and succeed
+        self._test_retransmit(lambda l, h: l.message_released(h))
+
+    def test_modified(self):
+        # should retry and succeed
+        self._test_retransmit(lambda l, h: l.message_modified(h,
+                                                              False,
+                                                              False,
+                                                              {}))
+
+    def test_modified_failed(self):
+        # since delivery_failed is set to True, should fail
+        self.assertRaises(oslo_messaging.MessageDeliveryFailure,
+                          self._test_retransmit,
+                          lambda l, h: l.message_modified(h, True, False, {}))
+
+    def test_rejected(self):
+        # rejected - should fail
+        self.assertRaises(oslo_messaging.MessageDeliveryFailure,
+                          self._test_retransmit,
+                          lambda l, h: l.message_rejected(h, {}))
+
+
 class FakeBroker(threading.Thread):
     """A test AMQP message 'broker'."""
 
@@ -1609,7 +1670,7 @@ class FakeBroker(threading.Thread):
 
             def message_received(self, receiver_link, message, handle):
                 """Forward this message out the proper sending link."""
-                self.server.forward_message(message, handle, receiver_link)
+                self.server.on_message(message, handle, receiver_link)
                 if self.link.capacity < 1:
                     self.server.on_credit_exhausted(self.link)
 
@@ -1674,6 +1735,7 @@ class FakeBroker(threading.Thread):
         self.on_sender_active = lambda link: None
         self.on_receiver_active = lambda link: link.add_capacity(10)
         self.on_credit_exhausted = lambda link: link.add_capacity(10)
+        self.on_message = lambda m, h, l: self.forward_message(m, h, l)
 
     def start(self):
         """Start the server."""