From 684e3f0e410e969e00d0acb6f0a5a56f8a856f84 Mon Sep 17 00:00:00 2001
From: Kenneth Giusti <kgiusti@gmail.com>
Date: Mon, 11 Jun 2018 10:02:12 -0400
Subject: [PATCH] Enable RPC call monitoring in AMQP 1.0 driver

The call monitoring feature was introduced in commit
b34ab8b1cc9f4d513a2927c102dbbe82031d9c2a for RabbitMQ.  This patch
enables the feature on the AMQP 1.0 driver - currently the only other
driver that supports RPC.

Change-Id: Ic787696852690b59779fb4716aec1e78c48bbe6a
---
 .../_drivers/amqp1_driver/controller.py       |  97 +++++++++---
 oslo_messaging/_drivers/impl_amqp1.py         |  73 ++++++---
 .../tests/drivers/test_amqp_driver.py         | 145 ++++++++++++++----
 .../tests/functional/test_functional.py       |   3 +-
 .../RPC-call-monitoring-7977f047d069769a.yaml |  11 ++
 5 files changed, 262 insertions(+), 67 deletions(-)
 create mode 100644 releasenotes/notes/RPC-call-monitoring-7977f047d069769a.yaml

diff --git a/oslo_messaging/_drivers/amqp1_driver/controller.py b/oslo_messaging/_drivers/amqp1_driver/controller.py
index 0583d0e13..6ef60fdcb 100644
--- a/oslo_messaging/_drivers/amqp1_driver/controller.py
+++ b/oslo_messaging/_drivers/amqp1_driver/controller.py
@@ -110,19 +110,24 @@ class SendTask(Task):
         self._retry = None if retry is None or retry < 0 else retry
         self._wakeup = threading.Event()
         self._error = None
+        self._sender = None
 
     def wait(self):
         self._wakeup.wait()
         return self._error
 
     def _execute(self, controller):
+        if self.deadline:
+            # time out the send
+            self.timer = controller.processor.alarm(self._on_timeout,
+                                                    self.deadline)
         controller.send(self)
 
     def _prepare(self, sender):
         """Called immediately before the message is handed off to the i/o
         system.  This implies that the sender link is up.
         """
-        pass
+        self._sender = sender
 
     def _on_ack(self, state, info):
         """If wait_for_ack is True, this is called by the eventloop thread when
@@ -143,10 +148,10 @@ class SendTask(Task):
         self._wakeup.set()
 
     def _on_timeout(self):
-        """Invoked by the eventloop when the send fails to complete before the
-        timeout is reached.
+        """Invoked by the eventloop when our timer expires
         """
         self.timer = None
+        self._sender and self._sender.cancel_send(self)
         msg = ("{name} message sent to {target} failed: timed"
                " out".format(name=self.name, target=self.target))
         LOG.warning("%s", msg)
@@ -172,6 +177,7 @@ class SendTask(Task):
         self._wakeup.set()
 
     def _cleanup(self):
+        self._sender = None
         if self.timer:
             self.timer.cancel()
             self.timer = None
@@ -202,6 +208,7 @@ class RPCCallTask(SendTask):
         return error or self._reply_msg
 
     def _prepare(self, sender):
+        super(RPCCallTask, self)._prepare(sender)
         # reserve a message id for mapping the received response
         if self._msg_id:
             # already set so this is a re-transmit. To be safe cancel the old
@@ -214,7 +221,6 @@ class RPCCallTask(SendTask):
     def _on_reply(self, message):
         # called if/when the reply message arrives
         self._reply_msg = message
-        self._msg_id = None  # to prevent _cleanup() from cancelling it
         self._cleanup()
         self._wakeup.set()
 
@@ -226,10 +232,60 @@ class RPCCallTask(SendTask):
     def _cleanup(self):
         if self._msg_id:
             self._reply_link.cancel_response(self._msg_id)
+            self._msg_id = None
         self._reply_link = None
         super(RPCCallTask, self)._cleanup()
 
 
