Abstract out the worker finding from the WBE engine

To be able to easily plug-in future types of ways to get
which topics (and tasks) workers exist on (and can perform)
and to identify and keep this information up-to date refactor
the functionality that currently does this using periodic messages
into a finder type and a periodic function that exists on
it (that will be periodically activated by an updated and
improved periodic worker).

Part of blueprint wbe-worker-info

Change-Id: Ib3ae29758af3d244b4ac4624ac380caf88b159fd
This commit is contained in:
Joshua Harlow 2015-01-26 19:48:26 -08:00 committed by Joshua Harlow
parent 934e15a029
commit 19f9674877
12 changed files with 529 additions and 201 deletions

View File

@ -44,6 +44,11 @@ Notifier
.. automodule:: taskflow.types.notifier
Periodic
========
.. automodule:: taskflow.types.periodic
Table
=====

View File

@ -15,7 +15,6 @@
# under the License.
from kombu import exceptions as kombu_exc
import six
from taskflow import exceptions as excp
from taskflow import logging
@ -27,14 +26,35 @@ LOG = logging.getLogger(__name__)
class TypeDispatcher(object):
"""Receives messages and dispatches to type specific handlers."""
def __init__(self, type_handlers):
self._handlers = dict(type_handlers)
self._requeue_filters = []
def __init__(self, type_handlers=None, requeue_filters=None):
if type_handlers is not None:
self._type_handlers = dict(type_handlers)
else:
self._type_handlers = {}
if requeue_filters is not None:
self._requeue_filters = list(requeue_filters)
else:
self._requeue_filters = []
def add_requeue_filter(self, callback):
"""Add a callback that can *request* message requeuing.
@property
def type_handlers(self):
"""Dictionary of message type -> callback to handle that message.
The callback will be activated before the message has been acked and
The callback(s) will be activated by looking for a message
property 'type' and locating a callback in this dictionary that maps
to that type; if one is found it is expected to be a callback that
accepts two positional parameters; the first being the message data
and the second being the message object. If a callback is not found
then the message is rejected and it will be up to the underlying
message transport to determine what this means/implies...
"""
return self._type_handlers
@property
def requeue_filters(self):
"""List of filters (callbacks) to request a message to be requeued.
The callback(s) will be activated before the message has been acked and
it can be used to instruct the dispatcher to requeue the message
instead of processing it. The callback, when called, will be provided
two positional parameters; the first being the message data and the
@ -42,9 +62,7 @@ class TypeDispatcher(object):
filter should return a truthy object if the message should be requeued
and a falsey object if it should not.
"""
if not six.callable(callback):
raise ValueError("Requeue filter callback must be callable")
self._requeue_filters.append(callback)
return self._requeue_filters
def _collect_requeue_votes(self, data, message):
# Returns how many of the filters asked for the message to be requeued.
@ -74,7 +92,7 @@ class TypeDispatcher(object):
LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message))
def _process_message(self, data, message, message_type):
handler = self._handlers.get(message_type)
handler = self._type_handlers.get(message_type)
if handler is None:
message.reject_log_error(logger=LOG,
errors=(kombu_exc.MessageStateError,))

View File

@ -25,7 +25,7 @@ from taskflow.engines.worker_based import types as wt
from taskflow import exceptions as exc
from taskflow import logging
from taskflow import task as task_atom
from taskflow.types import timing as tt
from taskflow.types import periodic
from taskflow.utils import kombu_utils as ku
from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
@ -41,51 +41,41 @@ class WorkerTaskExecutor(executor.TaskExecutor):
url=None, transport=None, transport_options=None,
retry_options=None):
self._uuid = uuid
self._topics = topics
self._requests_cache = wt.RequestsCache()
self._workers = wt.TopicWorkers()
self._transition_timeout = transition_timeout
type_handlers = {
pr.NOTIFY: [
self._process_notify,
functools.partial(pr.Notify.validate, response=True),
],
pr.RESPONSE: [
self._process_response,
pr.Response.validate,
],
}
self._proxy = proxy.Proxy(uuid, exchange, type_handlers,
self._proxy = proxy.Proxy(uuid, exchange,
type_handlers=type_handlers,
on_wait=self._on_wait, url=url,
transport=transport,
transport_options=transport_options,
retry_options=retry_options)
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
[self._notify_topics])
# NOTE(harlowja): This is the most simplest finder impl. that
# doesn't have external dependencies (outside of what this engine
# already requires); it though does create periodic 'polling' traffic
# to workers to 'learn' of the tasks they can perform (and requires
# pre-existing knowledge of the topics those workers are on to gather
# and update this information).
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
self._finder.on_worker = self._on_worker
self._helpers = tu.ThreadBundle()
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
after_start=lambda t: self._proxy.wait(),
before_join=lambda t: self._proxy.stop())
self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start),
before_join=lambda t: self._periodic.stop(),
after_join=lambda t: self._periodic.reset(),
before_start=lambda t: self._periodic.reset())
p_worker = periodic.PeriodicWorker.create([self._finder])
if p_worker:
self._helpers.bind(lambda: tu.daemon_thread(p_worker.start),
before_join=lambda t: p_worker.stop(),
after_join=lambda t: p_worker.reset(),
before_start=lambda t: p_worker.reset())
def _process_notify(self, notify, message):
"""Process notify message from remote side."""
LOG.debug("Started processing notify message '%s'",
ku.DelayedPretty(message))
topic = notify['topic']
tasks = notify['tasks']
# Add worker info to the cache
worker = self._workers.add(topic, tasks)
LOG.debug("Received notification about worker '%s' (%s"
" total workers are currently known)", worker,
len(self._workers))
# Publish waiting requests
def _on_worker(self, worker):
"""Process new worker that has arrived (and fire off any work)."""
for request in self._requests_cache.get_waiting_requests(worker):
if request.transition_and_log_error(pr.PENDING, logger=LOG):
self._publish_request(request, worker)
@ -174,7 +164,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
request.result.add_done_callback(lambda fut: cleaner())
# Get task's worker and publish request if worker was found.
worker = self._workers.get_worker_for_task(task)
worker = self._finder.get_worker_for_task(task)
if worker is not None:
# NOTE(skudriashev): Make sure request is set to the PENDING state
# before putting it into the requests cache to prevent the notify
@ -208,10 +198,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
del self._requests_cache[request.uuid]
request.set_result(failure)
def _notify_topics(self):
"""Cyclically called to publish notify message to each topic."""
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
def execute_task(self, task, task_uuid, arguments,
progress_callback=None):
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
@ -232,7 +218,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
return how many workers are still needed, otherwise it will
return zero.
"""
return self._workers.wait_for_workers(workers=workers, timeout=timeout)
return self._finder.wait_for_workers(workers=workers,
timeout=timeout)
def start(self):
"""Starts proxy thread and associated topic notification thread."""
@ -242,4 +229,4 @@ class WorkerTaskExecutor(executor.TaskExecutor):
"""Stops proxy thread and associated topic notification thread."""
self._helpers.stop()
self._requests_cache.clear(self._handle_expired_request)
self._workers.clear()
self._finder.clear()

