Merge "Rework joining mechanism"

This commit is contained in:
Zuul 2019-05-20 03:44:57 +00:00 committed by Gerrit Code Review
commit 26b9cc2bac
3 changed files with 91 additions and 16 deletions

View File

@ -350,6 +350,18 @@ class DirectWorkflowTaskSpec(TaskSpec):
def get_on_error(self): def get_on_error(self):
return self._on_error return self._on_error
def is_conditional_transition(self, state):
data = self._data.get(state) or self._data.get('on-complete')
if not data:
return False
for item in data:
if type(item) is dict:
return True
return False
class ReverseWorkflowTaskSpec(TaskSpec): class ReverseWorkflowTaskSpec(TaskSpec):
_polymorphic_value = 'reverse' _polymorphic_value = 'reverse'

View File

@ -236,7 +236,6 @@ def evaluate_task_outbound_context(task_ex):
:param task_ex: DB task. :param task_ex: DB task.
:return: Outbound task Data Flow context. :return: Outbound task Data Flow context.
""" """
# NOTE(rakhmerov): 'task_ex.in_context' has the SQLAlchemy specific # NOTE(rakhmerov): 'task_ex.in_context' has the SQLAlchemy specific
# type MutableDict. So we need to create a shallow copy using dict(...) # type MutableDict. So we need to create a shallow copy using dict(...)
# initializer and use it. It's enough to be safe in order to manipulate # initializer and use it. It's enough to be safe in order to manipulate
@ -248,10 +247,10 @@ def evaluate_task_outbound_context(task_ex):
# footprint and reduces performance. # footprint and reduces performance.
in_context = ( in_context = (
dict(task_ex.in_context) dict(task_ex.in_context)
if task_ex.in_context is not None else {} if getattr(task_ex, 'in_context', None) is not None else {}
) )
return utils.update_dict(in_context, task_ex.published) return utils.update_dict(in_context, getattr(task_ex, 'published', {}))
def evaluate_workflow_output(wf_ex, wf_output, ctx): def evaluate_workflow_output(wf_ex, wf_output, ctx):

View File

