From f1c7e78a56794f016ad63c28b036dff12e259953 Mon Sep 17 00:00:00 2001
From: Joshua Harlow <harlowja@yahoo-inc.com>
Date: Wed, 26 Nov 2014 11:40:02 -0800
Subject: [PATCH] Have the timeout decrement inside the wait() method

Currently it appears the timeout provided to the wait()
method will be reused X number of times (where X may be
indeterminate depending on reconnections, loss of messages
and so-on) so instead of reusing it which can potentially
result in a infinite number of calls have a new object be
used that will cause the subsequent timeouts used elsewhere
in the wait function to actually decay correctly.

Closes-bug: #1379394

Change-Id: I12c4ea1eef6b857d12246db0483adaf7c87e740c
---
 oslo/messaging/_drivers/amqpdriver.py | 37 ++++++++++++++++++++++-----
 1 file changed, 31 insertions(+), 6 deletions(-)

diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py
index 5a522d92f..f82c1eee6 100644
--- a/oslo/messaging/_drivers/amqpdriver.py
+++ b/oslo/messaging/_drivers/amqpdriver.py
@@ -31,6 +31,30 @@ from oslo.messaging._i18n import _LI
 LOG = logging.getLogger(__name__)
 
 
+class _DecayingTimer(object):
+    def __init__(self, duration=None):
+        self._duration = duration
+        self._ends_at = None
+
+    def start(self):
+        if self._duration is not None:
+            self._ends_at = time.time() + max(0, self._duration)
+        return self
+
+    def check_return(self, msg_id):
+        if self._duration is None:
+            return None
+        if self._ends_at is None:
+            raise RuntimeError("Can not check/return a timeout from a timer"
+                               " that has not been started")
+        left = self._ends_at - time.time()
+        if left <= 0:
+            raise messaging.MessagingTimeout('Timed out waiting for a '
+                                             'reply to message ID %s'
+                                             % msg_id)
+        return left
+
+
 class AMQPIncomingMessage(base.IncomingMessage):
 
     def __init__(self, listener, ctxt, message, unique_id, msg_id, reply_q):
@@ -220,7 +244,7 @@ class ReplyWaiter(object):
             result = data['result']
         return result, ending
 
-    def _poll_connection(self, msg_id, timeout):
+    def _poll_connection(self, msg_id, timer):
         while True:
             while self.incoming:
                 message_data = self.incoming.pop(0)
@@ -232,14 +256,14 @@ class ReplyWaiter(object):
                 self.waiters.put(incoming_msg_id, message_data)
 
             try:
-                self.conn.consume(limit=1, timeout=timeout)
+                self.conn.consume(limit=1, timeout=timer.check_return(msg_id))
             except rpc_common.Timeout:
                 raise messaging.MessagingTimeout('Timed out waiting for a '
                                                  'reply to message ID %s'
                                                  % msg_id)
 
-    def _poll_queue(self, msg_id, timeout):
-        message = self.waiters.get(msg_id, timeout)
+    def _poll_queue(self, msg_id, timer):
+        message = self.waiters.get(msg_id, timeout=timer.check_return(msg_id))
         if message is self.waiters.WAKE_UP:
             return None, None, True  # lock was released
 
@@ -268,6 +292,7 @@ class ReplyWaiter(object):
         # have the first thread take responsibility for passing replies not
         # intended for itself to the appropriate thread.
         #
+        timer = _DecayingTimer(duration=timeout).start()
         final_reply = None
         while True:
             if self.conn_lock.acquire(False):
@@ -286,7 +311,7 @@ class ReplyWaiter(object):
 
                     # Now actually poll the connection
                     while True:
-                        reply, ending = self._poll_connection(msg_id, timeout)
+                        reply, ending = self._poll_connection(msg_id, timer)
                         if not ending:
                             final_reply = reply
                         else:
@@ -299,7 +324,7 @@ class ReplyWaiter(object):
                     self.waiters.wake_all(msg_id)
             else:
                 # We're going to wait for the first thread to pass us our reply
-                reply, ending, trylock = self._poll_queue(msg_id, timeout)
+                reply, ending, trylock = self._poll_queue(msg_id, timer)
                 if trylock:
                     # The first thread got its reply, let's try and take over
                     # the responsibility for polling