diff --git a/oslo/messaging/_drivers/amqpdriver.py b/oslo/messaging/_drivers/amqpdriver.py index b68ae3e56..3142aca1e 100644 --- a/oslo/messaging/_drivers/amqpdriver.py +++ b/oslo/messaging/_drivers/amqpdriver.py @@ -106,7 +106,7 @@ class ReplyWaiters(object): self._wrn_threshhold = 10 def get(self, msg_id): - return self._queues.get(msg_id) + return self._queues[msg_id].get() def put(self, msg_id, message_data): queue = self._queues.get(msg_id) @@ -177,10 +177,12 @@ class ReplyWaiter(object): while True: while self.incoming: message_data = self.incoming.pop(0) - if message_data.pop('_msg_id', None) == msg_id: + + incoming_msg_id = message_data.pop('_msg_id', None) + if incoming_msg_id == msg_id: return self._process_reply(message_data) - self.waiters.put(msg_id, message_data) + self.waiters.put(incoming_msg_id, message_data) # FIXME(markmc): timeout? self.conn.consume(limit=1) @@ -205,11 +207,12 @@ class ReplyWaiter(object): while True: if self.conn_lock.acquire(False): try: - reply, ending = self._poll_connection(msg_id) - if reply: - final_reply = reply - elif ending: - return final_reply + while True: + reply, ending = self._poll_connection(msg_id) + if reply: + final_reply = reply + elif ending: + return final_reply finally: self.conn_lock.release() self.waiters.wake_all(msg_id) diff --git a/tests/test_rabbit.py b/tests/test_rabbit.py index 9ea084547..bb2880e15 100644 --- a/tests/test_rabbit.py +++ b/tests/test_rabbit.py @@ -55,28 +55,40 @@ class TestRabbitDriver(test_utils.BaseTestCase): listener = driver.listen(target) + senders = [] replies = [] + msgs = [] - def send_and_wait_for_reply(): + def send_and_wait_for_reply(i): replies.append(driver.send(target, {}, - {'foo': 'bar'}, + {'foo': i}, wait_for_reply=True)) - sender = threading.Thread(target=send_and_wait_for_reply) - sender.start() + while len(senders) < 10: + senders.append(threading.Thread(target=send_and_wait_for_reply, + args=(len(senders), ))) - received = listener.poll() - self.assertTrue(received is not None) - self.assertEqual(received.ctxt, {}) - self.assertEqual(received.message, {'foo': 'bar'}) + for i in range(len(senders)): + senders[i].start() - received.reply({'bar': 'foo'}) + received = listener.poll() + self.assertTrue(received is not None) + self.assertEqual(received.ctxt, {}) + self.assertEqual(received.message, {'foo': i}) + msgs.append(received) - sender.join() + # reply in reverse, except reply to the first guy second from last + order = range(len(senders)-1, -1, -1) + order[-1], order[-2] = order[-2], order[-1] - self.assertEqual(len(replies), 1) - self.assertEqual(replies[0], {'bar': 'foo'}) + for i in order: + msgs[i].reply({'bar': msgs[i].message['foo']}) + senders[i].join() + + self.assertEqual(len(replies), len(senders)) + for i, reply in enumerate(replies): + self.assertEqual(reply, {'bar': order[i]}) def _declare_queue(target):