@ -15,6 +15,7 @@
from oslo_log import log as logging from oslo_log import log as logging
from osprofiler import profiler from osprofiler import profiler
from mistral.db.v2 import api as db_api
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral import expressions as expr from mistral import expressions as expr
from mistral import utils from mistral import utils
@ -64,7 +65,8 @@ class DirectWorkflowController(base.WorkflowController):
induced_state, _, _ = self._get_induced_join_state( induced_state, _, _ = self._get_induced_join_state(
self.wf_spec.get_tasks()[t_ex_candidate.name], self.wf_spec.get_tasks()[t_ex_candidate.name],
self._find_task_execution_by_name(t_ex_candidate.name), self._find_task_execution_by_name(t_ex_candidate.name),
t_spec t_spec,
{}
) )
return induced_state == states.RUNNING return induced_state == states.RUNNING
@ -353,9 +355,6 @@ class DirectWorkflowController(base.WorkflowController):
# TODO(rakhmerov): We need to use task_ex instead of task_spec # TODO(rakhmerov): We need to use task_ex instead of task_spec
# in order to cover a use case when there's more than one instance # in order to cover a use case when there's more than one instance
# of the same 'join' task in a workflow. # of the same 'join' task in a workflow.
# TODO(rakhmerov): In some cases this method will be expensive because
# it uses a multi-step recursive search. We need to optimize it moving
# forward (e.g. with Workflow Execution Graph).
join_expr = task_spec.get_join() join_expr = task_spec.get_join()
@ -364,13 +363,29 @@ class DirectWorkflowController(base.WorkflowController):
if not in_task_specs: if not in_task_specs:
return base.TaskLogicalState(states.RUNNING) return base.TaskLogicalState(states.RUNNING)
names = self._find_all_parent_task_names(task_spec)
t_execs_cache = {
t_ex.name: t_ex for t_ex in db_api.get_task_executions(
fields=('id', 'name', 'state'),
sort_keys=[],
workflow_execution_id=self.wf_ex.id,
name={'in': names}
)
} if names else {} # don't perform a db request if 'names' are empty
# List of tuples (task_name, task_ex, state, depth, event_name). # List of tuples (task_name, task_ex, state, depth, event_name).
induced_states = [] induced_states = []
for t_s in in_task_specs: for t_s in in_task_specs:
t_ex = self._find_task_execution_by_name(t_s.get_name()) t_ex = t_execs_cache.get(t_s.get_name())
tup = self._get_induced_join_state(t_s, t_ex, task_spec) tup = self._get_induced_join_state(
t_s,
t_ex,
task_spec,
t_execs_cache
)
induced_states.append( induced_states.append(
( (
@ -470,11 +485,14 @@ class DirectWorkflowController(base.WorkflowController):
# we may have multiple task executions for a task. It should # we may have multiple task executions for a task. It should
# accept inbound task execution rather than a spec. # accept inbound task execution rather than a spec.
def _get_induced_join_state(self, in_task_spec, in_task_ex, def _get_induced_join_state(self, in_task_spec, in_task_ex,
join_task_spec): join_task_spec, t_execs_cache):
join_task_name = join_task_spec.get_name() join_task_name = join_task_spec.get_name()
if not in_task_ex: if not in_task_ex:
possible, depth = self._possible_route(in_task_spec) possible, depth = self._possible_route(
in_task_spec,
t_execs_cache
)
if possible: if possible:
return states.WAITING, depth, None return states.WAITING, depth, None
@ -484,6 +502,11 @@ class DirectWorkflowController(base.WorkflowController):
if not states.is_completed(in_task_ex.state): if not states.is_completed(in_task_ex.state):
return states.WAITING, 1, None return states.WAITING, 1, None
if self._is_conditional_transition(in_task_ex, in_task_spec) and \
not hasattr(in_task_ex, "in_context"):
in_task_ex = db_api.get_task_execution(in_task_ex.id)
t_execs_cache[in_task_ex.name] = in_task_ex
# [(task name, params, event name), ...] # [(task name, params, event name), ...]
next_tasks_tuples = self._find_next_tasks(in_task_ex) next_tasks_tuples = self._find_next_tasks(in_task_ex)
@ -506,25 +529,66 @@ class DirectWorkflowController(base.WorkflowController):
# TODO(rakhmerov): Temporary hack. See the previous comment. # TODO(rakhmerov): Temporary hack. See the previous comment.
return t_execs[-1] if t_execs else None return t_execs[-1] if t_execs else None
def _possible_route(self, task_spec, depth=1): def _possible_route(self, task_spec, t_execs_cache, depth=1):
in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec) in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec)
if not in_task_specs: if not in_task_specs:
return True, depth return True, depth
for t_s in in_task_specs: for t_s in in_task_specs:
t_ex = self._find_task_execution_by_name(t_s.get_name()) t_ex = t_execs_cache.get(t_s.get_name())
if not t_ex: if not t_ex:
possible, depth = self._possible_route(t_s, depth + 1) possible, depth = self._possible_route(
t_s,
t_execs_cache,
depth + 1
)
if possible: if possible:
return True, depth return True, depth
else: else:
t_name = task_spec.get_name() t_name = task_spec.get_name()
if (not states.is_completed(t_ex.state) or if not states.is_completed(t_ex.state):
t_name in self._find_next_task_names(t_ex)): return True, depth
# By default we don't download task context from the database,
# but just basic fields: 'id', 'name' and 'state'. It's a good
# optimization, because contexts can be too heavy and we don't
# need them most of the time.
# But sometimes we need it for conditional transitions (when
# the decision where to go is based on the current context),
# and if this is the case, we download full task execution
# and then evaluate its context to find the route.
# TODO(mfedosin): Think of a way to avoid this.
if self._is_conditional_transition(t_ex, task_spec) and \
not hasattr(t_ex, "in_context"):
t_ex = db_api.get_task_execution(t_ex.id)
t_execs_cache[t_ex.name] = t_ex
if t_name in self._find_next_task_names(t_ex):
return True, depth return True, depth
return False, depth return False, depth
def _find_all_parent_task_names(self, task_spec):
all_parent_names = set()
inbound_specs = self.wf_spec.find_inbound_task_specs(task_spec)[:]
while inbound_specs:
spec = inbound_specs.pop()
all_parent_names.add(spec.get_name())
inbound_specs += self.wf_spec.find_inbound_task_specs(spec)
return all_parent_names
@staticmethod
def _is_conditional_transition(t_ex, t_spec):
if t_ex.state == states.SUCCESS:
return t_spec.is_conditional_transition('on-success')
elif t_ex.state == states.ERROR:
return t_spec.is_conditional_transition('on-error')
return False