+class RPCMonitoredCallTask(RPCCallTask):
+    """An RPC call which expects a periodic heartbeat until the response is
+    received.  There are two timeouts:
+    deadline - overall hard timeout, implemented in RPCCallTask
+    call_monitor_timeout - keep alive timeout, reset when heartbeat arrives
+    """
+    def __init__(self, target, message, deadline, call_monitor_timeout,
+                 retry, wait_for_ack):
+        super(RPCMonitoredCallTask, self).__init__(target, message, deadline,
+                                                   retry, wait_for_ack)
+        assert call_monitor_timeout is not None  # nosec
+        self._monitor_timeout = call_monitor_timeout
+        self._monitor_timer = None
+        self._set_alarm = None
+
+    def _execute(self, controller):
+        self._set_alarm = controller.processor.defer
+        self._monitor_timer = self._set_alarm(self._call_timeout,
+                                              self._monitor_timeout)
+        super(RPCMonitoredCallTask, self)._execute(controller)
+
+    def _call_timeout(self):
+        # monitor_timeout expired
+        self._monitor_timer = None
+        self._sender and self._sender.cancel_send(self)
+        msg = ("{name} message sent to {target} failed: call monitor timed"
+               " out".format(name=self.name, target=self.target))
+        LOG.warning("%s", msg)
+        self._error = exceptions.MessagingTimeout(msg)
+        self._cleanup()
+        self._wakeup.set()
+
+    def _on_reply(self, message):
+        # if reply is null, then this is the call monitor heartbeat
+        if message.body is None:
+            self._monitor_timer.cancel()
+            self._monitor_timer = self._set_alarm(self._call_timeout,
+                                                  self._monitor_timeout)
+        else:
+            super(RPCMonitoredCallTask, self)._on_reply(message)
+
+    def _cleanup(self):
+        self._set_alarm = None
+        if self._monitor_timer:
+            self._monitor_timer.cancel()
+            self._monitor_timer = None
+        super(RPCMonitoredCallTask, self)._cleanup()
+
+
 class MessageDispositionTask(Task):
     """A task that updates the message disposition as accepted or released
     for a Server
@@ -329,23 +385,22 @@ class Sender(pyngus.SenderEventHandler):
     def send_message(self, send_task):
         """Send a message out the link.
         """
-        if send_task.deadline:
-            def timer_callback():
-                # may be in either list, or none
-                self._unacked.discard(send_task)
-                try:
-                    self._pending_sends.remove(send_task)
-                except ValueError:
-                    pass
-                send_task._on_timeout()
-            send_task.timer = self._scheduler.alarm(timer_callback,
-                                                    send_task.deadline)
-
         if not self._can_send or self._pending_sends:
             self._pending_sends.append(send_task)
         else:
             self._send(send_task)
 
+    def cancel_send(self, send_task):
+        """Attempts to cancel a send request.  It is possible that the send has
+        already completed, so this is best-effort.
+        """
+        # may be in either list, or none
+        self._unacked.discard(send_task)
+        try:
+            self._pending_sends.remove(send_task)
+        except ValueError:
+            pass
+
     # Pyngus callbacks:
 
     def sender_active(self, sender_link):
@@ -537,10 +592,12 @@ class Replies(pyngus.ReceiverEventHandler):
 
     def prepare_for_response(self, request, callback):
         """Apply a unique message identifier to this request message. This will
-        be used to identify messages sent in reply.  The identifier is placed
-        in the 'id' field of the request message.  It is expected that the
-        identifier will appear in the 'correlation-id' field of the
+        be used to identify messages received in reply.  The identifier is
+        placed in the 'id' field of the request message.  It is expected that
+        the identifier will appear in the 'correlation-id' field of the
         corresponding response message.