View File

@ -68,19 +68,19 @@ class Proxy(object):
# value is valid...
_RETRY_INT_OPTS = frozenset(['max_retries'])
def __init__(self, topic, exchange, type_handlers,
on_wait=None, url=None,
def __init__(self, topic, exchange,
type_handlers=None, on_wait=None, url=None,
transport=None, transport_options=None,
retry_options=None):
self._topic = topic
self._exchange_name = exchange
self._on_wait = on_wait
self._running = threading_utils.Event()
self._dispatcher = dispatcher.TypeDispatcher(type_handlers)
self._dispatcher.add_requeue_filter(
self._dispatcher = dispatcher.TypeDispatcher(
# NOTE(skudriashev): Process all incoming messages only if proxy is
# running, otherwise requeue them.
lambda data, message: not self.is_running)
requeue_filters=[lambda data, message: not self.is_running],
type_handlers=type_handlers)
ensure_options = self.DEFAULT_RETRY_OPTIONS.copy()
if retry_options is not None:
@ -112,11 +112,16 @@ class Proxy(object):
# create exchange
self._exchange = kombu.Exchange(name=self._exchange_name,
durable=False,
auto_delete=True)
durable=False, auto_delete=True)
@property
def dispatcher(self):
"""Dispatcher internally used to dispatch message(s) that match."""
return self._dispatcher
@property
def connection_details(self):
"""Details about the connection (read-only)."""
# The kombu drivers seem to use 'N/A' when they don't have a version...
driver_version = self._conn.transport.driver_version()
if driver_version and driver_version.lower() == 'n/a':

