Merge "batch notification listener"
This commit is contained in:
commit
213176657d
@ -203,6 +203,7 @@ class AMQPListener(base.Listener):
|
||||
ctxt.reply_q,
|
||||
self._obsolete_reply_queues))
|
||||
|
||||
@base.batch_poll_helper
|
||||
def poll(self, timeout=None):
|
||||
while not self._stopped.is_set():
|
||||
if self.incoming:
|
||||
|
@ -15,9 +15,12 @@
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_utils import timeutils
|
||||
import six
|
||||
from six.moves import range as compat_range
|
||||
|
||||
|
||||
from oslo_messaging import exceptions
|
||||
|
||||
base_opts = [
|
||||
@ -28,6 +31,27 @@ base_opts = [
|
||||
]
|
||||
|
||||
|
||||
def batch_poll_helper(func):
|
||||
"""Decorator to poll messages in batch
|
||||
|
||||
This decorator helps driver that polls message one by one,
|
||||
to returns a list of message.
|
||||
"""
|
||||
def wrapper(in_self, timeout=None, prefetch_size=1):
|
||||
incomings = []
|
||||
watch = timeutils.StopWatch(duration=timeout)
|
||||
with watch:
|
||||
for __ in compat_range(prefetch_size):
|
||||
msg = func(in_self, timeout=watch.leftover(return_none=True))
|
||||
if msg is not None:
|
||||
incomings.append(msg)
|
||||
else:
|
||||
# timeout reached or listener stopped
|
||||
break
|
||||
return incomings
|
||||
return wrapper
|
||||
|
||||
|
||||
class TransportDriverError(exceptions.MessagingException):
|
||||
"""Base class for transport driver specific exceptions."""
|
||||
|
||||
@ -61,8 +85,9 @@ class Listener(object):
|
||||
self.driver = driver
|
||||
|
||||
@abc.abstractmethod
|
||||
def poll(self, timeout=None):
|
||||
"""Blocking until a message is pending and return IncomingMessage.
|
||||
def poll(self, timeout=None, prefetch_size=1):
|
||||
"""Blocking until 'prefetch_size' message is pending and return
|
||||
[IncomingMessage].
|
||||
Return None after timeout seconds if timeout is set and no message is
|
||||
ending or if the listener have been stopped.
|
||||
"""
|
||||
|
@ -54,6 +54,7 @@ class FakeListener(base.Listener):
|
||||
exchange = self._exchange_manager.get_exchange(target.exchange)
|
||||
exchange.ensure_queue(target, pool)
|
||||
|
||||
@base.batch_poll_helper
|
||||
def poll(self, timeout=None):
|
||||
if timeout is not None:
|
||||
deadline = time.time() + timeout
|
||||
|
@ -252,6 +252,7 @@ class KafkaListener(base.Listener):
|
||||
self.conn = conn
|
||||
self.incoming_queue = []
|
||||
|
||||
@base.batch_poll_helper
|
||||
def poll(self, timeout=None):
|
||||
while not self._stopped.is_set():
|
||||
if self.incoming_queue:
|
||||
|
@ -859,7 +859,8 @@ class Connection(object):
|
||||
raise rpc_common.Timeout()
|
||||
|
||||
def _recoverable_error_callback(exc):
|
||||
self._new_consumers = self._consumers
|
||||
if not isinstance(exc, rpc_common.Timeout):
|
||||
self._new_consumers = self._consumers
|
||||
timer.check_return(_raise_timeout, exc)
|
||||
|
||||
def _error_callback(exc):
|
||||
|
@ -117,8 +117,12 @@ class ProtonListener(base.Listener):
|
||||
super(ProtonListener, self).__init__(driver)
|
||||
self.incoming = moves.queue.Queue()
|
||||
|
||||
def poll(self):
|
||||
message = self.incoming.get()
|
||||
@base.batch_poll_helper
|
||||
def poll(self, timeout=None):
|
||||
try:
|
||||
message = self.incoming.get(True, timeout)
|
||||
except moves.queue.Empty:
|
||||
return
|
||||
request, ctxt = unmarshal_request(message)
|
||||
LOG.debug("Returning incoming message")
|
||||
return ProtonIncomingMessage(self, ctxt, request, message)
|
||||
|
@ -40,6 +40,7 @@ class ZmqServer(base.Listener):
|
||||
self.notify_consumer = self.rpc_consumer
|
||||
self.consumers = [self.rpc_consumer]
|
||||
|
||||
@base.batch_poll_helper
|
||||
def poll(self, timeout=None):
|
||||
message, socket = self.poller.poll(
|
||||
timeout or self.conf.rpc_poll_timeout)
|
||||
|
@ -93,8 +93,11 @@ class PooledExecutor(base.ExecutorBase):
|
||||
@excutils.forever_retry_uncaught_exceptions
|
||||
def _runner(self):
|
||||
while not self._tombstone.is_set():
|
||||
incoming = self.listener.poll()
|
||||
if incoming is None:
|
||||
incoming = self.listener.poll(
|
||||
timeout=self.dispatcher.batch_timeout,
|
||||
prefetch_size=self.dispatcher.batch_size)
|
||||
|
||||
if not incoming:
|
||||
continue
|
||||
callback = self.dispatcher(incoming, self._executor_callback)
|
||||
was_submitted = self._do_submit(callback)
|
||||
|
@ -79,6 +79,12 @@ class DispatcherExecutorContext(object):
|
||||
class DispatcherBase(object):
|
||||
"Base class for dispatcher"
|
||||
|
||||
batch_size = 1
|
||||
"Number of messages to wait before calling endpoints callacks"
|
||||
|
||||
batch_timeout = None
|
||||
"Number of seconds to wait before calling endpoints callacks"
|
||||
|
||||
@abc.abstractmethod
|
||||
def _listen(self, transport):
|
||||
"""Initiate the driver Listener
|
||||
@ -98,7 +104,7 @@ class DispatcherBase(object):
|
||||
def __call__(self, incoming, executor_callback=None):
|
||||
"""Called by the executor to get the DispatcherExecutorContext
|
||||
|
||||
:param incoming: message or list of messages
|
||||
:param incoming: list of messages
|
||||
:type incoming: oslo_messging._drivers.base.IncomingMessage
|
||||
:returns: DispatcherExecutorContext
|
||||
:rtype: DispatcherExecutorContext
|
||||
|
@ -17,6 +17,7 @@ __all__ = ['Notifier',
|
||||
'LoggingNotificationHandler',
|
||||
'get_notification_transport',
|
||||
'get_notification_listener',
|
||||
'get_batch_notification_listener',
|
||||
'NotificationResult',
|
||||
'NotificationFilter',
|
||||
'PublishErrorsHandler',
|
||||
|
@ -16,7 +16,8 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import six
|
||||
|
||||
from oslo_messaging import dispatcher
|
||||
from oslo_messaging import localcontext
|
||||
@ -33,17 +34,7 @@ class NotificationResult(object):
|
||||
REQUEUE = 'requeue'
|
||||
|
||||
|
||||
class NotificationDispatcher(dispatcher.DispatcherBase):
|
||||
"""A message dispatcher which understands Notification messages.
|
||||
|
||||
A MessageHandlingServer is constructed by passing a callable dispatcher
|
||||
which is invoked with context and message dictionaries each time a message
|
||||
is received.
|
||||
|
||||
NotifcationDispatcher is one such dispatcher which pass a raw notification
|
||||
message to the endpoints
|
||||
"""
|
||||
|
||||
class _NotificationDispatcherBase(dispatcher.DispatcherBase):
|
||||
def __init__(self, targets, endpoints, serializer, allow_requeue,
|
||||
pool=None):
|
||||
self.targets = targets
|
||||
@ -74,12 +65,15 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
|
||||
executor_callback=executor_callback,
|
||||
post=self._post_dispatch)
|
||||
|
||||
@staticmethod
|
||||
def _post_dispatch(incoming, result):
|
||||
if result == NotificationResult.HANDLED:
|
||||
incoming.acknowledge()
|
||||
else:
|
||||
incoming.requeue()
|
||||
def _post_dispatch(self, incoming, requeues):
|
||||
for m in incoming:
|
||||
try:
|
||||
if requeues and m in requeues:
|
||||
m.requeue()
|
||||
else:
|
||||
m.acknowledge()
|
||||
except Exception:
|
||||
LOG.error("Fail to ack/requeue message", exc_info=True)
|
||||
|
||||
def _dispatch_and_handle_error(self, incoming, executor_callback):
|
||||
"""Dispatch a notification message to the appropriate endpoint method.
|
||||
@ -88,24 +82,59 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
|
||||
:type ctxt: IncomingMessage
|
||||
"""
|
||||
try:
|
||||
return self._dispatch(incoming.ctxt, incoming.message,
|
||||
executor_callback)
|
||||
return self._dispatch(incoming, executor_callback)
|
||||
except Exception:
|
||||
# sys.exc_info() is deleted by LOG.exception().
|
||||
exc_info = sys.exc_info()
|
||||
LOG.error('Exception during message handling',
|
||||
exc_info=exc_info)
|
||||
return NotificationResult.HANDLED
|
||||
LOG.error('Exception during message handling', exc_info=True)
|
||||
|
||||
def _dispatch(self, ctxt, message, executor_callback=None):
|
||||
"""Dispatch an RPC message to the appropriate endpoint method.
|
||||
|
||||
:param ctxt: the request context
|
||||
:type ctxt: dict
|
||||
:param message: the message payload
|
||||
:type message: dict
|
||||
def _dispatch(self, incoming, executor_callback=None):
|
||||
"""Dispatch notification messages to the appropriate endpoint method.
|
||||
"""
|
||||
ctxt = self.serializer.deserialize_context(ctxt)
|
||||
|
||||
messages_grouped = itertools.groupby((
|
||||
self._extract_user_message(m)
|
||||
for m in incoming), lambda x: x[0])
|
||||
|
||||
requeues = set()
|
||||
for priority, messages in messages_grouped:
|
||||
__, raw_messages, messages = six.moves.zip(*messages)
|
||||
raw_messages = list(raw_messages)
|
||||
messages = list(messages)
|
||||
if priority not in PRIORITIES:
|
||||
LOG.warning('Unknown priority "%s"', priority)
|
||||
continue
|
||||
for screen, callback in self._callbacks_by_priority.get(priority,
|
||||
[]):
|
||||
if screen:
|
||||
filtered_messages = [message for message in messages
|
||||
if screen.match(
|
||||
message["ctxt"],
|
||||
message["publisher_id"],
|
||||
message["event_type"],
|
||||
message["metadata"],
|
||||
message["payload"])]
|
||||
else:
|
||||
filtered_messages = messages
|
||||
|
||||
if not filtered_messages:
|
||||
continue
|
||||
|
||||
ret = self._exec_callback(executor_callback, callback,
|
||||
filtered_messages)
|
||||
if self.allow_requeue and ret == NotificationResult.REQUEUE:
|
||||
requeues.update(raw_messages)
|
||||
break
|
||||
return requeues
|
||||
|
||||
def _exec_callback(self, executor_callback, callback, *args):
|
||||
if executor_callback:
|
||||
ret = executor_callback(callback, *args)
|
||||
else:
|
||||
ret = callback(*args)
|
||||
return NotificationResult.HANDLED if ret is None else ret
|
||||
|
||||
def _extract_user_message(self, incoming):
|
||||
ctxt = self.serializer.deserialize_context(incoming.ctxt)
|
||||
message = incoming.message
|
||||
|
||||
publisher_id = message.get('publisher_id')
|
||||
event_type = message.get('event_type')
|
||||
@ -114,28 +143,50 @@ class NotificationDispatcher(dispatcher.DispatcherBase):
|
||||
'timestamp': message.get('timestamp')
|
||||
}
|
||||
priority = message.get('priority', '').lower()
|
||||
if priority not in PRIORITIES:
|
||||
LOG.warning('Unknown priority "%s"', priority)
|
||||
return
|
||||
|
||||
payload = self.serializer.deserialize_entity(ctxt,
|
||||
message.get('payload'))
|
||||
return priority, incoming, dict(ctxt=ctxt,
|
||||
publisher_id=publisher_id,
|
||||
event_type=event_type,
|
||||
payload=payload,
|
||||
metadata=metadata)
|
||||
|
||||
for screen, callback in self._callbacks_by_priority.get(priority, []):
|
||||
if screen and not screen.match(ctxt, publisher_id, event_type,
|
||||
metadata, payload):
|
||||
continue
|
||||
localcontext._set_local_context(ctxt)
|
||||
try:
|
||||
if executor_callback:
|
||||
ret = executor_callback(callback, ctxt, publisher_id,
|
||||
event_type, payload, metadata)
|
||||
else:
|
||||
ret = callback(ctxt, publisher_id, event_type, payload,
|
||||
metadata)
|
||||
ret = NotificationResult.HANDLED if ret is None else ret
|
||||
if self.allow_requeue and ret == NotificationResult.REQUEUE:
|
||||
return ret
|
||||
finally:
|
||||
localcontext._clear_local_context()
|
||||
return NotificationResult.HANDLED
|
||||
|
||||
class NotificationDispatcher(_NotificationDispatcherBase):
|
||||
"""A message dispatcher which understands Notification messages.
|
||||
|
||||
A MessageHandlingServer is constructed by passing a callable dispatcher
|
||||
which is invoked with context and message dictionaries each time a message
|
||||
is received.
|
||||
"""
|
||||
def _exec_callback(self, executor_callback, callback, messages):
|
||||
localcontext._set_local_context(
|
||||
messages[0]["ctxt"])
|
||||
try:
|
||||
return super(NotificationDispatcher, self)._exec_callback(
|
||||
executor_callback, callback,
|
||||
messages[0]["ctxt"],
|
||||
messages[0]["publisher_id"],
|
||||
messages[0]["event_type"],
|
||||
messages[0]["payload"],
|
||||
messages[0]["metadata"])
|
||||
finally:
|
||||
localcontext._clear_local_context()
|
||||
|
||||
|
||||
class BatchNotificationDispatcher(_NotificationDispatcherBase):
|
||||
"""A message dispatcher which understands Notification messages.
|
||||
|
||||
A MessageHandlingServer is constructed by passing a callable dispatcher
|
||||
which is invoked with a list of message dictionaries each time 'batch_size'
|
||||
messages are received or 'batch_timeout' seconds is reached.
|
||||
"""
|
||||
|
||||
def __init__(self, targets, endpoints, serializer, allow_requeue,
|
||||
pool=None, batch_size=None, batch_timeout=None):
|
||||
super(BatchNotificationDispatcher, self).__init__(targets, endpoints,
|
||||
serializer,
|
||||
allow_requeue,
|
||||
pool)
|
||||
self.batch_size = batch_size
|
||||
self.batch_timeout = batch_timeout
|
||||
|
@ -142,3 +142,46 @@ def get_notification_listener(transport, targets, endpoints,
|
||||
serializer,
|
||||
allow_requeue, pool)
|
||||
return msg_server.MessageHandlingServer(transport, dispatcher, executor)
|
||||
|
||||
|
||||
def get_batch_notification_listener(transport, targets, endpoints,
|
||||
executor='blocking', serializer=None,
|
||||
allow_requeue=False, pool=None,
|
||||
batch_size=None, batch_timeout=None):
|
||||
"""Construct a batch notification listener
|
||||
|
||||
The executor parameter controls how incoming messages will be received and
|
||||
dispatched. By default, the most simple executor is used - the blocking
|
||||
executor.
|
||||
|
||||
If the eventlet executor is used, the threading and time library need to be
|
||||
monkeypatched.
|
||||
|
||||
:param transport: the messaging transport
|
||||
:type transport: Transport
|
||||
:param targets: the exchanges and topics to listen on
|
||||
:type targets: list of Target
|
||||
:param endpoints: a list of endpoint objects
|
||||
:type endpoints: list
|
||||
:param executor: name of a message executor - for example
|
||||
'eventlet', 'blocking'
|
||||
:type executor: str
|
||||
:param serializer: an optional entity serializer
|
||||
:type serializer: Serializer
|
||||
:param allow_requeue: whether NotificationResult.REQUEUE support is needed
|
||||
:type allow_requeue: bool
|
||||
:param pool: the pool name
|
||||
:type pool: str
|
||||
:param batch_size: number of messages to wait before calling
|
||||
endpoints callacks
|
||||
:type batch_size: int
|
||||
:param batch_timeout: number of seconds to wait before calling
|
||||
endpoints callacks
|
||||
:type batch_timeout: int
|
||||
:raises: NotImplementedError
|
||||
"""
|
||||
transport._require_driver_features(requeue=allow_requeue)
|
||||
dispatcher = notify_dispatcher.BatchNotificationDispatcher(
|
||||
targets, endpoints, serializer, allow_requeue, pool,
|
||||
batch_size, batch_timeout)
|
||||
return msg_server.MessageHandlingServer(transport, dispatcher, executor)
|
||||
|
@ -131,9 +131,9 @@ class RPCDispatcher(dispatcher.DispatcherBase):
|
||||
return self.serializer.serialize_entity(ctxt, result)
|
||||
|
||||
def __call__(self, incoming, executor_callback=None):
|
||||
incoming.acknowledge()
|
||||
incoming[0].acknowledge()
|
||||
return dispatcher.DispatcherExecutorContext(
|
||||
incoming, self._dispatch_and_reply,
|
||||
incoming[0], self._dispatch_and_reply,
|
||||
executor_callback=executor_callback)
|
||||
|
||||
def _dispatch_and_reply(self, incoming, executor_callback):
|
||||
|
@ -226,7 +226,7 @@ class TestKafkaListener(test_utils.BaseTestCase):
|
||||
listener.stop()
|
||||
fake_response = listener.poll()
|
||||
self.assertEqual(1, len(listener.conn.consume.mock_calls))
|
||||
self.assertEqual(fake_response, None)
|
||||
self.assertEqual([], fake_response)
|
||||
|
||||
|
||||
class TestWithRealKafkaBroker(test_utils.BaseTestCase):
|
||||
@ -251,7 +251,7 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
|
||||
self.driver.send_notification(
|
||||
target, fake_context, fake_message, None)
|
||||
|
||||
received_message = listener.poll()
|
||||
received_message = listener.poll()[0]
|
||||
self.assertEqual(fake_context, received_message.ctxt)
|
||||
self.assertEqual(fake_message, received_message.message)
|
||||
|
||||
@ -268,7 +268,7 @@ class TestWithRealKafkaBroker(test_utils.BaseTestCase):
|
||||
self.driver.send_notification(
|
||||
target, fake_context, fake_message, None)
|
||||
|
||||
received_message = listener.poll()
|
||||
received_message = listener.poll()[0]
|
||||
self.assertEqual(fake_context, received_message.ctxt)
|
||||
self.assertEqual(fake_message, received_message.message)
|
||||
|
||||
|
@ -423,7 +423,7 @@ class TestSendReceive(test_utils.BaseTestCase):
|
||||
for i in range(len(senders)):
|
||||
senders[i].start()
|
||||
|
||||
received = listener.poll()
|
||||
received = listener.poll()[0]
|
||||
self.assertIsNotNone(received)
|
||||
self.assertEqual(self.ctxt, received.ctxt)
|
||||
self.assertEqual({'tx_id': i}, received.message)
|
||||
@ -501,7 +501,7 @@ class TestPollAsync(test_utils.BaseTestCase):
|
||||
target = oslo_messaging.Target(topic='testtopic')
|
||||
listener = driver.listen(target)
|
||||
received = listener.poll(timeout=0.050)
|
||||
self.assertIsNone(received)
|
||||
self.assertEqual([], received)
|
||||
|
||||
|
||||
class TestRacyWaitForReply(test_utils.BaseTestCase):
|
||||
@ -561,13 +561,13 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
|
||||
senders[0].start()
|
||||
notify_condition.wait()
|
||||
|
||||
msgs.append(listener.poll())
|
||||
msgs.extend(listener.poll())
|
||||
self.assertEqual({'tx_id': 0}, msgs[-1].message)
|
||||
|
||||
# Start the second guy, receive his message
|
||||
senders[1].start()
|
||||
|
||||
msgs.append(listener.poll())
|
||||
msgs.extend(listener.poll())
|
||||
self.assertEqual({'tx_id': 1}, msgs[-1].message)
|
||||
|
||||
# Reply to both in order, making the second thread queue
|
||||
@ -581,7 +581,7 @@ class TestRacyWaitForReply(test_utils.BaseTestCase):
|
||||
# Start the 3rd guy, receive his message
|
||||
senders[2].start()
|
||||
|
||||
msgs.append(listener.poll())
|
||||
msgs.extend(listener.poll())
|
||||
self.assertEqual({'tx_id': 2}, msgs[-1].message)
|
||||
|
||||
# Verify the _send_reply was not invoked by driver:
|
||||
@ -862,7 +862,7 @@ class TestReplyWireFormat(test_utils.BaseTestCase):
|
||||
|
||||
producer.publish(msg)
|
||||
|
||||
received = listener.poll()
|
||||
received = listener.poll()[0]
|
||||
self.assertIsNotNone(received)
|
||||
self.assertEqual(self.expected_ctxt, received.ctxt)
|
||||
self.assertEqual(self.expected, received.message)
|
||||
|
@ -52,7 +52,8 @@ class TestServerListener(object):
|
||||
def _run(self):
|
||||
try:
|
||||
message = self.listener.poll()
|
||||
if message is not None:
|
||||
if message:
|
||||
message = message[0]
|
||||
message.acknowledge()
|
||||
self._received.set()
|
||||
self.message = message
|
||||
|
@ -132,11 +132,14 @@ class TestExecutor(test_utils.BaseTestCase):
|
||||
endpoint = mock.MagicMock(return_value='result')
|
||||
event = None
|
||||
|
||||
class Dispatcher(object):
|
||||
class Dispatcher(dispatcher_base.DispatcherBase):
|
||||
def __init__(self, endpoint):
|
||||
self.endpoint = endpoint
|
||||
self.result = "not set"
|
||||
|
||||
def _listen(self, transport):
|
||||
pass
|
||||
|
||||
def callback(self, incoming, executor_callback):
|
||||
if executor_callback is None:
|
||||
result = self.endpoint(incoming.ctxt,
|
||||
@ -152,7 +155,7 @@ class TestExecutor(test_utils.BaseTestCase):
|
||||
|
||||
def __call__(self, incoming, executor_callback=None):
|
||||
return dispatcher_base.DispatcherExecutorContext(
|
||||
incoming, self.callback, executor_callback)
|
||||
incoming[0], self.callback, executor_callback)
|
||||
|
||||
return Dispatcher(endpoint), endpoint, event, run_executor
|
||||
|
||||
@ -162,7 +165,7 @@ class TestExecutor(test_utils.BaseTestCase):
|
||||
executor = self.executor(self.conf, listener, dispatcher)
|
||||
incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'})
|
||||
|
||||
def fake_poll(timeout=None):
|
||||
def fake_poll(timeout=None, prefetch_size=1):
|
||||
time.sleep(0.1)
|
||||
if listener.poll.call_count == 10:
|
||||
if event is not None:
|
||||
@ -190,9 +193,9 @@ class TestExecutor(test_utils.BaseTestCase):
|
||||
executor = self.executor(self.conf, listener, dispatcher)
|
||||
incoming_message = mock.MagicMock(ctxt={}, message={'payload': 'data'})
|
||||
|
||||
def fake_poll(timeout=None):
|
||||
def fake_poll(timeout=None, prefetch_size=1):
|
||||
if listener.poll.call_count == 1:
|
||||
return incoming_message
|
||||
return [incoming_message]
|
||||
if event is not None:
|
||||
event.wait()
|
||||
executor.stop()
|
||||
|
@ -16,6 +16,7 @@ import uuid
|
||||
|
||||
import concurrent.futures
|
||||
from oslo_config import cfg
|
||||
import six.moves
|
||||
from testtools import matchers
|
||||
|
||||
import oslo_messaging
|
||||
@ -27,8 +28,8 @@ class CallTestCase(utils.SkipIfNoTransportURL):
|
||||
def setUp(self):
|
||||
super(CallTestCase, self).setUp(conf=cfg.ConfigOpts())
|
||||
|
||||
self.conf.prog="test_prog"
|
||||
self.conf.project="test_project"
|
||||
self.conf.prog = "test_prog"
|
||||
self.conf.project = "test_project"
|
||||
|
||||
self.config(heartbeat_timeout_threshold=0,
|
||||
group='oslo_messaging_rabbit')
|
||||
@ -324,3 +325,18 @@ class NotifyTestCase(utils.SkipIfNoTransportURL):
|
||||
self.assertEqual(expected[1], actual[0])
|
||||
self.assertEqual(expected[2], actual[1])
|
||||
self.assertEqual(expected[3], actual[2])
|
||||
|
||||
def test_simple_batch(self):
|
||||
listener = self.useFixture(
|
||||
utils.BatchNotificationFixture(self.conf, self.url,
|
||||
['test_simple_batch'],
|
||||
batch_size=100, batch_timeout=2))
|
||||
notifier = listener.notifier('abc')
|
||||
|
||||
for i in six.moves.range(0, 205):
|
||||
notifier.info({}, 'test%s' % i, 'Hello World!')
|
||||
events = listener.get_events(timeout=3)
|
||||
self.assertEqual(3, len(events), events)
|
||||
self.assertEqual(100, len(events[0][1]))
|
||||
self.assertEqual(100, len(events[1][1]))
|
||||
self.assertEqual(5, len(events[2][1]))
|
||||
|
@ -293,13 +293,14 @@ class SkipIfNoTransportURL(test_utils.BaseTestCase):
|
||||
|
||||
|
||||
class NotificationFixture(fixtures.Fixture):
|
||||
def __init__(self, conf, url, topics):
|
||||
def __init__(self, conf, url, topics, batch=None):
|
||||
super(NotificationFixture, self).__init__()
|
||||
self.conf = conf
|
||||
self.url = url
|
||||
self.topics = topics
|
||||
self.events = moves.queue.Queue()
|
||||
self.name = str(id(self))
|
||||
self.batch = batch
|
||||
|
||||
def setUp(self):
|
||||
super(NotificationFixture, self).setUp()
|
||||
@ -307,10 +308,7 @@ class NotificationFixture(fixtures.Fixture):
|
||||
# add a special topic for internal notifications
|
||||
targets.append(oslo_messaging.Target(topic=self.name))
|
||||
transport = self.useFixture(TransportFixture(self.conf, self.url))
|
||||
self.server = oslo_messaging.get_notification_listener(
|
||||
transport.transport,
|
||||
targets,
|
||||
[self], 'eventlet')
|
||||
self.server = self._get_server(transport, targets)
|
||||
self._ctrl = self.notifier('internal', topic=self.name)
|
||||
self._start()
|
||||
transport.wait()
|
||||
@ -319,6 +317,12 @@ class NotificationFixture(fixtures.Fixture):
|
||||
self._stop()
|
||||
super(NotificationFixture, self).cleanUp()
|
||||
|
||||
def _get_server(self, transport, targets):
|
||||
return oslo_messaging.get_notification_listener(
|
||||
transport.transport,
|
||||
targets,
|
||||
[self], 'eventlet')
|
||||
|
||||
def _start(self):
|
||||
self.thread = test_utils.ServerThreadHelper(self.server)
|
||||
self.thread.start()
|
||||
@ -366,3 +370,39 @@ class NotificationFixture(fixtures.Fixture):
|
||||
except moves.queue.Empty:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
class BatchNotificationFixture(NotificationFixture):
|
||||
def __init__(self, conf, url, topics, batch_size=5, batch_timeout=2):
|
||||
super(BatchNotificationFixture, self).__init__(conf, url, topics)
|
||||
self.batch_size = batch_size
|
||||
self.batch_timeout = batch_timeout
|
||||
|
||||
def _get_server(self, transport, targets):
|
||||
return oslo_messaging.get_batch_notification_listener(
|
||||
transport.transport,
|
||||
targets,
|
||||
[self], 'eventlet',
|
||||
batch_timeout=self.batch_timeout,
|
||||
batch_size=self.batch_size)
|
||||
|
||||
def debug(self, messages):
|
||||
self.events.put(['debug', messages])
|
||||
|
||||
def audit(self, messages):
|
||||
self.events.put(['audit', messages])
|
||||
|
||||
def info(self, messages):
|
||||
self.events.put(['info', messages])
|
||||
|
||||
def warn(self, messages):
|
||||
self.events.put(['warn', messages])
|
||||
|
||||
def error(self, messages):
|
||||
self.events.put(['error', messages])
|
||||
|
||||
def critical(self, messages):
|
||||
self.events.put(['critical', messages])
|
||||
|
||||
def sample(self, messages):
|
||||
pass # Just used for internal shutdown control
|
||||
|
@ -107,7 +107,7 @@ class TestDispatcher(test_utils.BaseTestCase):
|
||||
sorted(dispatcher._targets_priorities))
|
||||
|
||||
incoming = mock.Mock(ctxt={}, message=msg)
|
||||
callback = dispatcher(incoming)
|
||||
callback = dispatcher([incoming])
|
||||
callback.run()
|
||||
callback.done()
|
||||
|
||||
@ -144,7 +144,7 @@ class TestDispatcher(test_utils.BaseTestCase):
|
||||
msg['priority'] = 'what???'
|
||||
dispatcher = notify_dispatcher.NotificationDispatcher(
|
||||
[mock.Mock()], [mock.Mock()], None, allow_requeue=True, pool=None)
|
||||
callback = dispatcher(mock.Mock(ctxt={}, message=msg))
|
||||
callback = dispatcher([mock.Mock(ctxt={}, message=msg)])
|
||||
callback.run()
|
||||
callback.done()
|
||||
mylog.warning.assert_called_once_with('Unknown priority "%s"',
|
||||
@ -246,7 +246,7 @@ class TestDispatcherFilter(test_utils.BaseTestCase):
|
||||
'timestamp': '2014-03-03 18:21:04.369234',
|
||||
'message_id': '99863dda-97f0-443a-a0c1-6ed317b7fd45'}
|
||||
incoming = mock.Mock(ctxt=self.context, message=message)
|
||||
callback = dispatcher(incoming)
|
||||
callback = dispatcher([incoming])
|
||||
callback.run()
|
||||
callback.done()
|
||||
|
||||
|
@ -23,6 +23,7 @@ import oslo_messaging
|
||||
from oslo_messaging.notify import dispatcher
|
||||
from oslo_messaging.notify import notifier as msg_notifier
|
||||
from oslo_messaging.tests import utils as test_utils
|
||||
import six
|
||||
from six.moves import mock
|
||||
|
||||
load_tests = testscenarios.load_tests_apply_scenarios
|
||||
@ -56,7 +57,7 @@ class ListenerSetupMixin(object):
|
||||
self.threads = []
|
||||
self.lock = threading.Condition()
|
||||
|
||||
def info(self, ctxt, publisher_id, event_type, payload, metadata):
|
||||
def info(self, *args, **kwargs):
|
||||
# NOTE(sileht): this run into an other thread
|
||||
with self.lock:
|
||||
self._received_msgs += 1
|
||||
@ -86,7 +87,7 @@ class ListenerSetupMixin(object):
|
||||
self.trackers = {}
|
||||
|
||||
def _setup_listener(self, transport, endpoints,
|
||||
targets=None, pool=None):
|
||||
targets=None, pool=None, batch=False):
|
||||
|
||||
if pool is None:
|
||||
tracker_name = '__default__'
|
||||
@ -98,9 +99,15 @@ class ListenerSetupMixin(object):
|
||||
|
||||
tracker = self.trackers.setdefault(
|
||||
tracker_name, self.ThreadTracker())
|
||||
listener = oslo_messaging.get_notification_listener(
|
||||
transport, targets=targets, endpoints=[tracker] + endpoints,
|
||||
allow_requeue=True, pool=pool, executor='eventlet')
|
||||
if batch:
|
||||
listener = oslo_messaging.get_batch_notification_listener(
|
||||
transport, targets=targets, endpoints=[tracker] + endpoints,
|
||||
allow_requeue=True, pool=pool, executor='eventlet',
|
||||
batch_size=batch[0], batch_timeout=batch[1])
|
||||
else:
|
||||
listener = oslo_messaging.get_notification_listener(
|
||||
transport, targets=targets, endpoints=[tracker] + endpoints,
|
||||
allow_requeue=True, pool=pool, executor='eventlet')
|
||||
|
||||
thread = RestartableServerThread(listener)
|
||||
tracker.start(thread)
|
||||
@ -170,6 +177,82 @@ class TestNotifyListener(test_utils.BaseTestCase, ListenerSetupMixin):
|
||||
else:
|
||||
self.assertTrue(False)
|
||||
|
||||
def test_batch_timeout(self):
|
||||
transport = oslo_messaging.get_transport(self.conf, url='fake:')
|
||||
|
||||
endpoint = mock.Mock()
|
||||
endpoint.info.return_value = None
|
||||
listener_thread = self._setup_listener(transport, [endpoint],
|
||||
batch=(5, 1))
|
||||
|
||||
notifier = self._setup_notifier(transport)
|
||||
for i in six.moves.range(12):
|
||||
notifier.info({}, 'an_event.start', 'test message')
|
||||
|
||||
self.wait_for_messages(3)
|
||||
self.assertFalse(listener_thread.stop())
|
||||
|
||||
messages = [dict(ctxt={},
|
||||
publisher_id='testpublisher',
|
||||
event_type='an_event.start',
|
||||
payload='test message',
|
||||
metadata={'message_id': mock.ANY,
|
||||
'timestamp': mock.ANY})]
|
||||
|
||||
endpoint.info.assert_has_calls([mock.call(messages * 5),
|
||||
mock.call(messages * 5),
|
||||
mock.call(messages * 2)])
|
||||
|
||||
def test_batch_size(self):
|
||||
transport = oslo_messaging.get_transport(self.conf, url='fake:')
|
||||
|
||||
endpoint = mock.Mock()
|
||||
endpoint.info.return_value = None
|
||||
listener_thread = self._setup_listener(transport, [endpoint],
|
||||
batch=(5, None))
|
||||
|
||||
notifier = self._setup_notifier(transport)
|
||||
for i in six.moves.range(10):
|
||||
notifier.info({}, 'an_event.start', 'test message')
|
||||
|
||||
self.wait_for_messages(2)
|
||||
self.assertFalse(listener_thread.stop())
|
||||
|
||||
messages = [dict(ctxt={},
|
||||
publisher_id='testpublisher',
|
||||
event_type='an_event.start',
|
||||
payload='test message',
|
||||
metadata={'message_id': mock.ANY,
|
||||
'timestamp': mock.ANY})]
|
||||
|
||||
endpoint.info.assert_has_calls([mock.call(messages * 5),
|
||||
mock.call(messages * 5)])
|
||||
|
||||
def test_batch_size_exception_path(self):
|
||||
transport = oslo_messaging.get_transport(self.conf, url='fake:')
|
||||
|
||||
endpoint = mock.Mock()
|
||||
endpoint.info.side_effect = [None, Exception('boom!')]
|
||||
listener_thread = self._setup_listener(transport, [endpoint],
|
||||
batch=(5, None))
|
||||
|
||||
notifier = self._setup_notifier(transport)
|
||||
for i in six.moves.range(10):
|
||||
notifier.info({}, 'an_event.start', 'test message')
|
||||
|
||||
self.wait_for_messages(2)
|
||||
self.assertFalse(listener_thread.stop())
|
||||
|
||||
messages = [dict(ctxt={},
|
||||
publisher_id='testpublisher',
|
||||
event_type='an_event.start',
|
||||
payload='test message',
|
||||
metadata={'message_id': mock.ANY,
|
||||
'timestamp': mock.ANY})]
|
||||
|
||||
endpoint.info.assert_has_calls([mock.call(messages * 5)])
|
||||
|
||||
|
||||
def test_one_topic(self):
|
||||
transport = msg_notifier.get_notification_transport(
|
||||
self.conf, url='fake:')
|
||||
|
@ -133,7 +133,7 @@ class TestDispatcher(test_utils.BaseTestCase):
|
||||
incoming = mock.Mock(ctxt=self.ctxt, message=self.msg)
|
||||
incoming.reply.side_effect = check_reply
|
||||
|
||||
callback = dispatcher(incoming)
|
||||
callback = dispatcher([incoming])
|
||||
callback.run()
|
||||
callback.done()
|
||||
|
||||
|
@ -60,7 +60,7 @@ class _ListenerThread(threading.Thread):
|
||||
def run(self):
|
||||
LOG.debug("Listener started")
|
||||
while self.msg_count > 0:
|
||||
in_msg = self.listener.poll()
|
||||
in_msg = self.listener.poll()[0]
|
||||
self.messages.put(in_msg)
|
||||
self.msg_count -= 1
|
||||
if in_msg.message.get('method') == 'echo':
|
||||
|
@ -79,14 +79,34 @@ class LoggingNoParsingFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
class NotifyEndpoint(object):
|
||||
def __init__(self):
|
||||
class Monitor(object):
|
||||
def __init__(self, show_stats=False, *args, **kwargs):
|
||||
self._count = self._prev_count = 0
|
||||
self.show_stats = show_stats
|
||||
if self.show_stats:
|
||||
self._monitor()
|
||||
|
||||
def _monitor(self):
|
||||
threading.Timer(1.0, self._monitor).start()
|
||||
print ("%d msg was received per second"
|
||||
% (self._count - self._prev_count))
|
||||
self._prev_count = self._count
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
self._count += 1
|
||||
|
||||
|
||||
class NotifyEndpoint(Monitor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(NotifyEndpoint, self).__init__(*args, **kwargs)
|
||||
self.cache = []
|
||||
|
||||
def info(self, ctxt, publisher_id, event_type, payload, metadata):
|
||||
super(NotifyEndpoint, self).info(ctxt, publisher_id, event_type,
|
||||
payload, metadata)
|
||||
LOG.info('msg rcv')
|
||||
LOG.info("%s %s %s %s" % (ctxt, publisher_id, event_type, payload))
|
||||
if payload not in self.cache:
|
||||
if not self.show_stats and payload not in self.cache:
|
||||
LOG.info('requeue msg')
|
||||
self.cache.append(payload)
|
||||
for i in range(15):
|
||||
@ -97,8 +117,8 @@ class NotifyEndpoint(object):
|
||||
return messaging.NotificationResult.HANDLED
|
||||
|
||||
|
||||
def notify_server(transport):
|
||||
endpoints = [NotifyEndpoint()]
|
||||
def notify_server(transport, show_stats):
|
||||
endpoints = [NotifyEndpoint(show_stats)]
|
||||
target = messaging.Target(topic='n-t1')
|
||||
server = notify.get_notification_listener(transport, [target],
|
||||
endpoints, executor='eventlet')
|
||||
@ -106,8 +126,41 @@ def notify_server(transport):
|
||||
server.wait()
|
||||
|
||||
|
||||
class RpcEndpoint(object):
|
||||
def __init__(self, wait_before_answer):
|
||||
class BatchNotifyEndpoint(Monitor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BatchNotifyEndpoint, self).__init__(*args, **kwargs)
|
||||
self.cache = []
|
||||
|
||||
def info(self, messages):
|
||||
super(BatchNotifyEndpoint, self).info(messages)
|
||||
self._count += len(messages) - 1
|
||||
|
||||
LOG.info('msg rcv')
|
||||
LOG.info("%s" % messages)
|
||||
if not self.show_stats and messages not in self.cache:
|
||||
LOG.info('requeue msg')
|
||||
self.cache.append(messages)
|
||||
for i in range(15):
|
||||
eventlet.sleep(1)
|
||||
return messaging.NotificationResult.REQUEUE
|
||||
else:
|
||||
LOG.info('ack msg')
|
||||
return messaging.NotificationResult.HANDLED
|
||||
|
||||
|
||||
def batch_notify_server(transport, show_stats):
|
||||
endpoints = [BatchNotifyEndpoint(show_stats)]
|
||||
target = messaging.Target(topic='n-t1')
|
||||
server = notify.get_batch_notification_listener(
|
||||
transport, [target],
|
||||
endpoints, executor='eventlet',
|
||||
batch_size=1000, batch_time=5)
|
||||
server.start()
|
||||
server.wait()
|
||||
|
||||
|
||||
class RpcEndpoint(Monitor):
|
||||
def __init__(self, wait_before_answer, show_stats):
|
||||
self.count = None
|
||||
self.wait_before_answer = wait_before_answer
|
||||
|
||||
@ -126,27 +179,8 @@ class RpcEndpoint(object):
|
||||
return "OK: %s" % message
|
||||
|
||||
|
||||
class RpcEndpointMonitor(RpcEndpoint):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RpcEndpointMonitor, self).__init__(*args, **kwargs)
|
||||
|
||||
self._count = self._prev_count = 0
|
||||
self._monitor()
|
||||
|
||||
def _monitor(self):
|
||||
threading.Timer(1.0, self._monitor).start()
|
||||
print ("%d msg was received per second"
|
||||
% (self._count - self._prev_count))
|
||||
self._prev_count = self._count
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
self._count += 1
|
||||
super(RpcEndpointMonitor, self).info(*args, **kwargs)
|
||||
|
||||
|
||||
def rpc_server(transport, target, wait_before_answer, executor, show_stats):
|
||||
endpoint_cls = RpcEndpointMonitor if show_stats else RpcEndpoint
|
||||
endpoints = [endpoint_cls(wait_before_answer)]
|
||||
endpoints = [RpcEndpoint(wait_before_answer, show_stats)]
|
||||
server = rpc.get_rpc_server(transport, target, endpoints,
|
||||
executor=executor)
|
||||
server.start()
|
||||
@ -244,6 +278,11 @@ def main():
|
||||
help='notify/rpc server/client mode')
|
||||
|
||||
server = subparsers.add_parser('notify-server')
|
||||
server.add_argument('--show-stats', dest='show_stats',
|
||||
type=bool, default=True)
|
||||
server = subparsers.add_parser('batch-notify-server')
|
||||
server.add_argument('--show-stats', dest='show_stats',
|
||||
type=bool, default=True)
|
||||
client = subparsers.add_parser('notify-client')
|
||||
client.add_argument('-p', dest='threads', type=int, default=1,
|
||||
help='number of client threads')
|
||||
@ -302,7 +341,9 @@ def main():
|
||||
rpc_server(transport, target, args.wait_before_answer, args.executor,
|
||||
args.show_stats)
|
||||
elif args.mode == 'notify-server':
|
||||
notify_server(transport)
|
||||
notify_server(transport, args.show_stats)
|
||||
elif args.mode == 'batch-notify-server':
|
||||
batch_notify_server(transport, args.show_stats)
|
||||
elif args.mode == 'notify-client':
|
||||
threads_spawner(args.threads, notifier, transport, args.messages,
|
||||
args.wait_after_msg, args.timeout)
|
||||
|
Loading…
Reference in New Issue
Block a user