+
+        When the caller is done receiving replies, it must call cancel_response
         """
         request.id = uuid.uuid4().hex
         # reply is placed on reply_queue
@@ -597,8 +654,6 @@ class Replies(pyngus.ReceiverEventHandler):
         key = message.correlation_id
         try:
             self._correlation[key](message)
-            # cleanup (only need one response per request)
-            del self._correlation[key]
             receiver.message_accepted(handle)
         except KeyError:
             LOG.warning(_LW("Can't find receiver for response msg id=%s, "
diff --git a/oslo_messaging/_drivers/impl_amqp1.py b/oslo_messaging/_drivers/impl_amqp1.py
index b56cee066..a5a229636 100644
--- a/oslo_messaging/_drivers/impl_amqp1.py
+++ b/oslo_messaging/_drivers/impl_amqp1.py
@@ -45,6 +45,11 @@ controller = importutils.try_import(
 )
 LOG = logging.getLogger(__name__)
 
+# Build/Decode RPC Response messages
+# Body Format - json string containing a map with keys:
+# 'failure' - (optional) serialized exception from remote
+# 'response' - (if no failure provided) data returned by call
+
 
 def marshal_response(reply, failure):
     # TODO(grs): do replies have a context?
@@ -70,7 +75,14 @@ def unmarshal_response(message, allowed):
     return data.get("response")
 
 
-def marshal_request(request, context, envelope):
+# Build/Decode RPC Request and Notification messages
+# Body Format: json string containing a map with keys:
+# 'request' - possibly serialized application data
+# 'context' - context provided by the application
+# 'call_monitor_timeout' - optional time in seconds for RPC call monitoring
+
+def marshal_request(request, context, envelope=False,
+                    call_monitor_timeout=None):
     # NOTE(flaper87): Set inferred to True since rabbitmq-amqp-1.0 doesn't
     # have support for vbin8.
     msg = proton.Message(inferred=True)
@@ -80,6 +92,8 @@ def marshal_request(request, context, envelope):
         "request": request,
         "context": context
     }
+    if call_monitor_timeout is not None:
+        data["call_monitor_timeout"] = call_monitor_timeout
     msg.body = jsonutils.dumps(data)
     return msg
 
@@ -87,19 +101,36 @@ def marshal_request(request, context, envelope):
 def unmarshal_request(message):
     data = jsonutils.loads(message.body)
     msg = common.deserialize_msg(data.get("request"))
-    return (msg, data.get("context"))
+    return (msg, data.get("context"), data.get("call_monitor_timeout"))
 
 
 class ProtonIncomingMessage(base.RpcIncomingMessage):
-    def __init__(self, listener, ctxt, request, message, disposition):
+    def __init__(self, listener, message, disposition):
+        request, ctxt, client_timeout = unmarshal_request(message)
         super(ProtonIncomingMessage, self).__init__(ctxt, request)
         self.listener = listener
+        self.client_timeout = client_timeout
         self._reply_to = message.reply_to
         self._correlation_id = message.id
         self._disposition = disposition
 
     def heartbeat(self):
-        LOG.debug("Message heartbeat not implemented")
+        # heartbeats are sent "worst effort": non-blocking, no retries,
+        # pre-settled (no blocking for acks). We don't want the server thread
+        # being blocked because it is unable to send a heartbeat.
+        if not self._reply_to:
+            LOG.warning("Cannot send RPC heartbeat: no reply-to provided")
+            return
+        # send a null msg (no body). This will cause the client to simply reset
+        # its timeout (the null message is dropped).  Use time-to-live to
+        # prevent stale heartbeats from building up on the message bus
+        msg = proton.Message()
+        msg.correlation_id = self._correlation_id
+        msg.ttl = self.client_timeout
+        task = controller.SendTask("RPC KeepAlive", msg, self._reply_to,
+                                   deadline=None, retry=0, wait_for_ack=False)
+        self.listener.driver._ctrl.add_task(task)
+        task.wait()
 
     def reply(self, reply=None, failure=None):
         """Schedule an RPCReplyTask to send the reply."""
@@ -179,10 +210,9 @@ class ProtonListener(base.PollStyleListener):
         qentry = self.incoming.pop(timeout)
         if qentry is None:
             return None
-        message = qentry['message']
-        request, ctxt = unmarshal_request(message)
-        disposition = qentry['disposition']
-        return ProtonIncomingMessage(self, ctxt, request, message, disposition)
+        return ProtonIncomingMessage(self,
+                                     qentry['message'],
+                                     qentry['disposition'])
 
 
 class ProtonDriver(base.BaseDriver):
@@ -268,7 +298,8 @@ class ProtonDriver(base.BaseDriver):
 
     @_ensure_connect_called
     def send(self, target, ctxt, message,
-             wait_for_reply=False, timeout=None, call_monitor_timeout=None,
+             wait_for_reply=False,
+             timeout=None, call_monitor_timeout=None,
              retry=None):
         """Send a message to the given target.
 