View File

@ -59,7 +59,8 @@ class Server(object):
pr.Request.validate,
],
}
self._proxy = proxy.Proxy(topic, exchange, type_handlers,
self._proxy = proxy.Proxy(topic, exchange,
type_handlers=type_handlers,
url=url, transport=transport,
transport_options=transport_options,
retry_options=retry_options)

View File

@ -14,17 +14,21 @@
# License for the specific language governing permissions and limitations
# under the License.
import abc
import functools
import itertools
import logging
import random
import threading
from oslo.utils import reflection
from oslo_utils import reflection
import six
from taskflow.engines.worker_based import protocol as pr
from taskflow import logging
from taskflow.types import cache as base
from taskflow.types import periodic
from taskflow.types import timing as tt
from taskflow.utils import kombu_utils as ku
LOG = logging.getLogger(__name__)
@ -91,8 +95,37 @@ class TopicWorker(object):
return r
class TopicWorkers(object):
"""A collection of topic based workers."""
@six.add_metaclass(abc.ABCMeta)
class WorkerFinder(object):
"""Base class for worker finders..."""
def __init__(self):
self._cond = threading.Condition()
self.on_worker = None
@abc.abstractmethod
def _total_workers(self):
"""Returns how many workers are known."""
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.
NOTE(harlowja): if a timeout is provided this function will wait
until that timeout expires, if the amount of workers does not reach
the desired amount of workers before the timeout expires then this will
return how many workers are still needed, otherwise it will
return zero.
"""
if workers <= 0:
raise ValueError("Worker amount must be greater than zero")
watch = tt.StopWatch(duration=timeout)
watch.start()
with self._cond:
while self._total_workers() < workers:
if watch.expired():
return max(0, workers - self._total_workers())
self._cond.wait(watch.leftover(return_none=True))
return 0
@staticmethod
def _match_worker(task, available_workers):
@ -110,14 +143,30 @@ class TopicWorkers(object):
else:
return random.choice(available_workers)
def __init__(self):
self._workers = {}
self._cond = threading.Condition()
# Used to name workers with more useful identities...
self._counter = itertools.count()
@abc.abstractmethod
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
def __len__(self):
return len(self._workers)
def clear(self):
pass
class ProxyWorkerFinder(WorkerFinder):
"""Requests and receives responses about workers topic+task details."""
def __init__(self, uuid, proxy, topics):
super(ProxyWorkerFinder, self).__init__()
self._proxy = proxy
self._topics = topics
self._workers = {}
self._uuid = uuid
self._proxy.dispatcher.type_handlers.update({
pr.NOTIFY: [
self._process_response,
functools.partial(pr.Notify.validate, response=True),
],
})
self._counter = itertools.count()
def _next_worker(self, topic, tasks, temporary=False):
if not temporary:
@ -126,48 +175,54 @@ class TopicWorkers(object):
else:
return TopicWorker(topic, tasks)
def add(self, topic, tasks):
@periodic.periodic(pr.NOTIFY_PERIOD)
def beat(self):
"""Cyclically called to publish notify message to each topic."""
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
def _total_workers(self):
return len(self._workers)
def _add(self, topic, tasks):
"""Adds/updates a worker for the topic for the given tasks."""
try:
worker = self._workers[topic]
# Check if we already have an equivalent worker, if so just
# return it...
if worker == self._next_worker(topic, tasks, temporary=True):
return (worker, False)
# This *fall through* is done so that if someone is using an
# active worker object that already exists that we just create
# a new one; so that the existing object doesn't get
# affected (workers objects are supposed to be immutable).
except KeyError:
pass
worker = self._next_worker(topic, tasks)
self._workers[topic] = worker
return (worker, True)
def _process_response(self, response, message):
"""Process notify message from remote side."""
LOG.debug("Started processing notify message '%s'",
ku.DelayedPretty(message))
topic = response['topic']
tasks = response['tasks']
with self._cond:
try:
worker = self._workers[topic]
# Check if we already have an equivalent worker, if so just
# return it...
if worker == self._next_worker(topic, tasks, temporary=True):
return worker
# This *fall through* is done so that if someone is using an
# active worker object that already exists that we just create
# a new one; so that the existing object doesn't get
# affected (workers objects are supposed to be immutable).
except KeyError:
pass
worker = self._next_worker(topic, tasks)
self._workers[topic] = worker
worker, new_or_updated = self._add(topic, tasks)
if new_or_updated:
LOG.debug("Received notification about worker '%s' (%s"
" total workers are currently known)", worker,
self._total_workers())
self._cond.notify_all()
if self.on_worker is not None and new_or_updated:
self.on_worker(worker)
def clear(self):
with self._cond:
self._workers.clear()
self._cond.notify_all()
return worker
def wait_for_workers(self, workers=1, timeout=None):
"""Waits for geq workers to notify they are ready to do work.
NOTE(harlowja): if a timeout is provided this function will wait
until that timeout expires, if the amount of workers does not reach
the desired amount of workers before the timeout expires then this will
return how many workers are still needed, otherwise it will
return zero.
"""
if workers <= 0:
raise ValueError("Worker amount must be greater than zero")
watch = tt.StopWatch(duration=timeout)
watch.start()
with self._cond:
while len(self._workers) < workers:
if watch.expired():
return max(0, workers - len(self._workers))
self._cond.wait(watch.leftover(return_none=True))
return 0
def get_worker_for_task(self, task):
"""Gets a worker that can perform a given task."""
available_workers = []
with self._cond:
for worker in six.itervalues(self._workers):
@ -177,37 +232,3 @@ class TopicWorkers(object):
return self._match_worker(task, available_workers)
else:
return None
def clear(self):
with self._cond:
self._workers.clear()
self._cond.notify_all()
class PeriodicWorker(object):
"""Calls a set of functions when activated periodically.
NOTE(harlowja): the provided timeout object determines the periodicity.
"""
def __init__(self, timeout, functors):
self._timeout = timeout
self._functors = []
for f in functors:
self._functors.append((f, reflection.get_callable_name(f)))
def start(self):
while not self._timeout.is_stopped():
for (f, f_name) in self._functors:
LOG.debug("Calling periodic function '%s'", f_name)
try:
f()
except Exception:
LOG.warn("Failed to call periodic function '%s'", f_name,
exc_info=True)
self._timeout.wait()
def stop(self):
self._timeout.interrupt()
def reset(self):
self._timeout.reset()

View File

@ -14,6 +14,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import time
import networkx as nx
import six
@ -21,9 +23,31 @@ from taskflow import exceptions as excp
from taskflow import test
from taskflow.types import fsm
from taskflow.types import graph
from taskflow.types import latch
from taskflow.types import periodic
from taskflow.types import table
from taskflow.types import timing as tt
from taskflow.types import tree
from taskflow.utils import threading_utils as tu
class PeriodicThingy(object):
def __init__(self):
self.capture = []
@periodic.periodic(0.01)
def a(self):
self.capture.append('a')
@periodic.periodic(0.02)
def b(self):
self.capture.append('b')
def c(self):
pass
def d(self):
pass
class GraphTest(test.TestCase):
@ -451,3 +475,112 @@ class FSMTest(test.TestCase):
m.add_state('broken')
self.assertRaises(ValueError, m.add_state, 'b', on_enter=2)
self.assertRaises(ValueError, m.add_state, 'b', on_exit=2)
class PeriodicTest(test.TestCase):
def test_invalid_periodic(self):
def no_op():
pass
self.assertRaises(ValueError, periodic.periodic, -1)
def test_valid_periodic(self):
@periodic.periodic(2)
def no_op():
pass
self.assertTrue(getattr(no_op, '_periodic'))
self.assertEqual(2, getattr(no_op, '_periodic_spacing'))
self.assertEqual(True, getattr(no_op, '_periodic_run_immediately'))
def test_scanning_periodic(self):
p = PeriodicThingy()
w = periodic.PeriodicWorker.create([p])
self.assertEqual(2, len(w))
t = tu.daemon_thread(target=w.start)
t.start()
time.sleep(0.1)
w.stop()
t.join()
b_calls = [c for c in p.capture if c == 'b']
self.assertGreater(0, len(b_calls))
a_calls = [c for c in p.capture if c == 'a']
self.assertGreater(0, len(a_calls))
def test_periodic_single(self):
barrier = latch.Latch(5)
capture = []
tombstone = tu.Event()
@periodic.periodic(0.01)
def callee():
barrier.countdown()
if barrier.needed == 0:
tombstone.set()
capture.append(1)
w = periodic.PeriodicWorker([callee], tombstone=tombstone)
t = tu.daemon_thread(target=w.start)
t.start()
t.join()
self.assertEqual(0, barrier.needed)
self.assertEqual(5, sum(capture))
self.assertTrue(tombstone.is_set())
def test_immediate(self):
capture = []
@periodic.periodic(120, run_immediately=True)
def a():
capture.append('a')
w = periodic.PeriodicWorker([a])
t = tu.daemon_thread(target=w.start)
t.start()
time.sleep(0.1)
w.stop()
t.join()
a_calls = [c for c in capture if c == 'a']
self.assertGreater(0, len(a_calls))
def test_period_double_no_immediate(self):
capture = []
@periodic.periodic(0.01, run_immediately=False)
def a():
capture.append('a')
@periodic.periodic(0.02, run_immediately=False)
def b():
capture.append('b')
w = periodic.PeriodicWorker([a, b])
t = tu.daemon_thread(target=w.start)
t.start()
time.sleep(0.1)
w.stop()
t.join()
b_calls = [c for c in capture if c == 'b']
self.assertGreater(0, len(b_calls))
a_calls = [c for c in capture if c == 'a']
self.assertGreater(0, len(a_calls))
def test_start_nothing_error(self):
w = periodic.PeriodicWorker([])
self.assertRaises(RuntimeError, w.start)
def test_missing_function_attrs(self):
def fake_periodic():
pass
cb = fake_periodic
self.assertRaises(ValueError, periodic.PeriodicWorker, [cb])

View File

@ -41,12 +41,12 @@ class TestDispatcher(test.TestCase):
def test_creation(self):
on_hello = mock.MagicMock()
handlers = {'hello': on_hello}
dispatcher.TypeDispatcher(handlers)
dispatcher.TypeDispatcher(type_handlers=handlers)
def test_on_message(self):
on_hello = mock.MagicMock()
handlers = {'hello': on_hello}
d = dispatcher.TypeDispatcher(handlers)
d = dispatcher.TypeDispatcher(type_handlers=handlers)
msg = mock_acked_message(properties={'type': 'hello'})
d.on_message("", msg)
self.assertTrue(on_hello.called)
@ -54,15 +54,15 @@ class TestDispatcher(test.TestCase):
self.assertTrue(msg.acknowledged)
def test_on_rejected_message(self):
d = dispatcher.TypeDispatcher({})
d = dispatcher.TypeDispatcher()
msg = mock_acked_message(properties={'type': 'hello'})
d.on_message("", msg)
self.assertTrue(msg.reject_log_error.called)
self.assertFalse(msg.acknowledged)
def test_on_requeue_message(self):
d = dispatcher.TypeDispatcher({})
d.add_requeue_filter(lambda data, message: True)
d = dispatcher.TypeDispatcher()
d.requeue_filters.append(lambda data, message: True)
msg = mock_acked_message()
d.on_message("", msg)
self.assertTrue(msg.requeue.called)
@ -71,7 +71,7 @@ class TestDispatcher(test.TestCase):
def test_failed_ack(self):
on_hello = mock.MagicMock()
handlers = {'hello': on_hello}
d = dispatcher.TypeDispatcher(handlers)
d = dispatcher.TypeDispatcher(type_handlers=handlers)
msg = mock_acked_message(ack_ok=False,
properties={'type': 'hello'})
d.on_message("", msg)

View File

@ -86,11 +86,12 @@ class TestWorkerTaskExecutor(test.MockTestCase):
ex = self.executor(reset_master_mock=False)
master_mock_calls = [
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
mock.ANY, on_wait=ex._on_wait,
on_wait=ex._on_wait,
url=self.broker_url, transport=mock.ANY,
transport_options=mock.ANY,
retry_options=mock.ANY
)
retry_options=mock.ANY,
type_handlers=mock.ANY),
mock.call.proxy.dispatcher.type_handlers.update(mock.ANY),
]
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
@ -212,10 +213,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertEqual(len(ex._requests_cache), 0)
def test_execute_task(self):
self.message_mock.properties['type'] = pr.NOTIFY
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
ex = self.executor()
ex._process_notify(notify.to_dict(), self.message_mock)
ex._finder._add(self.executor_topic, [self.task.name])
ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [
@ -231,10 +230,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertEqual(expected_calls, self.master_mock.mock_calls)
def test_revert_task(self):
self.message_mock.properties['type'] = pr.NOTIFY
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
ex = self.executor()
ex._process_notify(notify.to_dict(), self.message_mock)
ex._finder._add(self.executor_topic, [self.task.name])
ex.revert_task(self.task, self.task_uuid, self.task_args,
self.task_result, self.task_failures)
@ -263,11 +260,9 @@ class TestWorkerTaskExecutor(test.MockTestCase):
self.assertEqual(self.master_mock.mock_calls, expected_calls)
def test_execute_task_publish_error(self):
self.message_mock.properties['type'] = pr.NOTIFY
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
ex = self.executor()
ex._process_notify(notify.to_dict(), self.message_mock)
ex._finder._add(self.executor_topic, [self.task.name])
ex.execute_task(self.task, self.task_uuid, self.task_args)
expected_calls = [

View File

@ -86,7 +86,7 @@ class TestServer(test.MockTestCase):
# check calls
master_mock_calls = [
mock.call.Proxy(self.server_topic, self.server_exchange,
mock.ANY, url=self.broker_url,
type_handlers=mock.ANY, url=self.broker_url,
transport=mock.ANY, transport_options=mock.ANY,
retry_options=mock.ANY)
]
@ -99,7 +99,7 @@ class TestServer(test.MockTestCase):
# check calls
master_mock_calls = [
mock.call.Proxy(self.server_topic, self.server_exchange,
mock.ANY, url=self.broker_url,
type_handlers=mock.ANY, url=self.broker_url,
transport=mock.ANY, transport_options=mock.ANY,
retry_options=mock.ANY)
]

View File

@ -14,23 +14,20 @@
# License for the specific language governing permissions and limitations
# under the License.
import threading
import time
from oslo.utils import reflection
from taskflow.engines.worker_based import protocol as pr
from taskflow.engines.worker_based import types as worker_types
from taskflow import test
from taskflow.test import mock
from taskflow.tests import utils
from taskflow.types import latch
from taskflow.types import timing
class TestWorkerTypes(test.TestCase):
class TestRequestCache(test.TestCase):
def setUp(self):
super(TestWorkerTypes, self).setUp()
super(TestRequestCache, self).setUp()
self.addCleanup(timing.StopWatch.clear_overrides)
self.task = utils.DummyTask()
self.task_uuid = 'task-uuid'
@ -76,6 +73,8 @@ class TestWorkerTypes(test.TestCase):
self.assertEqual(1, len(matches))
self.assertEqual(2, len(cache))
class TestTopicWorker(test.TestCase):
def test_topic_worker(self):
worker = worker_types.TopicWorker("dummy-topic",
[utils.DummyTask], identity="dummy")
@ -84,52 +83,37 @@ class TestWorkerTypes(test.TestCase):
self.assertEqual('dummy', worker.identity)
self.assertEqual('dummy-topic', worker.topic)
def test_single_topic_workers(self):
workers = worker_types.TopicWorkers()
w = workers.add('dummy-topic', [utils.DummyTask])
class TestProxyFinder(test.TestCase):
def test_single_topic_worker(self):
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
w, emit = finder._add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
self.assertEqual(1, len(workers))
w2 = workers.get_worker_for_task(utils.DummyTask)
self.assertTrue(emit)
self.assertEqual(1, finder._total_workers())
w2 = finder.get_worker_for_task(utils.DummyTask)
self.assertEqual(w.identity, w2.identity)
def test_multi_same_topic_workers(self):
workers = worker_types.TopicWorkers()
w = workers.add('dummy-topic', [utils.DummyTask])
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
w, emit = finder._add('dummy-topic', [utils.DummyTask])
self.assertIsNotNone(w)
w2 = workers.add('dummy-topic-2', [utils.DummyTask])
self.assertTrue(emit)
w2, emit = finder._add('dummy-topic-2', [utils.DummyTask])
self.assertIsNotNone(w2)
w3 = workers.get_worker_for_task(
self.assertTrue(emit)
w3 = finder.get_worker_for_task(
reflection.get_class_name(utils.DummyTask))
self.assertIn(w3.identity, [w.identity, w2.identity])
def test_multi_different_topic_workers(self):
workers = worker_types.TopicWorkers()
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
added = []
added.append(workers.add('dummy-topic', [utils.DummyTask]))
added.append(workers.add('dummy-topic-2', [utils.DummyTask]))
added.append(workers.add('dummy-topic-3', [utils.NastyTask]))
self.assertEqual(3, len(workers))
w = workers.get_worker_for_task(utils.NastyTask)
self.assertEqual(added[-1].identity, w.identity)
w = workers.get_worker_for_task(utils.DummyTask)
self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]])
def test_periodic_worker(self):
barrier = latch.Latch(5)
to = timing.Timeout(0.01)
called_at = []
def callee():
barrier.countdown()
if barrier.needed == 0:
to.interrupt()
called_at.append(time.time())
w = worker_types.PeriodicWorker(to, [callee])
t = threading.Thread(target=w.start)
t.start()
t.join()
self.assertEqual(0, barrier.needed)
self.assertEqual(5, len(called_at))
self.assertTrue(to.is_stopped())
added.append(finder._add('dummy-topic', [utils.DummyTask]))
added.append(finder._add('dummy-topic-2', [utils.DummyTask]))
added.append(finder._add('dummy-topic-3', [utils.NastyTask]))
self.assertEqual(3, finder._total_workers())
w = finder.get_worker_for_task(utils.NastyTask)
self.assertEqual(added[-1][0].identity, w.identity)
w = finder.get_worker_for_task(utils.DummyTask)
self.assertIn(w.identity, [w_a[0].identity for w_a in added[0:2]])

179
taskflow/types/periodic.py Normal file
View File

@ -0,0 +1,179 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import heapq
import inspect
from oslo_utils import reflection
import six
from taskflow import logging
from taskflow.utils import misc
from taskflow.utils import threading_utils as tu
LOG = logging.getLogger(__name__)
# Find a monotonic providing time (or fallback to using time.time()
# which isn't *always* accurate but will suffice).
_now = misc.find_monotonic(allow_time_time=True)
# Attributes expected on periodic tagged/decorated functions or methods...
_PERIODIC_ATTRS = tuple([
'_periodic',
'_periodic_spacing',
'_periodic_run_immediately',
])
def periodic(spacing, run_immediately=True):
"""Tags a method/function as wanting/able to execute periodically."""
if spacing <= 0:
raise ValueError("Periodicity/spacing must be greater than"
" zero instead of %s" % spacing)
def wrapper(f):
f._periodic = True
f._periodic_spacing = spacing
f._periodic_run_immediately = run_immediately
@six.wraps(f)
def decorator(*args, **kwargs):
return f(*args, **kwargs)
return decorator
return wrapper
class PeriodicWorker(object):
"""Calls a collection of callables periodically (sleeping as needed...).
NOTE(harlowja): typically the :py:meth:`.start` method is executed in a
background thread so that the periodic callables are executed in
the background/asynchronously (using the defined periods to determine
when each is called).
"""
@classmethod
def create(cls, objects, exclude_hidden=True):
"""Automatically creates a worker by analyzing object(s) methods.
Only picks up methods that have been tagged/decorated with
the :py:func:`.periodic` decorator (does not match against private
or protected methods unless explicitly requested to).
"""
callables = []
for obj in objects:
for (name, member) in inspect.getmembers(obj):
if name.startswith("_") and exclude_hidden:
continue
if reflection.is_bound_method(member):
consume = True
for attr_name in _PERIODIC_ATTRS:
if not hasattr(member, attr_name):
consume = False
break
if consume:
callables.append(member)
return cls(callables)
def __init__(self, callables, tombstone=None):
if tombstone is None:
self._tombstone = tu.Event()
else:
# Allows someone to share an event (if they so want to...)
self._tombstone = tombstone
almost_callables = list(callables)
for cb in almost_callables:
if not six.callable(cb):
raise ValueError("Periodic callback must be callable")
for attr_name in _PERIODIC_ATTRS:
if not hasattr(cb, attr_name):
raise ValueError("Periodic callback missing required"
" attribute '%s'" % attr_name)
self._callables = tuple((cb, reflection.get_callable_name(cb))
for cb in almost_callables)
self._schedule = []
self._immediates = []
now = _now()
for i, (cb, cb_name) in enumerate(self._callables):
spacing = getattr(cb, '_periodic_spacing')
next_run = now + spacing
heapq.heappush(self._schedule, (next_run, i))
for (cb, cb_name) in reversed(self._callables):
if getattr(cb, '_periodic_run_immediately', False):
self._immediates.append((cb, cb_name))
def __len__(self):
return len(self._callables)
@staticmethod
def _safe_call(cb, cb_name, kind='periodic'):
try:
cb()
except Exception:
LOG.warn("Failed to call %s callable '%s'",
kind, cb_name, exc_info=True)
def start(self):
"""Starts running (will not stop/return until the tombstone is set).
NOTE(harlowja): If this worker has no contained callables this raises
a runtime error and does not run since it is impossible to periodically
run nothing.
"""
if not self._callables:
raise RuntimeError("A periodic worker can not start"
" without any callables")
while not self._tombstone.is_set():
if self._immediates:
cb, cb_name = self._immediates.pop()
LOG.debug("Calling immediate callable '%s'", cb_name)
self._safe_call(cb, cb_name, kind='immediate')
else:
# Figure out when we should run next (by selecting the
# minimum item from the heap, where the minimum should be
# the callable that needs to run next and has the lowest
# next desired run time).
now = _now()
next_run, i = heapq.heappop(self._schedule)
when_next = next_run - now
if when_next <= 0:
cb, cb_name = self._callables[i]
spacing = getattr(cb, '_periodic_spacing')
LOG.debug("Calling periodic callable '%s' (it runs every"
" %s seconds)", cb_name, spacing)
self._safe_call(cb, cb_name)
# Run again someday...
next_run = now + spacing
heapq.heappush(self._schedule, (next_run, i))
else:
# Gotta wait...
heapq.heappush(self._schedule, (next_run, i))
self._tombstone.wait(when_next)
def stop(self):
"""Sets the tombstone (this stops any further executions)."""
self._tombstone.set()
def reset(self):
"""Resets the tombstone and re-queues up any immediate executions."""
self._tombstone.clear()
self._immediates = []
for (cb, cb_name) in reversed(self._callables):
if getattr(cb, '_periodic_run_immediately', False):
self._immediates.append((cb, cb_name))