Improve join by removing periodic jobs

* This patch removes the approach with DB polling needed to
  determine if a "join" task is ready to run. Instead of running
  a periodic scheduled job, each task completion now runs the
  algorithm that finds all potentially affected join tasks
  and schedules just one job (instead of a periodic job) to check
  their readiness.
  This solves a problem of system cascaded overloading in case of
  having many very large joins (when a workflow has many joins with
  many  dependencies each). Previously, in such case Mistral created
  too many periodic jobs that just didn't let the workflow progress
  well, i.e. most CPU was used by scheduler to run those periodic
  jobs that very rarely switched "join" tasks to the RUNNING state.

Change-Id: I5ebc44c7a3f95c868d653689dc5cea689c788cd0
Closes-Bug: #1799356
This commit is contained in:
Renat Akhmerov 2018-10-10 14:37:08 +07:00
parent 0b38cd8028
commit 1a4c599a4d
10 changed files with 156 additions and 85 deletions

View File

@ -99,7 +99,7 @@ def process(func):
# NOTE(rakhmerov): Since we make RPC calls to the engine itself # NOTE(rakhmerov): Since we make RPC calls to the engine itself
# we need to process the action queue asynchronously in a new # we need to process the action queue asynchronously in a new
# thread. Otherwise, if we have one engine process the engine # thread. Otherwise, if we have one engine process the engine
# will may send a request to itself while already processing # may send a request to itself while already processing
# another one. In conjunction with blocking RPC it will lead # another one. In conjunction with blocking RPC it will lead
# to a deadlock (and RPC timeout). # to a deadlock (and RPC timeout).
def _within_new_thread(): def _within_new_thread():

View File

