Abstract out the worker finding from the WBE engine
To be able to easily plug-in future types of ways to get which topics (and tasks) workers exist on (and can perform) and to identify and keep this information up-to date refactor the functionality that currently does this using periodic messages into a finder type and a periodic function that exists on it (that will be periodically activated by an updated and improved periodic worker). Part of blueprint wbe-worker-info Change-Id: Ib3ae29758af3d244b4ac4624ac380caf88b159fd
This commit is contained in:
parent
934e15a029
commit
19f9674877
@ -44,6 +44,11 @@ Notifier
|
||||
|
||||
.. automodule:: taskflow.types.notifier
|
||||
|
||||
Periodic
|
||||
========
|
||||
|
||||
.. automodule:: taskflow.types.periodic
|
||||
|
||||
Table
|
||||
=====
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
# under the License.
|
||||
|
||||
from kombu import exceptions as kombu_exc
|
||||
import six
|
||||
|
||||
from taskflow import exceptions as excp
|
||||
from taskflow import logging
|
||||
@ -27,14 +26,35 @@ LOG = logging.getLogger(__name__)
|
||||
class TypeDispatcher(object):
|
||||
"""Receives messages and dispatches to type specific handlers."""
|
||||
|
||||
def __init__(self, type_handlers):
|
||||
self._handlers = dict(type_handlers)
|
||||
self._requeue_filters = []
|
||||
def __init__(self, type_handlers=None, requeue_filters=None):
|
||||
if type_handlers is not None:
|
||||
self._type_handlers = dict(type_handlers)
|
||||
else:
|
||||
self._type_handlers = {}
|
||||
if requeue_filters is not None:
|
||||
self._requeue_filters = list(requeue_filters)
|
||||
else:
|
||||
self._requeue_filters = []
|
||||
|
||||
def add_requeue_filter(self, callback):
|
||||
"""Add a callback that can *request* message requeuing.
|
||||
@property
|
||||
def type_handlers(self):
|
||||
"""Dictionary of message type -> callback to handle that message.
|
||||
|
||||
The callback will be activated before the message has been acked and
|
||||
The callback(s) will be activated by looking for a message
|
||||
property 'type' and locating a callback in this dictionary that maps
|
||||
to that type; if one is found it is expected to be a callback that
|
||||
accepts two positional parameters; the first being the message data
|
||||
and the second being the message object. If a callback is not found
|
||||
then the message is rejected and it will be up to the underlying
|
||||
message transport to determine what this means/implies...
|
||||
"""
|
||||
return self._type_handlers
|
||||
|
||||
@property
|
||||
def requeue_filters(self):
|
||||
"""List of filters (callbacks) to request a message to be requeued.
|
||||
|
||||
The callback(s) will be activated before the message has been acked and
|
||||
it can be used to instruct the dispatcher to requeue the message
|
||||
instead of processing it. The callback, when called, will be provided
|
||||
two positional parameters; the first being the message data and the
|
||||
@ -42,9 +62,7 @@ class TypeDispatcher(object):
|
||||
filter should return a truthy object if the message should be requeued
|
||||
and a falsey object if it should not.
|
||||
"""
|
||||
if not six.callable(callback):
|
||||
raise ValueError("Requeue filter callback must be callable")
|
||||
self._requeue_filters.append(callback)
|
||||
return self._requeue_filters
|
||||
|
||||
def _collect_requeue_votes(self, data, message):
|
||||
# Returns how many of the filters asked for the message to be requeued.
|
||||
@ -74,7 +92,7 @@ class TypeDispatcher(object):
|
||||
LOG.debug("Message '%s' was requeued.", ku.DelayedPretty(message))
|
||||
|
||||
def _process_message(self, data, message, message_type):
|
||||
handler = self._handlers.get(message_type)
|
||||
handler = self._type_handlers.get(message_type)
|
||||
if handler is None:
|
||||
message.reject_log_error(logger=LOG,
|
||||
errors=(kombu_exc.MessageStateError,))
|
||||
|
@ -25,7 +25,7 @@ from taskflow.engines.worker_based import types as wt
|
||||
from taskflow import exceptions as exc
|
||||
from taskflow import logging
|
||||
from taskflow import task as task_atom
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.types import periodic
|
||||
from taskflow.utils import kombu_utils as ku
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils as tu
|
||||
@ -41,51 +41,41 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
url=None, transport=None, transport_options=None,
|
||||
retry_options=None):
|
||||
self._uuid = uuid
|
||||
self._topics = topics
|
||||
self._requests_cache = wt.RequestsCache()
|
||||
self._workers = wt.TopicWorkers()
|
||||
self._transition_timeout = transition_timeout
|
||||
type_handlers = {
|
||||
pr.NOTIFY: [
|
||||
self._process_notify,
|
||||
functools.partial(pr.Notify.validate, response=True),
|
||||
],
|
||||
pr.RESPONSE: [
|
||||
self._process_response,
|
||||
pr.Response.validate,
|
||||
],
|
||||
}
|
||||
self._proxy = proxy.Proxy(uuid, exchange, type_handlers,
|
||||
self._proxy = proxy.Proxy(uuid, exchange,
|
||||
type_handlers=type_handlers,
|
||||
on_wait=self._on_wait, url=url,
|
||||
transport=transport,
|
||||
transport_options=transport_options,
|
||||
retry_options=retry_options)
|
||||
self._periodic = wt.PeriodicWorker(tt.Timeout(pr.NOTIFY_PERIOD),
|
||||
[self._notify_topics])
|
||||
# NOTE(harlowja): This is the most simplest finder impl. that
|
||||
# doesn't have external dependencies (outside of what this engine
|
||||
# already requires); it though does create periodic 'polling' traffic
|
||||
# to workers to 'learn' of the tasks they can perform (and requires
|
||||
# pre-existing knowledge of the topics those workers are on to gather
|
||||
# and update this information).
|
||||
self._finder = wt.ProxyWorkerFinder(uuid, self._proxy, topics)
|
||||
self._finder.on_worker = self._on_worker
|
||||
self._helpers = tu.ThreadBundle()
|
||||
self._helpers.bind(lambda: tu.daemon_thread(self._proxy.start),
|
||||
after_start=lambda t: self._proxy.wait(),
|
||||
before_join=lambda t: self._proxy.stop())
|
||||
self._helpers.bind(lambda: tu.daemon_thread(self._periodic.start),
|
||||
before_join=lambda t: self._periodic.stop(),
|
||||
after_join=lambda t: self._periodic.reset(),
|
||||
before_start=lambda t: self._periodic.reset())
|
||||
p_worker = periodic.PeriodicWorker.create([self._finder])
|
||||
if p_worker:
|
||||
self._helpers.bind(lambda: tu.daemon_thread(p_worker.start),
|
||||
before_join=lambda t: p_worker.stop(),
|
||||
after_join=lambda t: p_worker.reset(),
|
||||
before_start=lambda t: p_worker.reset())
|
||||
|
||||
def _process_notify(self, notify, message):
|
||||
"""Process notify message from remote side."""
|
||||
LOG.debug("Started processing notify message '%s'",
|
||||
ku.DelayedPretty(message))
|
||||
|
||||
topic = notify['topic']
|
||||
tasks = notify['tasks']
|
||||
|
||||
# Add worker info to the cache
|
||||
worker = self._workers.add(topic, tasks)
|
||||
LOG.debug("Received notification about worker '%s' (%s"
|
||||
" total workers are currently known)", worker,
|
||||
len(self._workers))
|
||||
|
||||
# Publish waiting requests
|
||||
def _on_worker(self, worker):
|
||||
"""Process new worker that has arrived (and fire off any work)."""
|
||||
for request in self._requests_cache.get_waiting_requests(worker):
|
||||
if request.transition_and_log_error(pr.PENDING, logger=LOG):
|
||||
self._publish_request(request, worker)
|
||||
@ -174,7 +164,7 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
request.result.add_done_callback(lambda fut: cleaner())
|
||||
|
||||
# Get task's worker and publish request if worker was found.
|
||||
worker = self._workers.get_worker_for_task(task)
|
||||
worker = self._finder.get_worker_for_task(task)
|
||||
if worker is not None:
|
||||
# NOTE(skudriashev): Make sure request is set to the PENDING state
|
||||
# before putting it into the requests cache to prevent the notify
|
||||
@ -208,10 +198,6 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
del self._requests_cache[request.uuid]
|
||||
request.set_result(failure)
|
||||
|
||||
def _notify_topics(self):
|
||||
"""Cyclically called to publish notify message to each topic."""
|
||||
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
|
||||
|
||||
def execute_task(self, task, task_uuid, arguments,
|
||||
progress_callback=None):
|
||||
return self._submit_task(task, task_uuid, pr.EXECUTE, arguments,
|
||||
@ -232,7 +218,8 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
return how many workers are still needed, otherwise it will
|
||||
return zero.
|
||||
"""
|
||||
return self._workers.wait_for_workers(workers=workers, timeout=timeout)
|
||||
return self._finder.wait_for_workers(workers=workers,
|
||||
timeout=timeout)
|
||||
|
||||
def start(self):
|
||||
"""Starts proxy thread and associated topic notification thread."""
|
||||
@ -242,4 +229,4 @@ class WorkerTaskExecutor(executor.TaskExecutor):
|
||||
"""Stops proxy thread and associated topic notification thread."""
|
||||
self._helpers.stop()
|
||||
self._requests_cache.clear(self._handle_expired_request)
|
||||
self._workers.clear()
|
||||
self._finder.clear()
|
||||
|
@ -68,19 +68,19 @@ class Proxy(object):
|
||||
# value is valid...
|
||||
_RETRY_INT_OPTS = frozenset(['max_retries'])
|
||||
|
||||
def __init__(self, topic, exchange, type_handlers,
|
||||
on_wait=None, url=None,
|
||||
def __init__(self, topic, exchange,
|
||||
type_handlers=None, on_wait=None, url=None,
|
||||
transport=None, transport_options=None,
|
||||
retry_options=None):
|
||||
self._topic = topic
|
||||
self._exchange_name = exchange
|
||||
self._on_wait = on_wait
|
||||
self._running = threading_utils.Event()
|
||||
self._dispatcher = dispatcher.TypeDispatcher(type_handlers)
|
||||
self._dispatcher.add_requeue_filter(
|
||||
self._dispatcher = dispatcher.TypeDispatcher(
|
||||
# NOTE(skudriashev): Process all incoming messages only if proxy is
|
||||
# running, otherwise requeue them.
|
||||
lambda data, message: not self.is_running)
|
||||
requeue_filters=[lambda data, message: not self.is_running],
|
||||
type_handlers=type_handlers)
|
||||
|
||||
ensure_options = self.DEFAULT_RETRY_OPTIONS.copy()
|
||||
if retry_options is not None:
|
||||
@ -112,11 +112,16 @@ class Proxy(object):
|
||||
|
||||
# create exchange
|
||||
self._exchange = kombu.Exchange(name=self._exchange_name,
|
||||
durable=False,
|
||||
auto_delete=True)
|
||||
durable=False, auto_delete=True)
|
||||
|
||||
@property
|
||||
def dispatcher(self):
|
||||
"""Dispatcher internally used to dispatch message(s) that match."""
|
||||
return self._dispatcher
|
||||
|
||||
@property
|
||||
def connection_details(self):
|
||||
"""Details about the connection (read-only)."""
|
||||
# The kombu drivers seem to use 'N/A' when they don't have a version...
|
||||
driver_version = self._conn.transport.driver_version()
|
||||
if driver_version and driver_version.lower() == 'n/a':
|
||||
|
@ -59,7 +59,8 @@ class Server(object):
|
||||
pr.Request.validate,
|
||||
],
|
||||
}
|
||||
self._proxy = proxy.Proxy(topic, exchange, type_handlers,
|
||||
self._proxy = proxy.Proxy(topic, exchange,
|
||||
type_handlers=type_handlers,
|
||||
url=url, transport=transport,
|
||||
transport_options=transport_options,
|
||||
retry_options=retry_options)
|
||||
|
@ -14,17 +14,21 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
|
||||
from oslo.utils import reflection
|
||||
from oslo_utils import reflection
|
||||
import six
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow import logging
|
||||
from taskflow.types import cache as base
|
||||
from taskflow.types import periodic
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.utils import kombu_utils as ku
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -91,8 +95,37 @@ class TopicWorker(object):
|
||||
return r
|
||||
|
||||
|
||||
class TopicWorkers(object):
|
||||
"""A collection of topic based workers."""
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class WorkerFinder(object):
|
||||
"""Base class for worker finders..."""
|
||||
|
||||
def __init__(self):
|
||||
self._cond = threading.Condition()
|
||||
self.on_worker = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def _total_workers(self):
|
||||
"""Returns how many workers are known."""
|
||||
|
||||
def wait_for_workers(self, workers=1, timeout=None):
|
||||
"""Waits for geq workers to notify they are ready to do work.
|
||||
|
||||
NOTE(harlowja): if a timeout is provided this function will wait
|
||||
until that timeout expires, if the amount of workers does not reach
|
||||
the desired amount of workers before the timeout expires then this will
|
||||
return how many workers are still needed, otherwise it will
|
||||
return zero.
|
||||
"""
|
||||
if workers <= 0:
|
||||
raise ValueError("Worker amount must be greater than zero")
|
||||
watch = tt.StopWatch(duration=timeout)
|
||||
watch.start()
|
||||
with self._cond:
|
||||
while self._total_workers() < workers:
|
||||
if watch.expired():
|
||||
return max(0, workers - self._total_workers())
|
||||
self._cond.wait(watch.leftover(return_none=True))
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def _match_worker(task, available_workers):
|
||||
@ -110,14 +143,30 @@ class TopicWorkers(object):
|
||||
else:
|
||||
return random.choice(available_workers)
|
||||
|
||||
def __init__(self):
|
||||
self._workers = {}
|
||||
self._cond = threading.Condition()
|
||||
# Used to name workers with more useful identities...
|
||||
self._counter = itertools.count()
|
||||
@abc.abstractmethod
|
||||
def get_worker_for_task(self, task):
|
||||
"""Gets a worker that can perform a given task."""
|
||||
|
||||
def __len__(self):
|
||||
return len(self._workers)
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyWorkerFinder(WorkerFinder):
|
||||
"""Requests and receives responses about workers topic+task details."""
|
||||
|
||||
def __init__(self, uuid, proxy, topics):
|
||||
super(ProxyWorkerFinder, self).__init__()
|
||||
self._proxy = proxy
|
||||
self._topics = topics
|
||||
self._workers = {}
|
||||
self._uuid = uuid
|
||||
self._proxy.dispatcher.type_handlers.update({
|
||||
pr.NOTIFY: [
|
||||
self._process_response,
|
||||
functools.partial(pr.Notify.validate, response=True),
|
||||
],
|
||||
})
|
||||
self._counter = itertools.count()
|
||||
|
||||
def _next_worker(self, topic, tasks, temporary=False):
|
||||
if not temporary:
|
||||
@ -126,48 +175,54 @@ class TopicWorkers(object):
|
||||
else:
|
||||
return TopicWorker(topic, tasks)
|
||||
|
||||
def add(self, topic, tasks):
|
||||
@periodic.periodic(pr.NOTIFY_PERIOD)
|
||||
def beat(self):
|
||||
"""Cyclically called to publish notify message to each topic."""
|
||||
self._proxy.publish(pr.Notify(), self._topics, reply_to=self._uuid)
|
||||
|
||||
def _total_workers(self):
|
||||
return len(self._workers)
|
||||
|
||||
def _add(self, topic, tasks):
|
||||
"""Adds/updates a worker for the topic for the given tasks."""
|
||||
try:
|
||||
worker = self._workers[topic]
|
||||
# Check if we already have an equivalent worker, if so just
|
||||
# return it...
|
||||
if worker == self._next_worker(topic, tasks, temporary=True):
|
||||
return (worker, False)
|
||||
# This *fall through* is done so that if someone is using an
|
||||
# active worker object that already exists that we just create
|
||||
# a new one; so that the existing object doesn't get
|
||||
# affected (workers objects are supposed to be immutable).
|
||||
except KeyError:
|
||||
pass
|
||||
worker = self._next_worker(topic, tasks)
|
||||
self._workers[topic] = worker
|
||||
return (worker, True)
|
||||
|
||||
def _process_response(self, response, message):
|
||||
"""Process notify message from remote side."""
|
||||
LOG.debug("Started processing notify message '%s'",
|
||||
ku.DelayedPretty(message))
|
||||
topic = response['topic']
|
||||
tasks = response['tasks']
|
||||
with self._cond:
|
||||
try:
|
||||
worker = self._workers[topic]
|
||||
# Check if we already have an equivalent worker, if so just
|
||||
# return it...
|
||||
if worker == self._next_worker(topic, tasks, temporary=True):
|
||||
return worker
|
||||
# This *fall through* is done so that if someone is using an
|
||||
# active worker object that already exists that we just create
|
||||
# a new one; so that the existing object doesn't get
|
||||
# affected (workers objects are supposed to be immutable).
|
||||
except KeyError:
|
||||
pass
|
||||
worker = self._next_worker(topic, tasks)
|
||||
self._workers[topic] = worker
|
||||
worker, new_or_updated = self._add(topic, tasks)
|
||||
if new_or_updated:
|
||||
LOG.debug("Received notification about worker '%s' (%s"
|
||||
" total workers are currently known)", worker,
|
||||
self._total_workers())
|
||||
self._cond.notify_all()
|
||||
if self.on_worker is not None and new_or_updated:
|
||||
self.on_worker(worker)
|
||||
|
||||
def clear(self):
|
||||
with self._cond:
|
||||
self._workers.clear()
|
||||
self._cond.notify_all()
|
||||
return worker
|
||||
|
||||
def wait_for_workers(self, workers=1, timeout=None):
|
||||
"""Waits for geq workers to notify they are ready to do work.
|
||||
|
||||
NOTE(harlowja): if a timeout is provided this function will wait
|
||||
until that timeout expires, if the amount of workers does not reach
|
||||
the desired amount of workers before the timeout expires then this will
|
||||
return how many workers are still needed, otherwise it will
|
||||
return zero.
|
||||
"""
|
||||
if workers <= 0:
|
||||
raise ValueError("Worker amount must be greater than zero")
|
||||
watch = tt.StopWatch(duration=timeout)
|
||||
watch.start()
|
||||
with self._cond:
|
||||
while len(self._workers) < workers:
|
||||
if watch.expired():
|
||||
return max(0, workers - len(self._workers))
|
||||
self._cond.wait(watch.leftover(return_none=True))
|
||||
return 0
|
||||
|
||||
def get_worker_for_task(self, task):
|
||||
"""Gets a worker that can perform a given task."""
|
||||
available_workers = []
|
||||
with self._cond:
|
||||
for worker in six.itervalues(self._workers):
|
||||
@ -177,37 +232,3 @@ class TopicWorkers(object):
|
||||
return self._match_worker(task, available_workers)
|
||||
else:
|
||||
return None
|
||||
|
||||
def clear(self):
|
||||
with self._cond:
|
||||
self._workers.clear()
|
||||
self._cond.notify_all()
|
||||
|
||||
|
||||
class PeriodicWorker(object):
|
||||
"""Calls a set of functions when activated periodically.
|
||||
|
||||
NOTE(harlowja): the provided timeout object determines the periodicity.
|
||||
"""
|
||||
def __init__(self, timeout, functors):
|
||||
self._timeout = timeout
|
||||
self._functors = []
|
||||
for f in functors:
|
||||
self._functors.append((f, reflection.get_callable_name(f)))
|
||||
|
||||
def start(self):
|
||||
while not self._timeout.is_stopped():
|
||||
for (f, f_name) in self._functors:
|
||||
LOG.debug("Calling periodic function '%s'", f_name)
|
||||
try:
|
||||
f()
|
||||
except Exception:
|
||||
LOG.warn("Failed to call periodic function '%s'", f_name,
|
||||
exc_info=True)
|
||||
self._timeout.wait()
|
||||
|
||||
def stop(self):
|
||||
self._timeout.interrupt()
|
||||
|
||||
def reset(self):
|
||||
self._timeout.reset()
|
||||
|
@ -14,6 +14,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import time
|
||||
|
||||
import networkx as nx
|
||||
import six
|
||||
|
||||
@ -21,9 +23,31 @@ from taskflow import exceptions as excp
|
||||
from taskflow import test
|
||||
from taskflow.types import fsm
|
||||
from taskflow.types import graph
|
||||
from taskflow.types import latch
|
||||
from taskflow.types import periodic
|
||||
from taskflow.types import table
|
||||
from taskflow.types import timing as tt
|
||||
from taskflow.types import tree
|
||||
from taskflow.utils import threading_utils as tu
|
||||
|
||||
|
||||
class PeriodicThingy(object):
|
||||
def __init__(self):
|
||||
self.capture = []
|
||||
|
||||
@periodic.periodic(0.01)
|
||||
def a(self):
|
||||
self.capture.append('a')
|
||||
|
||||
@periodic.periodic(0.02)
|
||||
def b(self):
|
||||
self.capture.append('b')
|
||||
|
||||
def c(self):
|
||||
pass
|
||||
|
||||
def d(self):
|
||||
pass
|
||||
|
||||
|
||||
class GraphTest(test.TestCase):
|
||||
@ -451,3 +475,112 @@ class FSMTest(test.TestCase):
|
||||
m.add_state('broken')
|
||||
self.assertRaises(ValueError, m.add_state, 'b', on_enter=2)
|
||||
self.assertRaises(ValueError, m.add_state, 'b', on_exit=2)
|
||||
|
||||
|
||||
class PeriodicTest(test.TestCase):
|
||||
|
||||
def test_invalid_periodic(self):
|
||||
|
||||
def no_op():
|
||||
pass
|
||||
|
||||
self.assertRaises(ValueError, periodic.periodic, -1)
|
||||
|
||||
def test_valid_periodic(self):
|
||||
|
||||
@periodic.periodic(2)
|
||||
def no_op():
|
||||
pass
|
||||
|
||||
self.assertTrue(getattr(no_op, '_periodic'))
|
||||
self.assertEqual(2, getattr(no_op, '_periodic_spacing'))
|
||||
self.assertEqual(True, getattr(no_op, '_periodic_run_immediately'))
|
||||
|
||||
def test_scanning_periodic(self):
|
||||
p = PeriodicThingy()
|
||||
w = periodic.PeriodicWorker.create([p])
|
||||
self.assertEqual(2, len(w))
|
||||
|
||||
t = tu.daemon_thread(target=w.start)
|
||||
t.start()
|
||||
time.sleep(0.1)
|
||||
w.stop()
|
||||
t.join()
|
||||
|
||||
b_calls = [c for c in p.capture if c == 'b']
|
||||
self.assertGreater(0, len(b_calls))
|
||||
a_calls = [c for c in p.capture if c == 'a']
|
||||
self.assertGreater(0, len(a_calls))
|
||||
|
||||
def test_periodic_single(self):
|
||||
barrier = latch.Latch(5)
|
||||
capture = []
|
||||
tombstone = tu.Event()
|
||||
|
||||
@periodic.periodic(0.01)
|
||||
def callee():
|
||||
barrier.countdown()
|
||||
if barrier.needed == 0:
|
||||
tombstone.set()
|
||||
capture.append(1)
|
||||
|
||||
w = periodic.PeriodicWorker([callee], tombstone=tombstone)
|
||||
t = tu.daemon_thread(target=w.start)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
self.assertEqual(0, barrier.needed)
|
||||
self.assertEqual(5, sum(capture))
|
||||
self.assertTrue(tombstone.is_set())
|
||||
|
||||
def test_immediate(self):
|
||||
capture = []
|
||||
|
||||
@periodic.periodic(120, run_immediately=True)
|
||||
def a():
|
||||
capture.append('a')
|
||||
|
||||
w = periodic.PeriodicWorker([a])
|
||||
t = tu.daemon_thread(target=w.start)
|
||||
t.start()
|
||||
time.sleep(0.1)
|
||||
w.stop()
|
||||
t.join()
|
||||
|
||||
a_calls = [c for c in capture if c == 'a']
|
||||
self.assertGreater(0, len(a_calls))
|
||||
|
||||
def test_period_double_no_immediate(self):
|
||||
capture = []
|
||||
|
||||
@periodic.periodic(0.01, run_immediately=False)
|
||||
def a():
|
||||
capture.append('a')
|
||||
|
||||
@periodic.periodic(0.02, run_immediately=False)
|
||||
def b():
|
||||
capture.append('b')
|
||||
|
||||
w = periodic.PeriodicWorker([a, b])
|
||||
t = tu.daemon_thread(target=w.start)
|
||||
t.start()
|
||||
time.sleep(0.1)
|
||||
w.stop()
|
||||
t.join()
|
||||
|
||||
b_calls = [c for c in capture if c == 'b']
|
||||
self.assertGreater(0, len(b_calls))
|
||||
a_calls = [c for c in capture if c == 'a']
|
||||
self.assertGreater(0, len(a_calls))
|
||||
|
||||
def test_start_nothing_error(self):
|
||||
w = periodic.PeriodicWorker([])
|
||||
self.assertRaises(RuntimeError, w.start)
|
||||
|
||||
def test_missing_function_attrs(self):
|
||||
|
||||
def fake_periodic():
|
||||
pass
|
||||
|
||||
cb = fake_periodic
|
||||
self.assertRaises(ValueError, periodic.PeriodicWorker, [cb])
|
||||
|
@ -41,12 +41,12 @@ class TestDispatcher(test.TestCase):
|
||||
def test_creation(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
dispatcher.TypeDispatcher(handlers)
|
||||
dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||
|
||||
def test_on_message(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
d = dispatcher.TypeDispatcher(handlers)
|
||||
d = dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||
msg = mock_acked_message(properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(on_hello.called)
|
||||
@ -54,15 +54,15 @@ class TestDispatcher(test.TestCase):
|
||||
self.assertTrue(msg.acknowledged)
|
||||
|
||||
def test_on_rejected_message(self):
|
||||
d = dispatcher.TypeDispatcher({})
|
||||
d = dispatcher.TypeDispatcher()
|
||||
msg = mock_acked_message(properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(msg.reject_log_error.called)
|
||||
self.assertFalse(msg.acknowledged)
|
||||
|
||||
def test_on_requeue_message(self):
|
||||
d = dispatcher.TypeDispatcher({})
|
||||
d.add_requeue_filter(lambda data, message: True)
|
||||
d = dispatcher.TypeDispatcher()
|
||||
d.requeue_filters.append(lambda data, message: True)
|
||||
msg = mock_acked_message()
|
||||
d.on_message("", msg)
|
||||
self.assertTrue(msg.requeue.called)
|
||||
@ -71,7 +71,7 @@ class TestDispatcher(test.TestCase):
|
||||
def test_failed_ack(self):
|
||||
on_hello = mock.MagicMock()
|
||||
handlers = {'hello': on_hello}
|
||||
d = dispatcher.TypeDispatcher(handlers)
|
||||
d = dispatcher.TypeDispatcher(type_handlers=handlers)
|
||||
msg = mock_acked_message(ack_ok=False,
|
||||
properties={'type': 'hello'})
|
||||
d.on_message("", msg)
|
||||
|
@ -86,11 +86,12 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
ex = self.executor(reset_master_mock=False)
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.executor_uuid, self.executor_exchange,
|
||||
mock.ANY, on_wait=ex._on_wait,
|
||||
on_wait=ex._on_wait,
|
||||
url=self.broker_url, transport=mock.ANY,
|
||||
transport_options=mock.ANY,
|
||||
retry_options=mock.ANY
|
||||
)
|
||||
retry_options=mock.ANY,
|
||||
type_handlers=mock.ANY),
|
||||
mock.call.proxy.dispatcher.type_handlers.update(mock.ANY),
|
||||
]
|
||||
self.assertEqual(self.master_mock.mock_calls, master_mock_calls)
|
||||
|
||||
@ -212,10 +213,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(len(ex._requests_cache), 0)
|
||||
|
||||
def test_execute_task(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex._finder._add(self.executor_topic, [self.task.name])
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
@ -231,10 +230,8 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(expected_calls, self.master_mock.mock_calls)
|
||||
|
||||
def test_revert_task(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex._finder._add(self.executor_topic, [self.task.name])
|
||||
ex.revert_task(self.task, self.task_uuid, self.task_args,
|
||||
self.task_result, self.task_failures)
|
||||
|
||||
@ -263,11 +260,9 @@ class TestWorkerTaskExecutor(test.MockTestCase):
|
||||
self.assertEqual(self.master_mock.mock_calls, expected_calls)
|
||||
|
||||
def test_execute_task_publish_error(self):
|
||||
self.message_mock.properties['type'] = pr.NOTIFY
|
||||
self.proxy_inst_mock.publish.side_effect = Exception('Woot!')
|
||||
notify = pr.Notify(topic=self.executor_topic, tasks=[self.task.name])
|
||||
ex = self.executor()
|
||||
ex._process_notify(notify.to_dict(), self.message_mock)
|
||||
ex._finder._add(self.executor_topic, [self.task.name])
|
||||
ex.execute_task(self.task, self.task_uuid, self.task_args)
|
||||
|
||||
expected_calls = [
|
||||
|
@ -86,7 +86,7 @@ class TestServer(test.MockTestCase):
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||
mock.ANY, url=self.broker_url,
|
||||
type_handlers=mock.ANY, url=self.broker_url,
|
||||
transport=mock.ANY, transport_options=mock.ANY,
|
||||
retry_options=mock.ANY)
|
||||
]
|
||||
@ -99,7 +99,7 @@ class TestServer(test.MockTestCase):
|
||||
# check calls
|
||||
master_mock_calls = [
|
||||
mock.call.Proxy(self.server_topic, self.server_exchange,
|
||||
mock.ANY, url=self.broker_url,
|
||||
type_handlers=mock.ANY, url=self.broker_url,
|
||||
transport=mock.ANY, transport_options=mock.ANY,
|
||||
retry_options=mock.ANY)
|
||||
]
|
||||
|
@ -14,23 +14,20 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from oslo.utils import reflection
|
||||
|
||||
from taskflow.engines.worker_based import protocol as pr
|
||||
from taskflow.engines.worker_based import types as worker_types
|
||||
from taskflow import test
|
||||
from taskflow.test import mock
|
||||
from taskflow.tests import utils
|
||||
from taskflow.types import latch
|
||||
from taskflow.types import timing
|
||||
|
||||
|
||||
class TestWorkerTypes(test.TestCase):
|
||||
class TestRequestCache(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWorkerTypes, self).setUp()
|
||||
super(TestRequestCache, self).setUp()
|
||||
self.addCleanup(timing.StopWatch.clear_overrides)
|
||||
self.task = utils.DummyTask()
|
||||
self.task_uuid = 'task-uuid'
|
||||
@ -76,6 +73,8 @@ class TestWorkerTypes(test.TestCase):
|
||||
self.assertEqual(1, len(matches))
|
||||
self.assertEqual(2, len(cache))
|
||||
|
||||
|
||||
class TestTopicWorker(test.TestCase):
|
||||
def test_topic_worker(self):
|
||||
worker = worker_types.TopicWorker("dummy-topic",
|
||||
[utils.DummyTask], identity="dummy")
|
||||
@ -84,52 +83,37 @@ class TestWorkerTypes(test.TestCase):
|
||||
self.assertEqual('dummy', worker.identity)
|
||||
self.assertEqual('dummy-topic', worker.topic)
|
||||
|
||||
def test_single_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
||||
|
||||
class TestProxyFinder(test.TestCase):
|
||||
def test_single_topic_worker(self):
|
||||
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||
w, emit = finder._add('dummy-topic', [utils.DummyTask])
|
||||
self.assertIsNotNone(w)
|
||||
self.assertEqual(1, len(workers))
|
||||
w2 = workers.get_worker_for_task(utils.DummyTask)
|
||||
self.assertTrue(emit)
|
||||
self.assertEqual(1, finder._total_workers())
|
||||
w2 = finder.get_worker_for_task(utils.DummyTask)
|
||||
self.assertEqual(w.identity, w2.identity)
|
||||
|
||||
def test_multi_same_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
w = workers.add('dummy-topic', [utils.DummyTask])
|
||||
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||
w, emit = finder._add('dummy-topic', [utils.DummyTask])
|
||||
self.assertIsNotNone(w)
|
||||
w2 = workers.add('dummy-topic-2', [utils.DummyTask])
|
||||
self.assertTrue(emit)
|
||||
w2, emit = finder._add('dummy-topic-2', [utils.DummyTask])
|
||||
self.assertIsNotNone(w2)
|
||||
w3 = workers.get_worker_for_task(
|
||||
self.assertTrue(emit)
|
||||
w3 = finder.get_worker_for_task(
|
||||
reflection.get_class_name(utils.DummyTask))
|
||||
self.assertIn(w3.identity, [w.identity, w2.identity])
|
||||
|
||||
def test_multi_different_topic_workers(self):
|
||||
workers = worker_types.TopicWorkers()
|
||||
finder = worker_types.ProxyWorkerFinder('me', mock.MagicMock(), [])
|
||||
added = []
|
||||
added.append(workers.add('dummy-topic', [utils.DummyTask]))
|
||||
added.append(workers.add('dummy-topic-2', [utils.DummyTask]))
|
||||
added.append(workers.add('dummy-topic-3', [utils.NastyTask]))
|
||||
self.assertEqual(3, len(workers))
|
||||
w = workers.get_worker_for_task(utils.NastyTask)
|
||||
self.assertEqual(added[-1].identity, w.identity)
|
||||
w = workers.get_worker_for_task(utils.DummyTask)
|
||||
self.assertIn(w.identity, [w_a.identity for w_a in added[0:2]])
|
||||
|
||||
def test_periodic_worker(self):
|
||||
barrier = latch.Latch(5)
|
||||
to = timing.Timeout(0.01)
|
||||
called_at = []
|
||||
|
||||
def callee():
|
||||
barrier.countdown()
|
||||
if barrier.needed == 0:
|
||||
to.interrupt()
|
||||
called_at.append(time.time())
|
||||
|
||||
w = worker_types.PeriodicWorker(to, [callee])
|
||||
t = threading.Thread(target=w.start)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
self.assertEqual(0, barrier.needed)
|
||||
self.assertEqual(5, len(called_at))
|
||||
self.assertTrue(to.is_stopped())
|
||||
added.append(finder._add('dummy-topic', [utils.DummyTask]))
|
||||
added.append(finder._add('dummy-topic-2', [utils.DummyTask]))
|
||||
added.append(finder._add('dummy-topic-3', [utils.NastyTask]))
|
||||
self.assertEqual(3, finder._total_workers())
|
||||
w = finder.get_worker_for_task(utils.NastyTask)
|
||||
self.assertEqual(added[-1][0].identity, w.identity)
|
||||
w = finder.get_worker_for_task(utils.DummyTask)
|
||||
self.assertIn(w.identity, [w_a[0].identity for w_a in added[0:2]])
|
||||
|
179
taskflow/types/periodic.py
Normal file
179
taskflow/types/periodic.py
Normal file
@ -0,0 +1,179 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright (C) 2015 Yahoo! Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import heapq
|
||||
import inspect
|
||||
|
||||
from oslo_utils import reflection
|
||||
import six
|
||||
|
||||
from taskflow import logging
|
||||
from taskflow.utils import misc
|
||||
from taskflow.utils import threading_utils as tu
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# Find a monotonic providing time (or fallback to using time.time()
|
||||
# which isn't *always* accurate but will suffice).
|
||||
_now = misc.find_monotonic(allow_time_time=True)
|
||||
|
||||
# Attributes expected on periodic tagged/decorated functions or methods...
|
||||
_PERIODIC_ATTRS = tuple([
|
||||
'_periodic',
|
||||
'_periodic_spacing',
|
||||
'_periodic_run_immediately',
|
||||
])
|
||||
|
||||
|
||||
def periodic(spacing, run_immediately=True):
|
||||
"""Tags a method/function as wanting/able to execute periodically."""
|
||||
|
||||
if spacing <= 0:
|
||||
raise ValueError("Periodicity/spacing must be greater than"
|
||||
" zero instead of %s" % spacing)
|
||||
|
||||
def wrapper(f):
|
||||
f._periodic = True
|
||||
f._periodic_spacing = spacing
|
||||
f._periodic_run_immediately = run_immediately
|
||||
|
||||
@six.wraps(f)
|
||||
def decorator(*args, **kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return decorator
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PeriodicWorker(object):
|
||||
"""Calls a collection of callables periodically (sleeping as needed...).
|
||||
|
||||
NOTE(harlowja): typically the :py:meth:`.start` method is executed in a
|
||||
background thread so that the periodic callables are executed in
|
||||
the background/asynchronously (using the defined periods to determine
|
||||
when each is called).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create(cls, objects, exclude_hidden=True):
|
||||
"""Automatically creates a worker by analyzing object(s) methods.
|
||||
|
||||
Only picks up methods that have been tagged/decorated with
|
||||
the :py:func:`.periodic` decorator (does not match against private
|
||||
or protected methods unless explicitly requested to).
|
||||
"""
|
||||
callables = []
|
||||
for obj in objects:
|
||||
for (name, member) in inspect.getmembers(obj):
|
||||
if name.startswith("_") and exclude_hidden:
|
||||
continue
|
||||
if reflection.is_bound_method(member):
|
||||
consume = True
|
||||
for attr_name in _PERIODIC_ATTRS:
|
||||
if not hasattr(member, attr_name):
|
||||
consume = False
|
||||
break
|
||||
if consume:
|
||||
callables.append(member)
|
||||
return cls(callables)
|
||||
|
||||
def __init__(self, callables, tombstone=None):
|
||||
if tombstone is None:
|
||||
self._tombstone = tu.Event()
|
||||
else:
|
||||
# Allows someone to share an event (if they so want to...)
|
||||
self._tombstone = tombstone
|
||||
almost_callables = list(callables)
|
||||
for cb in almost_callables:
|
||||
if not six.callable(cb):
|
||||
raise ValueError("Periodic callback must be callable")
|
||||
for attr_name in _PERIODIC_ATTRS:
|
||||
if not hasattr(cb, attr_name):
|
||||
raise ValueError("Periodic callback missing required"
|
||||
" attribute '%s'" % attr_name)
|
||||
self._callables = tuple((cb, reflection.get_callable_name(cb))
|
||||
for cb in almost_callables)
|
||||
self._schedule = []
|
||||
self._immediates = []
|
||||
now = _now()
|
||||
for i, (cb, cb_name) in enumerate(self._callables):
|
||||
spacing = getattr(cb, '_periodic_spacing')
|
||||
next_run = now + spacing
|
||||
heapq.heappush(self._schedule, (next_run, i))
|
||||
for (cb, cb_name) in reversed(self._callables):
|
||||
if getattr(cb, '_periodic_run_immediately', False):
|
||||
self._immediates.append((cb, cb_name))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._callables)
|
||||
|
||||
@staticmethod
|
||||
def _safe_call(cb, cb_name, kind='periodic'):
|
||||
try:
|
||||
cb()
|
||||
except Exception:
|
||||
LOG.warn("Failed to call %s callable '%s'",
|
||||
kind, cb_name, exc_info=True)
|
||||
|
||||
def start(self):
|
||||
"""Starts running (will not stop/return until the tombstone is set).
|
||||
|
||||
NOTE(harlowja): If this worker has no contained callables this raises
|
||||
a runtime error and does not run since it is impossible to periodically
|
||||
run nothing.
|
||||
"""
|
||||
if not self._callables:
|
||||
raise RuntimeError("A periodic worker can not start"
|
||||
" without any callables")
|
||||
while not self._tombstone.is_set():
|
||||
if self._immediates:
|
||||
cb, cb_name = self._immediates.pop()
|
||||
LOG.debug("Calling immediate callable '%s'", cb_name)
|
||||
self._safe_call(cb, cb_name, kind='immediate')
|
||||
else:
|
||||
# Figure out when we should run next (by selecting the
|
||||
# minimum item from the heap, where the minimum should be
|
||||
# the callable that needs to run next and has the lowest
|
||||
# next desired run time).
|
||||
now = _now()
|
||||
next_run, i = heapq.heappop(self._schedule)
|
||||
when_next = next_run - now
|
||||
if when_next <= 0:
|
||||
cb, cb_name = self._callables[i]
|
||||
spacing = getattr(cb, '_periodic_spacing')
|
||||
LOG.debug("Calling periodic callable '%s' (it runs every"
|
||||
" %s seconds)", cb_name, spacing)
|
||||
self._safe_call(cb, cb_name)
|
||||
# Run again someday...
|
||||
next_run = now + spacing
|
||||
heapq.heappush(self._schedule, (next_run, i))
|
||||
else:
|
||||
# Gotta wait...
|
||||
heapq.heappush(self._schedule, (next_run, i))
|
||||
self._tombstone.wait(when_next)
|
||||
|
||||
def stop(self):
|
||||
"""Sets the tombstone (this stops any further executions)."""
|
||||
self._tombstone.set()
|
||||
|
||||
def reset(self):
|
||||
"""Resets the tombstone and re-queues up any immediate executions."""
|
||||
self._tombstone.clear()
|
||||
self._immediates = []
|
||||
for (cb, cb_name) in reversed(self._callables):
|
||||
if getattr(cb, '_periodic_run_immediately', False):
|
||||
self._immediates.append((cb, cb_name))
|
Loading…
Reference in New Issue
Block a user