[zmq] Refactor consumers and incoming messages

Change-Id: Ib9b5f1fbc184cc0364e3f742fab7b59bc6a7c03e
This commit is contained in:
Gevorg Davoian 2016-09-19 17:58:56 +03:00
parent 27594bd40f
commit 2b47281a7e
7 changed files with 198 additions and 195 deletions

@ -43,7 +43,6 @@ class ConsumerBase(object):
def stop(self): def stop(self):
"""Stop consumer polling/updates""" """Stop consumer polling/updates"""
pass
@abc.abstractmethod @abc.abstractmethod
def receive_message(self, target): def receive_message(self, target):

@ -18,6 +18,7 @@ import uuid
import six import six
from oslo_messaging._drivers import common as rpc_common from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver.client import zmq_senders from oslo_messaging._drivers.zmq_driver.client import zmq_senders
from oslo_messaging._drivers.zmq_driver.client import zmq_sockets_manager from oslo_messaging._drivers.zmq_driver.client import zmq_sockets_manager
from oslo_messaging._drivers.zmq_driver.server.consumers \ from oslo_messaging._drivers.zmq_driver.server.consumers \
@ -38,11 +39,7 @@ zmq = zmq_async.import_zmq()
class DealerConsumer(zmq_consumer_base.SingleSocketConsumer): class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server): def __init__(self, conf, poller, server):
self.ack_sender = zmq_senders.AckSenderProxy(conf)
self.reply_sender = zmq_senders.ReplySenderProxy(conf) self.reply_sender = zmq_senders.ReplySenderProxy(conf)
self.messages_cache = zmq_ttl_cache.TTLCache(
ttl=conf.oslo_messaging_zmq.rpc_message_ttl
)
self.sockets_manager = zmq_sockets_manager.SocketsManager( self.sockets_manager = zmq_sockets_manager.SocketsManager(
conf, server.matchmaker, zmq.ROUTER, zmq.DEALER) conf, server.matchmaker, zmq.ROUTER, zmq.DEALER)
self.host = None self.host = None
@ -68,77 +65,117 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
LOG.error(_LE("Failed connecting to ROUTER socket %(e)s") % e) LOG.error(_LE("Failed connecting to ROUTER socket %(e)s") % e)
raise rpc_common.RPCException(str(e)) raise rpc_common.RPCException(str(e))
def _receive_request(self, socket): def _reply(self, rpc_message, reply, failure):
empty = socket.recv() if failure is not None:
assert empty == b'', 'Bad format: empty delimiter expected' failure = rpc_common.serialize_remote_exception(failure)
reply_id = socket.recv() reply = zmq_response.Reply(message_id=rpc_message.message_id,
msg_type = int(socket.recv()) reply_id=rpc_message.reply_id,
message_id = socket.recv_string() reply_body=reply,
context, message = socket.recv_loaded() failure=failure)
return reply_id, msg_type, message_id, context, message self.reply_sender.send(rpc_message.socket, reply)
return reply
def _create_message(self, context, message, reply_id, message_id, socket,
message_type):
if message_type == zmq_names.CALL_TYPE:
message = zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id=reply_id, message_id=message_id,
socket=socket, reply_method=self._reply
)
else:
message = zmq_incoming_message.ZmqIncomingMessage(context, message)
LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(message_type),
"msg_id": message_id})
return message
def receive_message(self, socket): def receive_message(self, socket):
try: try:
reply_id, msg_type, message_id, context, message = \ empty = socket.recv()
self._receive_request(socket) assert empty == b'', "Empty delimiter expected!"
reply_id = socket.recv()
assert reply_id != b'', "Valid reply id expected!"
message_type = int(socket.recv())
assert message_type in zmq_names.REQUEST_TYPES, \
"Request message type expected!"
message_id = socket.recv_string()
assert message_id != '', "Valid message id expected!"
context, message = socket.recv_loaded()
if msg_type == zmq_names.CALL_TYPE or \ return self._create_message(context, message, reply_id,
msg_type in zmq_names.NON_BLOCKING_TYPES: message_id, socket, message_type)
ack_sender = self.ack_sender \
if self.conf.oslo_messaging_zmq.rpc_use_acks else None
reply_sender = self.reply_sender \
if msg_type == zmq_names.CALL_TYPE else None
message = zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket,
ack_sender, reply_sender, self.messages_cache
)
# drop a duplicate message
if message_id in self.messages_cache:
LOG.warning(
_LW("[%(host)s] Dropping duplicate %(msg_type)s "
"message %(msg_id)s"),
{"host": self.host,
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id}
)
# NOTE(gdavoian): send yet another ack for the non-CALL
# message, since the old one might be lost;
# for the CALL message also try to resend its reply
# (of course, if it was already obtained and cached).
message._acknowledge()
if msg_type == zmq_names.CALL_TYPE:
message._reply_from_cache()
return None
self.messages_cache.add(message_id)
LOG.debug(
"[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id}
)
# NOTE(gdavoian): send an immediate ack, since it may
# be too late to wait until the message will be
# dispatched and processed by a RPC server
message._acknowledge()
return message
else:
LOG.error(_LE("Unknown message type: %s"),
zmq_names.message_type_str(msg_type))
except (zmq.ZMQError, AssertionError, ValueError) as e: except (zmq.ZMQError, AssertionError, ValueError) as e:
LOG.error(_LE("Receiving message failure: %s"), str(e)) LOG.error(_LE("Receiving message failure: %s"), str(e))
def cleanup(self): def cleanup(self):
LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host) LOG.info(_LI("[%s] Destroy DEALER consumer"), self.host)
self.messages_cache.cleanup()
self.connection_updater.cleanup() self.connection_updater.cleanup()
super(DealerConsumer, self).cleanup() super(DealerConsumer, self).cleanup()
class DealerConsumerWithAcks(DealerConsumer):
def __init__(self, conf, poller, server):
super(DealerConsumerWithAcks, self).__init__(conf, poller, server)
self.ack_sender = zmq_senders.AckSenderProxy(conf)
self.messages_cache = zmq_ttl_cache.TTLCache(
ttl=conf.oslo_messaging_zmq.rpc_message_ttl
)
def _acknowledge(self, reply_id, message_id, socket):
ack = zmq_response.Ack(message_id=message_id,
reply_id=reply_id)
self.ack_sender.send(socket, ack)
def _reply(self, rpc_message, reply, failure):
reply = super(DealerConsumerWithAcks, self)._reply(rpc_message,
reply, failure)
self.messages_cache.add(rpc_message.message_id, reply)
return reply
def _reply_from_cache(self, message_id, socket):
reply = self.messages_cache.get(message_id)
if reply is not None:
self.reply_sender.send(socket, reply)
def _create_message(self, context, message, reply_id, message_id, socket,
message_type):
# drop a duplicate message
if message_id in self.messages_cache:
LOG.warning(
_LW("[%(host)s] Dropping duplicate %(msg_type)s "
"message %(msg_id)s"),
{"host": self.host,
"msg_type": zmq_names.message_type_str(message_type),
"msg_id": message_id}
)
# NOTE(gdavoian): send yet another ack for the non-CALL
# message, since the old one might be lost;
# for the CALL message also try to resend its reply
# (of course, if it was already obtained and cached).
self._acknowledge(reply_id, message_id, socket)
if message_type == zmq_names.CALL_TYPE:
self._reply_from_cache(message_id, socket)
return None
self.messages_cache.add(message_id)
# NOTE(gdavoian): send an immediate ack, since it may
# be too late to wait until the message will be
# dispatched and processed by a RPC server
self._acknowledge(reply_id, message_id, socket)
return super(DealerConsumerWithAcks, self)._create_message(
context, message, reply_id, message_id, socket, message_type
)
def cleanup(self):
self.messages_cache.cleanup()
super(DealerConsumerWithAcks, self).cleanup()
class ConsumerConnectionUpdater(zmq_updater.ConnectionUpdater): class ConsumerConnectionUpdater(zmq_updater.ConnectionUpdater):
def _update_connection(self): def _update_connection(self):

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -14,6 +14,8 @@
import logging import logging
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver.client import zmq_senders from oslo_messaging._drivers.zmq_driver.client import zmq_senders
from oslo_messaging._drivers.zmq_driver.server.consumers \ from oslo_messaging._drivers.zmq_driver.server.consumers \
import zmq_consumer_base import zmq_consumer_base
@ -30,43 +32,51 @@ zmq = zmq_async.import_zmq()
class RouterConsumer(zmq_consumer_base.SingleSocketConsumer): class RouterConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server): def __init__(self, conf, poller, server):
self.ack_sender = zmq_senders.AckSenderDirect(conf)
self.reply_sender = zmq_senders.ReplySenderDirect(conf) self.reply_sender = zmq_senders.ReplySenderDirect(conf)
super(RouterConsumer, self).__init__(conf, poller, server, zmq.ROUTER) super(RouterConsumer, self).__init__(conf, poller, server, zmq.ROUTER)
LOG.info(_LI("[%s] Run ROUTER consumer"), self.host) LOG.info(_LI("[%s] Run ROUTER consumer"), self.host)
def _receive_request(self, socket): def _reply(self, rpc_message, reply, failure):
reply_id = socket.recv() if failure is not None:
empty = socket.recv() failure = rpc_common.serialize_remote_exception(failure)
assert empty == b'', 'Bad format: empty delimiter expected' reply = zmq_response.Reply(message_id=rpc_message.message_id,
msg_type = int(socket.recv()) reply_id=rpc_message.reply_id,
message_id = socket.recv_string() reply_body=reply,
context, message = socket.recv_loaded() failure=failure)
return reply_id, msg_type, message_id, context, message self.reply_sender.send(rpc_message.socket, reply)
return reply
def _create_message(self, context, message, reply_id, message_id, socket,
message_type):
if message_type == zmq_names.CALL_TYPE:
message = zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id=reply_id, message_id=message_id,
socket=socket, reply_method=self._reply
)
else:
message = zmq_incoming_message.ZmqIncomingMessage(context, message)
LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s",
{"host": self.host,
"msg_type": zmq_names.message_type_str(message_type),
"msg_id": message_id})
return message
def receive_message(self, socket): def receive_message(self, socket):
try: try:
reply_id, msg_type, message_id, context, message = \ reply_id = socket.recv()
self._receive_request(socket) assert reply_id != b'', "Valid reply id expected!"
empty = socket.recv()
assert empty == b'', "Empty delimiter expected!"
message_type = int(socket.recv())
assert message_type in zmq_names.REQUEST_TYPES, \
"Request message type expected!"
message_id = socket.recv_string()
assert message_id != '', "Valid message id expected!"
context, message = socket.recv_loaded()
LOG.debug("[%(host)s] Received %(msg_type)s message %(msg_id)s", return self._create_message(context, message, reply_id,
{"host": self.host, message_id, socket, message_type)
"msg_type": zmq_names.message_type_str(msg_type),
"msg_id": message_id})
if msg_type == zmq_names.CALL_TYPE or \
msg_type in zmq_names.NON_BLOCKING_TYPES:
ack_sender = self.ack_sender \
if self.conf.oslo_messaging_zmq.rpc_use_acks else None
reply_sender = self.reply_sender \
if msg_type == zmq_names.CALL_TYPE else None
return zmq_incoming_message.ZmqIncomingMessage(
context, message, reply_id, message_id, socket,
ack_sender, reply_sender
)
else:
LOG.error(_LE("Unknown message type: %s"),
zmq_names.message_type_str(msg_type))
except (zmq.ZMQError, AssertionError, ValueError) as e: except (zmq.ZMQError, AssertionError, ValueError) as e:
LOG.error(_LE("Receiving message failed: %s"), str(e)) LOG.error(_LE("Receiving message failed: %s"), str(e))

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -12,68 +12,29 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import logging import six
from oslo_messaging._drivers import base from oslo_messaging._drivers import base
from oslo_messaging._drivers import common as rpc_common
from oslo_messaging._drivers.zmq_driver.client import zmq_response
from oslo_messaging._drivers.zmq_driver import zmq_async
LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq()
class ZmqIncomingMessage(base.RpcIncomingMessage): class ZmqIncomingMessage(base.RpcIncomingMessage):
"""Base class for RPC-messages via ZMQ-driver. """Base class for RPC-messages via ZMQ-driver.
Each message may send either acks/replies or just nothing Behaviour of messages is fully defined by consumers
(if acks are disabled and replies are not supported). which produced them from obtained raw data.
""" """
def __init__(self, context, message, reply_id=None, message_id=None, def __init__(self, context, message, **kwargs):
socket=None, ack_sender=None, reply_sender=None,
replies_cache=None):
if ack_sender is not None or reply_sender is not None:
assert socket is not None, "Valid socket expected!"
assert message_id is not None, "Valid message ID expected!"
assert reply_id is not None, "Valid reply ID expected!"
super(ZmqIncomingMessage, self).__init__(context, message) super(ZmqIncomingMessage, self).__init__(context, message)
self._reply_method = kwargs.pop('reply_method',
self.reply_id = reply_id lambda self, reply, failure: None)
self.message_id = message_id for key, value in six.iteritems(kwargs):
self.socket = socket setattr(self, key, value)
self.ack_sender = ack_sender
self.reply_sender = reply_sender
self.replies_cache = replies_cache
def acknowledge(self): def acknowledge(self):
"""Acknowledge is not supported publicly (used only internally).""" """Acknowledge is not supported."""
def _acknowledge(self):
if self.ack_sender is not None:
ack = zmq_response.Ack(message_id=self.message_id,
reply_id=self.reply_id)
self.ack_sender.send(self.socket, ack)
def reply(self, reply=None, failure=None): def reply(self, reply=None, failure=None):
if self.reply_sender is not None: self._reply_method(self, reply=reply, failure=failure)
if failure is not None:
failure = rpc_common.serialize_remote_exception(failure)
reply = zmq_response.Reply(message_id=self.message_id,
reply_id=self.reply_id,
reply_body=reply,
failure=failure)
self.reply_sender.send(self.socket, reply)
if self.replies_cache is not None:
self.replies_cache.add(self.message_id, reply)
def _reply_from_cache(self):
if self.reply_sender is not None and self.replies_cache is not None:
reply = self.replies_cache.get(self.message_id)
if reply is not None:
self.reply_sender.send(self.socket, reply)
def requeue(self): def requeue(self):
"""Requeue is not supported.""" """Requeue is not supported."""

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -44,14 +44,20 @@ class ZmqServer(base.PollStyleListener):
{'host': self.conf.oslo_messaging_zmq.rpc_zmq_host, {'host': self.conf.oslo_messaging_zmq.rpc_zmq_host,
'target': self.target}) 'target': self.target})
self.router_consumer = zmq_router_consumer.RouterConsumer( if conf.oslo_messaging_zmq.use_router_proxy:
conf, self.poller, self) \ self.router_consumer = None
if not conf.oslo_messaging_zmq.use_router_proxy else None dealer_consumer_cls = \
self.dealer_consumer = zmq_dealer_consumer.DealerConsumer( zmq_dealer_consumer.DealerConsumerWithAcks \
conf, self.poller, self) \ if conf.oslo_messaging_zmq.rpc_use_acks else \
if conf.oslo_messaging_zmq.use_router_proxy else None zmq_dealer_consumer.DealerConsumer
self.sub_consumer = zmq_sub_consumer.SubConsumer( self.dealer_consumer = dealer_consumer_cls(conf, self.poller, self)
conf, self.poller, self) \ else:
self.router_consumer = \
zmq_router_consumer.RouterConsumer(conf, self.poller, self)
self.dealer_consumer = None
self.sub_consumer = \
zmq_sub_consumer.SubConsumer(conf, self.poller, self) \
if conf.oslo_messaging_zmq.use_pub_sub else None if conf.oslo_messaging_zmq.use_pub_sub else None
self.consumers = [] self.consumers = []

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -48,7 +48,7 @@ RESPONSE_TYPES = (REPLY_TYPE, ACK_TYPE)
MESSAGE_TYPES = REQUEST_TYPES + RESPONSE_TYPES MESSAGE_TYPES = REQUEST_TYPES + RESPONSE_TYPES
MULTISEND_TYPES = (CAST_FANOUT_TYPE, NOTIFY_TYPE) MULTISEND_TYPES = (CAST_FANOUT_TYPE, NOTIFY_TYPE)
DIRECT_TYPES = (CALL_TYPE, CAST_TYPE, REPLY_TYPE, ACK_TYPE) DIRECT_TYPES = (CALL_TYPE, CAST_TYPE) + RESPONSE_TYPES
CAST_TYPES = (CAST_TYPE, CAST_FANOUT_TYPE) CAST_TYPES = (CAST_TYPE, CAST_FANOUT_TYPE)
NOTIFY_TYPES = (NOTIFY_TYPE,) NOTIFY_TYPES = (NOTIFY_TYPE,)
NON_BLOCKING_TYPES = CAST_TYPES + NOTIFY_TYPES NON_BLOCKING_TYPES = CAST_TYPES + NOTIFY_TYPES
@ -73,4 +73,4 @@ def message_type_str(message_type):
NOTIFY_TYPE: "NOTIFY", NOTIFY_TYPE: "NOTIFY",
REPLY_TYPE: "REPLY", REPLY_TYPE: "REPLY",
ACK_TYPE: "ACK"} ACK_TYPE: "ACK"}
return msg_type_str[message_type] return msg_type_str.get(message_type, "UNKNOWN")