@ -78,8 +78,7 @@ def run_task(wf_cmd):
return return
if task.is_waiting() and (task.is_created() or task.is_state_changed()): _check_affected_tasks(task)
_schedule_refresh_task_state(task.task_ex, 1)
def rerun_task(task_ex, wf_spec): def rerun_task(task_ex, wf_spec):
@ -129,6 +128,10 @@ def _on_action_complete(action_ex):
wf_handler.force_fail_workflow(wf_ex, msg) wf_handler.force_fail_workflow(wf_ex, msg)
return
_check_affected_tasks(task)
@profiler.trace('task-handler-on-action-update', hide_args=True) @profiler.trace('task-handler-on-action-update', hide_args=True)
def _on_action_update(action_ex): def _on_action_update(action_ex):
@ -186,6 +189,8 @@ def _on_action_update(action_ex):
return return
_check_affected_tasks(task)
def force_fail_task(task_ex, msg): def force_fail_task(task_ex, msg):
"""Forces the given task to fail. """Forces the given task to fail.
@ -238,6 +243,8 @@ def continue_task(task_ex):
return return
_check_affected_tasks(task)
def complete_task(task_ex, state, state_info): def complete_task(task_ex, state, state_info):
wf_spec = spec_parser.get_workflow_spec_by_execution_id( wf_spec = spec_parser.get_workflow_spec_by_execution_id(
@ -264,6 +271,33 @@ def complete_task(task_ex, state, state_info):
return return
_check_affected_tasks(task)
def _check_affected_tasks(task):
if not task.is_completed():
return
task_ex = task.task_ex
wf_ex = task_ex.workflow_execution
if states.is_completed(wf_ex.state):
return
wf_spec = spec_parser.get_workflow_spec_by_execution_id(
task_ex.workflow_execution_id
)
wf_ctrl = wf_base.get_controller(wf_ex, wf_spec)
affected_task_execs = wf_ctrl.find_indirectly_affected_task_executions(
task_ex.name
)
for t_ex in affected_task_execs:
_schedule_refresh_task_state(t_ex)
def _build_task_from_execution(wf_spec, task_ex): def _build_task_from_execution(wf_spec, task_ex):
return _create_task( return _create_task(
@ -350,9 +384,14 @@ def _refresh_task_state(task_ex_id):
wf_ctrl = wf_base.get_controller(wf_ex, wf_spec) wf_ctrl = wf_base.get_controller(wf_ex, wf_spec)
log_state = wf_ctrl.get_logical_task_state( with db_api.named_lock(task_ex.id):
task_ex db_api.refresh(task_ex)
)
if (states.is_completed(task_ex.state)
or task_ex.state == states.RUNNING):
return
log_state = wf_ctrl.get_logical_task_state(task_ex)
state = log_state.state state = log_state.state
state_info = log_state.state_info state_info = log_state.state_info
@ -365,23 +404,18 @@ def _refresh_task_state(task_ex_id):
elif state == states.ERROR: elif state == states.ERROR:
complete_task(task_ex, state, state_info) complete_task(task_ex, state, state_info)
elif state == states.WAITING: elif state == states.WAITING:
# Let's assume that a task takes 0.01 sec in average to complete LOG.info(
# and based on this assumption calculate a time of the next check. "Task execution is still in WAITING state"
# The estimation is very rough, of course, but this delay will be " [task_ex_id=%s, task_name=%s]",
# decreasing as task preconditions will be completing which will task_ex_id,
# give a decent asymptotic approximation. task_ex.name
# For example, if a 'join' task has 100 inbound incomplete tasks )
# then the next 'refresh_task_state' call will happen in 10
# seconds. For 500 tasks it will be 50 seconds. The larger the
# workflow is, the more beneficial this mechanism will be.
delay = int(log_state.cardinality * 0.01)
_schedule_refresh_task_state(task_ex, max(1, delay))
else: else:
# Must never get here. # Must never get here.
raise RuntimeError( raise RuntimeError(
'Unexpected logical task state [task_ex_id=%s, task_name=%s, ' 'Unexpected logical task state [task_ex_id=%s, '
'state=%s]' % (task_ex_id, task_ex.name, state) 'task_name=%s, state=%s]' %
(task_ex_id, task_ex.name, state)
) )
@ -401,7 +435,7 @@ def _schedule_refresh_task_state(task_ex, delay=0):
:param task_ex: Task execution. :param task_ex: Task execution.
:param delay: Delay. :param delay: Delay.
""" """
key = 'th_c_t_s_a-%s' % task_ex.id key = _get_refresh_state_job_key(task_ex.id)
scheduler.schedule_call( scheduler.schedule_call(
None, None,
@ -412,6 +446,10 @@ def _schedule_refresh_task_state(task_ex, delay=0):
) )
def _get_refresh_state_job_key(task_ex_id):
return 'th_r_t_s-%s' % task_ex_id
@db_utils.retry_on_db_error @db_utils.retry_on_db_error
@action_queue.process @action_queue.process
def _scheduled_on_action_complete(action_ex_id, wf_action): def _scheduled_on_action_complete(action_ex_id, wf_action):
@ -492,7 +530,7 @@ def schedule_on_action_update(action_ex, delay=0):
return return
key = 'th_on_a_c-%s' % action_ex.task_execution_id key = 'th_on_a_u-%s' % action_ex.task_execution_id
scheduler.schedule_call( scheduler.schedule_call(
None, None,

View File

@ -140,13 +140,14 @@ class EngineTestCase(base.DbTestCase):
for t in w.task_executions: for t in w.task_executions:
print( print(
"\t%s [id=%s, state=%s, state_info=%s, processed=%s," "\t%s [id=%s, state=%s, state_info=%s, processed=%s,"
" published=%s]" % " published=%s, runtime_context=%s]" %
(t.name, (t.name,
t.id, t.id,
t.state, t.state,
t.state_info, t.state_info,
t.processed, t.processed,
t.published) t.published,
t.runtime_context)
) )
child_execs = t.executions child_execs = t.executions

View File

@ -873,18 +873,11 @@ class JoinEngineTest(base.EngineTestCase):
state=states.WAITING state=states.WAITING
) )
calls = db_api.get_delayed_calls()
mtd_name = 'mistral.engine.task_handler._refresh_task_state'
cnt = sum([1 for c in calls if c.target_method_name == mtd_name])
# There can be 2 calls with different value of 'processing' flag.
self.assertTrue(cnt == 1 or cnt == 2)
# Stop the workflow. # Stop the workflow.
self.engine.stop_workflow(wf_ex.id, state=states.CANCELLED) self.engine.stop_workflow(wf_ex.id, state=states.CANCELLED)
mtd_name = 'mistral.engine.task_handler._refresh_task_state'
self._await( self._await(
lambda: lambda:
len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0 len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0
@ -931,18 +924,11 @@ class JoinEngineTest(base.EngineTestCase):
state=states.WAITING state=states.WAITING
) )
calls = db_api.get_delayed_calls()
mtd_name = 'mistral.engine.task_handler._refresh_task_state'
cnt = sum([1 for c in calls if c.target_method_name == mtd_name])
# There can be 2 calls with different value of 'processing' flag.
self.assertTrue(cnt == 1 or cnt == 2)
# Stop the workflow. # Stop the workflow.
db_api.delete_workflow_execution(wf_ex.id) db_api.delete_workflow_execution(wf_ex.id)
mtd_name = 'mistral.engine.task_handler._refresh_task_state'
self._await( self._await(
lambda: lambda:
len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0 len(db_api.get_delayed_calls(target_method_name=mtd_name)) == 0

View File

@ -1025,8 +1025,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the parent workflow execution. # Get objects for the parent workflow execution.
wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1') wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1')
wf_1_task_execs = wf_1_ex.task_executions
wf_1_task_1_ex = self._assert_single_item( wf_1_task_1_ex = self._assert_single_item(
wf_1_ex.task_executions, wf_1_ex.task_executions,
name='task1' name='task1'
@ -1049,8 +1047,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[0].id wf_1_task_1_action_exs[0].id
) )
wf_2_ex_1_task_execs = wf_2_ex_1.task_executions
wf_2_ex_1_task_1_ex = self._assert_single_item( wf_2_ex_1_task_1_ex = self._assert_single_item(
wf_2_ex_1.task_executions, wf_2_ex_1.task_executions,
name='task1' name='task1'
@ -1064,8 +1060,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[1].id wf_1_task_1_action_exs[1].id
) )
wf_2_ex_2_task_execs = wf_2_ex_2.task_executions
wf_2_ex_2_task_1_ex = self._assert_single_item( wf_2_ex_2_task_1_ex = self._assert_single_item(
wf_2_ex_2.task_executions, wf_2_ex_2.task_executions,
name='task1' name='task1'
@ -1079,8 +1073,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[2].id wf_1_task_1_action_exs[2].id
) )
wf_2_ex_3_task_execs = wf_2_ex_3.task_executions
wf_2_ex_3_task_1_ex = self._assert_single_item( wf_2_ex_3_task_1_ex = self._assert_single_item(
wf_2_ex_3.task_executions, wf_2_ex_3.task_executions,
name='task1' name='task1'
@ -1093,8 +1085,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the wf3 subworkflow execution. # Get objects for the wf3 subworkflow execution.
wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3') wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3')
wf_3_task_execs = wf_3_ex.task_executions
wf_3_task_1_ex = self._assert_single_item( wf_3_task_1_ex = self._assert_single_item(
wf_3_ex.task_executions, wf_3_ex.task_executions,
name='task1' name='task1'
@ -1149,8 +1139,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the parent workflow execution. # Get objects for the parent workflow execution.
wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1') wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1')
wf_1_task_execs = wf_1_ex.task_executions
wf_1_task_1_ex = self._assert_single_item( wf_1_task_1_ex = self._assert_single_item(
wf_1_ex.task_executions, wf_1_ex.task_executions,
name='task1' name='task1'
@ -1173,8 +1161,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[0].id wf_1_task_1_action_exs[0].id
) )
wf_2_ex_1_task_execs = wf_2_ex_1.task_executions
wf_2_ex_1_task_1_ex = self._assert_single_item( wf_2_ex_1_task_1_ex = self._assert_single_item(
wf_2_ex_1.task_executions, wf_2_ex_1.task_executions,
name='task1' name='task1'
@ -1188,8 +1174,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[1].id wf_1_task_1_action_exs[1].id
) )
wf_2_ex_2_task_execs = wf_2_ex_2.task_executions
wf_2_ex_2_task_1_ex = self._assert_single_item( wf_2_ex_2_task_1_ex = self._assert_single_item(
wf_2_ex_2.task_executions, wf_2_ex_2.task_executions,
name='task1' name='task1'
@ -1203,8 +1187,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[2].id wf_1_task_1_action_exs[2].id
) )
wf_2_ex_3_task_execs = wf_2_ex_3.task_executions
wf_2_ex_3_task_1_ex = self._assert_single_item( wf_2_ex_3_task_1_ex = self._assert_single_item(
wf_2_ex_3.task_executions, wf_2_ex_3.task_executions,
name='task1' name='task1'
@ -1217,8 +1199,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the wf3 subworkflow execution. # Get objects for the wf3 subworkflow execution.
wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3') wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3')
wf_3_task_execs = wf_3_ex.task_executions
wf_3_task_1_ex = self._assert_single_item( wf_3_task_1_ex = self._assert_single_item(
wf_3_ex.task_executions, wf_3_ex.task_executions,
name='task1' name='task1'
@ -1292,8 +1272,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the parent workflow execution. # Get objects for the parent workflow execution.
wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1') wf_1_ex = self._assert_single_item(wf_execs, name='wb.wf1')
wf_1_task_execs = wf_1_ex.task_executions
wf_1_task_1_ex = self._assert_single_item( wf_1_task_1_ex = self._assert_single_item(
wf_1_ex.task_executions, wf_1_ex.task_executions,
name='task1' name='task1'
@ -1316,8 +1294,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[0].id wf_1_task_1_action_exs[0].id
) )
wf_2_ex_1_task_execs = wf_2_ex_1.task_executions
wf_2_ex_1_task_1_ex = self._assert_single_item( wf_2_ex_1_task_1_ex = self._assert_single_item(
wf_2_ex_1.task_executions, wf_2_ex_1.task_executions,
name='task1' name='task1'
@ -1331,8 +1307,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[1].id wf_1_task_1_action_exs[1].id
) )
wf_2_ex_2_task_execs = wf_2_ex_2.task_executions
wf_2_ex_2_task_1_ex = self._assert_single_item( wf_2_ex_2_task_1_ex = self._assert_single_item(
wf_2_ex_2.task_executions, wf_2_ex_2.task_executions,
name='task1' name='task1'
@ -1346,8 +1320,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
wf_1_task_1_action_exs[2].id wf_1_task_1_action_exs[2].id
) )
wf_2_ex_3_task_execs = wf_2_ex_3.task_executions
wf_2_ex_3_task_1_ex = self._assert_single_item( wf_2_ex_3_task_1_ex = self._assert_single_item(
wf_2_ex_3.task_executions, wf_2_ex_3.task_executions,
name='task1' name='task1'
@ -1360,8 +1332,6 @@ class SubworkflowPauseResumeTest(base.EngineTestCase):
# Get objects for the wf3 subworkflow execution. # Get objects for the wf3 subworkflow execution.
wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3') wf_3_ex = self._assert_single_item(wf_execs, name='wb.wf3')
wf_3_task_execs = wf_3_ex.task_executions
wf_3_task_1_ex = self._assert_single_item( wf_3_task_1_ex = self._assert_single_item(
wf_3_ex.task_executions, wf_3_ex.task_executions,
name='task1' name='task1'

View File

@ -167,6 +167,16 @@ class WorkflowController(object):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def find_indirectly_affected_task_executions(self, task_name):
"""Get a set of task executions indirectly affected by the given.
:param task_name: Task name.
:return: Task executions that can be indirectly affected by a task
identified by the given name.
"""
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def is_error_handled_for(self, task_ex): def is_error_handled_for(self, task_ex):
"""Determines if error is handled for specific task. """Determines if error is handled for specific task.

View File

@ -191,6 +191,9 @@ class DirectWorkflowController(base.WorkflowController):
return self._get_join_logical_state(task_spec) return self._get_join_logical_state(task_spec)
def find_indirectly_affected_task_executions(self, task_name):
return self._find_indirectly_affected_created_joins(task_name)
def is_error_handled_for(self, task_ex): def is_error_handled_for(self, task_ex):
return bool(self.wf_spec.get_on_error_clause(task_ex.name)) return bool(self.wf_spec.get_on_error_clause(task_ex.name))
@ -308,6 +311,54 @@ class DirectWorkflowController(base.WorkflowController):
if not condition or expr.evaluate(condition, ctx) if not condition or expr.evaluate(condition, ctx)
] ]
@profiler.trace('direct-wf-controller-find-downstream-joins')
def _find_indirectly_affected_created_joins(self, task_name, result=None,
visited_task_names=None):
visited_task_names = visited_task_names or set()
if task_name in visited_task_names:
return
visited_task_names.add(task_name)
result = result or set()
def _process_clause(clause):
for t_name, condition, params in clause:
t_spec = self.wf_spec.get_tasks()[t_name]
# Encountered an engine command.
if not t_spec:
continue
if t_spec.get_join():
# TODO(rakhmerov): This is a fundamental limitation
# that prevents us having cycles within workflows
# that contain joins because we assume that there
# can be only one "join" task with a given name.
t_ex = self._find_task_execution_by_name(t_name)
if t_ex:
result.add(t_ex)
# If we found a "join" we don't need to go further
# because completion of the found join will handle
# other deeper joins.
continue
# Recursion.
self._find_indirectly_affected_created_joins(
t_name,
result=result,
visited_task_names=visited_task_names
)
_process_clause(self.wf_spec.get_on_success_clause(task_name))
_process_clause(self.wf_spec.get_on_error_clause(task_name))
_process_clause(self.wf_spec.get_on_complete_clause(task_name))
return result
@profiler.trace('direct-wf-controller-get-join-logical-state') @profiler.trace('direct-wf-controller-get-join-logical-state')
def _get_join_logical_state(self, task_spec): def _get_join_logical_state(self, task_spec):
"""Evaluates logical state of 'join' task. """Evaluates logical state of 'join' task.

View File

@ -88,7 +88,8 @@ def find_task_executions_by_name(wf_ex_id, task_name):
:param wf_ex_id: Workflow execution id. :param wf_ex_id: Workflow execution id.
:param task_name: Task name. :param task_name: Task name.
:return: Task executions (possibly a cached value). :return: Task executions (possibly a cached value). The returned list
may contain task execution clones not bound to the DB session.
""" """
with _TASK_EX_CACHE_LOCK: with _TASK_EX_CACHE_LOCK:
t_execs = _TASK_EX_CACHE[wf_ex_id].get(task_name) t_execs = _TASK_EX_CACHE[wf_ex_id].get(task_name)
@ -102,6 +103,8 @@ def find_task_executions_by_name(wf_ex_id, task_name):
sort_keys=[] # disable sorting sort_keys=[] # disable sorting
) )
t_execs = [t_ex.get_clone() for t_ex in t_execs]
# We can cache only finished tasks because they won't change. # We can cache only finished tasks because they won't change.
all_finished = ( all_finished = (
t_execs and t_execs and

View File

@ -118,6 +118,9 @@ class ReverseWorkflowController(base.WorkflowController):
# TODO(rakhmerov): Implement. # TODO(rakhmerov): Implement.
return base.TaskLogicalState(task_ex.state, task_ex.state_info) return base.TaskLogicalState(task_ex.state, task_ex.state_info)
def find_indirectly_affected_task_executions(self, task_name):
return set()
def is_error_handled_for(self, task_ex): def is_error_handled_for(self, task_ex):
return task_ex.state != states.ERROR return task_ex.state != states.ERROR

View File

@ -0,0 +1,9 @@
---
fixes:
- |
Removed DB polling from the logic that checks readiness of a "join" task
which leads to situations when CPU was mostly occupied by scheduler that
runs corresponding periodic jobs and that doesn't let the workflow move
forward with a proper speed. That happens in case if a workflow has lots
of "join" tasks with many dependencies. It's fixed now.