@@ -292,14 +323,13 @@ class ProtonDriver(base.BaseDriver):
                       0 means no retry
                       N means N retries
         :type retry: int
-"""
-        request = marshal_request(message, ctxt, envelope=False)
-        expire = 0
+        """
+        request = marshal_request(message, ctxt, None,
+                                  call_monitor_timeout)
         if timeout:
-            expire = compute_timeout(timeout)  # when the caller times out
-            # amqp uses millisecond time values, timeout is seconds
-            request.ttl = int(timeout * 1000)
-            request.expiry_time = int(expire * 1000)
+            expire = compute_timeout(timeout)
+            request.ttl = timeout
+            request.expiry_time = compute_timeout(timeout)
         else:
             # no timeout provided by application.  If the backend is queueless
             # this could lead to a hang - provide a default to prevent this
@@ -307,8 +337,13 @@ class ProtonDriver(base.BaseDriver):
             expire = compute_timeout(self._default_send_timeout)
         if wait_for_reply:
             ack = not self._pre_settle_call
-            task = controller.RPCCallTask(target, request, expire, retry,
-                                          wait_for_ack=ack)
+            if call_monitor_timeout is None:
+                task = controller.RPCCallTask(target, request, expire, retry,
+                                              wait_for_ack=ack)
+            else:
+                task = controller.RPCMonitoredCallTask(target, request, expire,
+                                                       call_monitor_timeout,
+                                                       retry, wait_for_ack=ack)
         else:
             ack = not self._pre_settle_cast
             task = controller.SendTask("RPC Cast", request, target, expire,
@@ -344,7 +379,7 @@ class ProtonDriver(base.BaseDriver):
                       N means N retries
         :type retry: int
         """
-        request = marshal_request(message, ctxt, (version == 2.0))
+        request = marshal_request(message, ctxt, envelope=(version == 2.0))
         # no timeout is applied to notifications, however if the backend is
         # queueless this could lead to a hang - provide a default to prevent
         # this
diff --git a/oslo_messaging/tests/drivers/test_amqp_driver.py b/oslo_messaging/tests/drivers/test_amqp_driver.py
index 1bab73fce..2a8e97f62 100644
--- a/oslo_messaging/tests/drivers/test_amqp_driver.py
+++ b/oslo_messaging/tests/drivers/test_amqp_driver.py
@@ -76,18 +76,18 @@ class _ListenerThread(threading.Thread):
         self.messages = moves.queue.Queue()
         self.daemon = True
         self.started = threading.Event()
-        self._done = False
+        self._done = threading.Event()
         self.start()
         self.started.wait()
 
     def run(self):
         LOG.debug("Listener started")
         self.started.set()
-        while not self._done:
+        while not self._done.is_set():
             for in_msg in self.listener.poll(timeout=0.5):
                 self.messages.put(in_msg)
                 self.msg_count -= 1
-                self._done = self.msg_count == 0
+                self.msg_count == 0 and self._done.set()
                 if self._msg_ack:
                     in_msg.acknowledge()
                     if in_msg.message.get('method') == 'echo':
@@ -110,10 +110,59 @@ class _ListenerThread(threading.Thread):
         return msgs
 
     def kill(self, timeout=30):
-        self._done = True
+        self._done.set()
         self.join(timeout)
 
 
+class _SlowResponder(_ListenerThread):
+    # an RPC listener that pauses delay seconds before replying
+    def __init__(self, listener, delay, msg_count=1):
+        self._delay = delay
+        super(_SlowResponder, self).__init__(listener, msg_count)
+
+    def run(self):
+        LOG.debug("_SlowResponder started")
+        self.started.set()
+        while not self._done.is_set():
+            for in_msg in self.listener.poll(timeout=0.5):
+                time.sleep(self._delay)
+                in_msg.acknowledge()
+                in_msg.reply(reply={'correlation-id':
+                                    in_msg.message.get('id')})
+                self.messages.put(in_msg)
+                self.msg_count -= 1
+                self.msg_count == 0 and self._done.set()
+
+
+class _CallMonitor(_ListenerThread):
+    # an RPC listener that generates heartbeats before
+    # replying.
+    def __init__(self, listener, delay, hb_count, msg_count=1):
+        self._delay = delay
+        self._hb_count = hb_count
+        super(_CallMonitor, self).__init__(listener, msg_count)
+
+    def run(self):
+        LOG.debug("_CallMonitor started")
+        self.started.set()
+        while not self._done.is_set():
+            for in_msg in self.listener.poll(timeout=0.5):
+                hb_rate = in_msg.client_timeout / 2.0
+                deadline = time.time() + self._delay
+                while deadline > time.time():
+                    if self._done.wait(hb_rate):
+                        return
+                    if self._hb_count > 0:
+                        in_msg.heartbeat()
+                        self._hb_count -= 1
+                in_msg.acknowledge()
+                in_msg.reply(reply={'correlation-id':
+                                    in_msg.message.get('id')})
+                self.messages.put(in_msg)
+                self.msg_count -= 1
+                self.msg_count == 0 and self._done.set()
+
+
 @testtools.skipUnless(pyngus, "proton modules not present")
 class TestProtonDriverLoad(test_utils.BaseTestCase):
 
@@ -365,27 +414,11 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
 
     def test_call_late_reply(self):
         """What happens if reply arrives after timeout?"""
-
-        class _SlowResponder(_ListenerThread):
-            def __init__(self, listener, delay):
-                self._delay = delay
-                super(_SlowResponder, self).__init__(listener, 1)
-
-            def run(self):
-                self.started.set()
-                while not self._done:
-                    for in_msg in self.listener.poll(timeout=0.5):
-                        time.sleep(self._delay)
-                        in_msg.acknowledge()
-                        in_msg.reply(reply={'correlation-id':
-                                            in_msg.message.get('id')})
-                        self.messages.put(in_msg)
-                        self._done = True
-
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         target = oslo_messaging.Target(topic="test-topic")
         listener = _SlowResponder(
-            driver.listen(target, None, None)._poll_style_listener, 3)
+            driver.listen(target, None, None)._poll_style_listener,
+            delay=3)
 
         self.assertRaises(oslo_messaging.MessagingTimeout,
                           driver.send, target,
@@ -410,14 +443,14 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
 
             def run(self):
                 self.started.set()
-                while not self._done:
+                while not self._done.is_set():
                     for in_msg in self.listener.poll(timeout=0.5):
                         try:
                             raise RuntimeError("Oopsie!")
                         except RuntimeError:
                             in_msg.reply(reply=None,
                                          failure=sys.exc_info())
-                        self._done = True
+                        self._done.set()
 
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         target = oslo_messaging.Target(topic="test-topic")
@@ -442,13 +475,13 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
 
             def run(self):
                 self.started.set()
-                while not self._done:
+                while not self._done.is_set():
                     for in_msg in self.listener.poll(timeout=0.5):
                         # reply will never be acked (simulate drop):
                         in_msg._reply_to = "!no-ack!"
                         in_msg.reply(reply={'correlation-id':
                                             in_msg.message.get("id")})
-                        self._done = True
+                        self._done.set()
 
         driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
         driver._default_reply_timeout = 1
@@ -555,6 +588,66 @@ class TestAmqpSend(_AmqpBrokerTestCaseAuto):
 
         driver.cleanup()
 
+    def test_call_monitor_ok(self):
+        # verify keepalive by delaying the reply > heartbeat interval
+        driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
+        target = oslo_messaging.Target(topic="test-topic")
+        listener = _CallMonitor(
+            driver.listen(target, None, None)._poll_style_listener,
+            delay=11,
+            hb_count=100)
+        rc = driver.send(target,
+                         {"context": True},
+                         {"method": "echo", "id": "1"},
+                         wait_for_reply=True,
+                         timeout=60,
+                         call_monitor_timeout=5)
+        self.assertIsNotNone(rc)
+        self.assertEqual("1", rc.get('correlation-id'))
+        listener.join(timeout=30)
+        self.assertFalse(listener.isAlive())
+        driver.cleanup()
+
+    def test_call_monitor_bad_no_heartbeat(self):
+        # verify call fails if keepalives stop coming
+        driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
+        target = oslo_messaging.Target(topic="test-topic")
+        listener = _CallMonitor(
+            driver.listen(target, None, None)._poll_style_listener,
+            delay=11,
+            hb_count=1)
+        self.assertRaises(oslo_messaging.MessagingTimeout,
+                          driver.send,
+                          target,
+                          {"context": True},
+                          {"method": "echo", "id": "1"},
+                          wait_for_reply=True,
+                          timeout=60,
+                          call_monitor_timeout=5)
+        listener.kill()
+        self.assertFalse(listener.isAlive())
+        driver.cleanup()
+
+    def test_call_monitor_bad_call_timeout(self):
+        # verify call fails if deadline hit regardless of heartbeat activity
+        driver = amqp_driver.ProtonDriver(self.conf, self._broker_url)
+        target = oslo_messaging.Target(topic="test-topic")
+        listener = _CallMonitor(
+            driver.listen(target, None, None)._poll_style_listener,
+            delay=20,
+            hb_count=100)
+        self.assertRaises(oslo_messaging.MessagingTimeout,
+                          driver.send,
+                          target,
+                          {"context": True},
+                          {"method": "echo", "id": "1"},
+                          wait_for_reply=True,
+                          timeout=11,
+                          call_monitor_timeout=5)
+        listener.kill()
+        self.assertFalse(listener.isAlive())
+        driver.cleanup()
+
 
 class TestAmqpNotification(_AmqpBrokerTestCaseAuto):
     """Test sending and receiving notifications."""
diff --git a/oslo_messaging/tests/functional/test_functional.py b/oslo_messaging/tests/functional/test_functional.py
index e2eb6d73a..69e26bdbe 100644
--- a/oslo_messaging/tests/functional/test_functional.py
+++ b/oslo_messaging/tests/functional/test_functional.py
@@ -153,7 +153,8 @@ class CallTestCase(utils.SkipIfNoTransportURL):
         self.assertEqual(10, server.endpoint.ival)
 
     def test_monitor_long_call(self):
-        if not self.url.startswith("rabbit://"):
+        if not (self.url.startswith("rabbit://")
+                or self.url.startswith("amqp://")):
             self.skipTest("backend does not support call monitoring")
 
         transport = self.useFixture(utils.RPCTransportFixture(self.conf,
diff --git a/releasenotes/notes/RPC-call-monitoring-7977f047d069769a.yaml b/releasenotes/notes/RPC-call-monitoring-7977f047d069769a.yaml
new file mode 100644
index 000000000..352f03ed0
--- /dev/null
+++ b/releasenotes/notes/RPC-call-monitoring-7977f047d069769a.yaml
@@ -0,0 +1,11 @@
+---
+prelude: >
+    RPCClient now supports RPC call monitoring for detecting the loss
+    of a server during an RPC call.
+features:
+  - |
+    RPC call monitoring is a new RPCClient feature.  Call monitoring
+    causes the RPC server to periodically send keepalive messages back
+    to the RPCClient while the RPC call is being processed.  This can
+    be used for early detection of a server failure without having to
+    wait for the full call timeout to expire.