@ -20,7 +20,8 @@ import oslo_messaging
from oslo_messaging._drivers.zmq_driver.client import zmq_receivers from oslo_messaging._drivers.zmq_driver.client import zmq_receivers
from oslo_messaging._drivers.zmq_driver.client import zmq_senders from oslo_messaging._drivers.zmq_driver.client import zmq_senders
from oslo_messaging._drivers.zmq_driver.proxy import zmq_proxy from oslo_messaging._drivers.zmq_driver.proxy import zmq_proxy
from oslo_messaging._drivers.zmq_driver.server import zmq_incoming_message from oslo_messaging._drivers.zmq_driver.server.consumers.zmq_dealer_consumer \
import DealerConsumerWithAcks
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_options from oslo_messaging._drivers.zmq_driver import zmq_options
from oslo_messaging.tests.drivers.zmq import zmq_common from oslo_messaging.tests.drivers.zmq import zmq_common
@ -100,11 +101,9 @@ class TestZmqAckManager(test_utils.BaseTestCase):
# and all parties to be ready for messaging # and all parties to be ready for messaging
time.sleep(1) time.sleep(1)
@mock.patch.object( @mock.patch.object(DealerConsumerWithAcks, '_acknowledge',
zmq_incoming_message.ZmqIncomingMessage, '_acknowledge', side_effect=DealerConsumerWithAcks._acknowledge,
side_effect=zmq_incoming_message.ZmqIncomingMessage._acknowledge, autospec=True)
autospec=True
)
def test_cast_success_without_retries(self, received_ack_mock): def test_cast_success_without_retries(self, received_ack_mock):
result = self.driver.send( result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False self.target, {}, self.message, wait_for_reply=False
@ -118,7 +117,7 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(2, self.set_result.call_count) self.assertEqual(2, self.set_result.call_count)
def test_cast_success_with_one_retry(self): def test_cast_success_with_one_retry(self):
with mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, with mock.patch.object(DealerConsumerWithAcks,
'_acknowledge') as lost_ack_mock: '_acknowledge') as lost_ack_mock:
result = self.driver.send( result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False self.target, {}, self.message, wait_for_reply=False
@ -131,11 +130,9 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(1, lost_ack_mock.call_count) self.assertEqual(1, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count) self.assertEqual(0, self.set_result.call_count)
self.listener._received.clear() self.listener._received.clear()
with mock.patch.object( with mock.patch.object(DealerConsumerWithAcks, '_acknowledge',
zmq_incoming_message.ZmqIncomingMessage, '_acknowledge', side_effect=DealerConsumerWithAcks._acknowledge,
side_effect=zmq_incoming_message.ZmqIncomingMessage._acknowledge, autospec=True) as received_ack_mock:
autospec=True
) as received_ack_mock:
self.ack_manager._pool.shutdown(wait=True) self.ack_manager._pool.shutdown(wait=True)
self.assertFalse(self.listener._received.isSet()) self.assertFalse(self.listener._received.isSet())
self.assertEqual(2, self.send.call_count) self.assertEqual(2, self.send.call_count)
@ -143,7 +140,7 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(2, self.set_result.call_count) self.assertEqual(2, self.set_result.call_count)
def test_cast_success_with_two_retries(self): def test_cast_success_with_two_retries(self):
with mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, with mock.patch.object(DealerConsumerWithAcks,
'_acknowledge') as lost_ack_mock: '_acknowledge') as lost_ack_mock:
result = self.driver.send( result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False self.target, {}, self.message, wait_for_reply=False
@ -161,18 +158,16 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(2, self.send.call_count) self.assertEqual(2, self.send.call_count)
self.assertEqual(2, lost_ack_mock.call_count) self.assertEqual(2, lost_ack_mock.call_count)
self.assertEqual(0, self.set_result.call_count) self.assertEqual(0, self.set_result.call_count)
with mock.patch.object( with mock.patch.object(DealerConsumerWithAcks, '_acknowledge',
zmq_incoming_message.ZmqIncomingMessage, '_acknowledge', side_effect=DealerConsumerWithAcks._acknowledge,
side_effect=zmq_incoming_message.ZmqIncomingMessage._acknowledge, autospec=True) as received_ack_mock:
autospec=True
) as received_ack_mock:
self.ack_manager._pool.shutdown(wait=True) self.ack_manager._pool.shutdown(wait=True)
self.assertFalse(self.listener._received.isSet()) self.assertFalse(self.listener._received.isSet())
self.assertEqual(3, self.send.call_count) self.assertEqual(3, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count) self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count) self.assertEqual(2, self.set_result.call_count)
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, '_acknowledge') @mock.patch.object(DealerConsumerWithAcks, '_acknowledge')
def test_cast_failure_exhausted_retries(self, lost_ack_mock): def test_cast_failure_exhausted_retries(self, lost_ack_mock):
result = self.driver.send( result = self.driver.send(
self.target, {}, self.message, wait_for_reply=False self.target, {}, self.message, wait_for_reply=False
@ -185,21 +180,15 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(3, lost_ack_mock.call_count) self.assertEqual(3, lost_ack_mock.call_count)
self.assertEqual(1, self.set_result.call_count) self.assertEqual(1, self.set_result.call_count)
@mock.patch.object( @mock.patch.object(DealerConsumerWithAcks, '_acknowledge',
zmq_incoming_message.ZmqIncomingMessage, '_acknowledge', side_effect=DealerConsumerWithAcks._acknowledge,
side_effect=zmq_incoming_message.ZmqIncomingMessage._acknowledge, autospec=True)
autospec=True @mock.patch.object(DealerConsumerWithAcks, '_reply',
) side_effect=DealerConsumerWithAcks._reply,
@mock.patch.object( autospec=True)
zmq_incoming_message.ZmqIncomingMessage, 'reply', @mock.patch.object(DealerConsumerWithAcks, '_reply_from_cache',
side_effect=zmq_incoming_message.ZmqIncomingMessage.reply, side_effect=DealerConsumerWithAcks._reply_from_cache,
autospec=True autospec=True)
)
@mock.patch.object(
zmq_incoming_message.ZmqIncomingMessage, '_reply_from_cache',
side_effect=zmq_incoming_message.ZmqIncomingMessage._reply_from_cache,
autospec=True
)
def test_call_success_without_retries(self, unused_reply_from_cache_mock, def test_call_success_without_retries(self, unused_reply_from_cache_mock,
received_reply_mock, received_reply_mock,
received_ack_mock): received_ack_mock):
@ -213,13 +202,13 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(1, self.send.call_count) self.assertEqual(1, self.send.call_count)
self.assertEqual(1, received_ack_mock.call_count) self.assertEqual(1, received_ack_mock.call_count)
self.assertEqual(3, self.set_result.call_count) self.assertEqual(3, self.set_result.call_count)
received_reply_mock.assert_called_once_with(mock.ANY, reply=True) received_reply_mock.assert_called_once_with(mock.ANY, mock.ANY,
reply=True, failure=None)
self.assertEqual(0, unused_reply_from_cache_mock.call_count) self.assertEqual(0, unused_reply_from_cache_mock.call_count)
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, '_acknowledge') @mock.patch.object(DealerConsumerWithAcks, '_acknowledge')
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, 'reply') @mock.patch.object(DealerConsumerWithAcks, '_reply')
@mock.patch.object(zmq_incoming_message.ZmqIncomingMessage, @mock.patch.object(DealerConsumerWithAcks, '_reply_from_cache')
'_reply_from_cache')
def test_call_failure_exhausted_retries(self, lost_reply_from_cache_mock, def test_call_failure_exhausted_retries(self, lost_reply_from_cache_mock,
lost_reply_mock, lost_ack_mock): lost_reply_mock, lost_ack_mock):
self.assertRaises(oslo_messaging.MessagingTimeout, self.assertRaises(oslo_messaging.MessagingTimeout,
@ -232,5 +221,6 @@ class TestZmqAckManager(test_utils.BaseTestCase):
self.assertEqual(3, self.send.call_count) self.assertEqual(3, self.send.call_count)
self.assertEqual(3, lost_ack_mock.call_count) self.assertEqual(3, lost_ack_mock.call_count)
self.assertEqual(2, self.set_result.call_count) self.assertEqual(2, self.set_result.call_count)
lost_reply_mock.assert_called_once_with(reply=True) lost_reply_mock.assert_called_once_with(mock.ANY,
reply=True, failure=None)
self.assertEqual(2, lost_reply_from_cache_mock.call_count) self.assertEqual(2, lost_reply_from_cache_mock.call_count)