diff --git a/mistral/lang/v2/tasks.py b/mistral/lang/v2/tasks.py index 0b333e98c..60527ebe0 100644 --- a/mistral/lang/v2/tasks.py +++ b/mistral/lang/v2/tasks.py @@ -350,6 +350,18 @@ class DirectWorkflowTaskSpec(TaskSpec): def get_on_error(self): return self._on_error + def is_conditional_transition(self, state): + data = self._data.get(state) or self._data.get('on-complete') + + if not data: + return False + + for item in data: + if type(item) is dict: + return True + + return False + class ReverseWorkflowTaskSpec(TaskSpec): _polymorphic_value = 'reverse' diff --git a/mistral/workflow/data_flow.py b/mistral/workflow/data_flow.py index 4069999d6..20aae072a 100644 --- a/mistral/workflow/data_flow.py +++ b/mistral/workflow/data_flow.py @@ -236,7 +236,6 @@ def evaluate_task_outbound_context(task_ex): :param task_ex: DB task. :return: Outbound task Data Flow context. """ - # NOTE(rakhmerov): 'task_ex.in_context' has the SQLAlchemy specific # type MutableDict. So we need to create a shallow copy using dict(...) # initializer and use it. It's enough to be safe in order to manipulate @@ -248,10 +247,10 @@ def evaluate_task_outbound_context(task_ex): # footprint and reduces performance. in_context = ( dict(task_ex.in_context) - if task_ex.in_context is not None else {} + if getattr(task_ex, 'in_context', None) is not None else {} ) - return utils.update_dict(in_context, task_ex.published) + return utils.update_dict(in_context, getattr(task_ex, 'published', {})) def evaluate_workflow_output(wf_ex, wf_output, ctx): diff --git a/mistral/workflow/direct_workflow.py b/mistral/workflow/direct_workflow.py index 1b4061abb..20cde942e 100644 --- a/mistral/workflow/direct_workflow.py +++ b/mistral/workflow/direct_workflow.py @@ -15,6 +15,7 @@ from oslo_log import log as logging from osprofiler import profiler +from mistral.db.v2 import api as db_api from mistral import exceptions as exc from mistral import expressions as expr from mistral import utils @@ -64,7 +65,8 @@ class DirectWorkflowController(base.WorkflowController): induced_state, _, _ = self._get_induced_join_state( self.wf_spec.get_tasks()[t_ex_candidate.name], self._find_task_execution_by_name(t_ex_candidate.name), - t_spec + t_spec, + {} ) return induced_state == states.RUNNING @@ -353,9 +355,6 @@ class DirectWorkflowController(base.WorkflowController): # TODO(rakhmerov): We need to use task_ex instead of task_spec # in order to cover a use case when there's more than one instance # of the same 'join' task in a workflow. - # TODO(rakhmerov): In some cases this method will be expensive because - # it uses a multi-step recursive search. We need to optimize it moving - # forward (e.g. with Workflow Execution Graph). join_expr = task_spec.get_join() @@ -364,13 +363,29 @@ class DirectWorkflowController(base.WorkflowController): if not in_task_specs: return base.TaskLogicalState(states.RUNNING) + names = self._find_all_parent_task_names(task_spec) + + t_execs_cache = { + t_ex.name: t_ex for t_ex in db_api.get_task_executions( + fields=('id', 'name', 'state'), + sort_keys=[], + workflow_execution_id=self.wf_ex.id, + name={'in': names} + ) + } if names else {} # don't perform a db request if 'names' are empty + # List of tuples (task_name, task_ex, state, depth, event_name). induced_states = [] for t_s in in_task_specs: - t_ex = self._find_task_execution_by_name(t_s.get_name()) + t_ex = t_execs_cache.get(t_s.get_name()) - tup = self._get_induced_join_state(t_s, t_ex, task_spec) + tup = self._get_induced_join_state( + t_s, + t_ex, + task_spec, + t_execs_cache + ) induced_states.append( ( @@ -470,11 +485,14 @@ class DirectWorkflowController(base.WorkflowController): # we may have multiple task executions for a task. It should # accept inbound task execution rather than a spec. def _get_induced_join_state(self, in_task_spec, in_task_ex, - join_task_spec): + join_task_spec, t_execs_cache): join_task_name = join_task_spec.get_name() if not in_task_ex: - possible, depth = self._possible_route(in_task_spec) + possible, depth = self._possible_route( + in_task_spec, + t_execs_cache + ) if possible: return states.WAITING, depth, None @@ -484,6 +502,11 @@ class DirectWorkflowController(base.WorkflowController): if not states.is_completed(in_task_ex.state): return states.WAITING, 1, None + if self._is_conditional_transition(in_task_ex, in_task_spec) and \ + not hasattr(in_task_ex, "in_context"): + in_task_ex = db_api.get_task_execution(in_task_ex.id) + t_execs_cache[in_task_ex.name] = in_task_ex + # [(task name, params, event name), ...] next_tasks_tuples = self._find_next_tasks(in_task_ex) @@ -506,25 +529,66 @@ class DirectWorkflowController(base.WorkflowController): # TODO(rakhmerov): Temporary hack. See the previous comment. return t_execs[-1] if t_execs else None - def _possible_route(self, task_spec, depth=1): + def _possible_route(self, task_spec, t_execs_cache, depth=1): in_task_specs = self.wf_spec.find_inbound_task_specs(task_spec) if not in_task_specs: return True, depth for t_s in in_task_specs: - t_ex = self._find_task_execution_by_name(t_s.get_name()) + t_ex = t_execs_cache.get(t_s.get_name()) if not t_ex: - possible, depth = self._possible_route(t_s, depth + 1) + possible, depth = self._possible_route( + t_s, + t_execs_cache, + depth + 1 + ) if possible: return True, depth else: t_name = task_spec.get_name() - if (not states.is_completed(t_ex.state) or - t_name in self._find_next_task_names(t_ex)): + if not states.is_completed(t_ex.state): + return True, depth + + # By default we don't download task context from the database, + # but just basic fields: 'id', 'name' and 'state'. It's a good + # optimization, because contexts can be too heavy and we don't + # need them most of the time. + # But sometimes we need it for conditional transitions (when + # the decision where to go is based on the current context), + # and if this is the case, we download full task execution + # and then evaluate its context to find the route. + # TODO(mfedosin): Think of a way to avoid this. + if self._is_conditional_transition(t_ex, task_spec) and \ + not hasattr(t_ex, "in_context"): + t_ex = db_api.get_task_execution(t_ex.id) + t_execs_cache[t_ex.name] = t_ex + + if t_name in self._find_next_task_names(t_ex): return True, depth return False, depth + + def _find_all_parent_task_names(self, task_spec): + + all_parent_names = set() + + inbound_specs = self.wf_spec.find_inbound_task_specs(task_spec)[:] + while inbound_specs: + spec = inbound_specs.pop() + all_parent_names.add(spec.get_name()) + inbound_specs += self.wf_spec.find_inbound_task_specs(spec) + + return all_parent_names + + @staticmethod + def _is_conditional_transition(t_ex, t_spec): + if t_ex.state == states.SUCCESS: + return t_spec.is_conditional_transition('on-success') + elif t_ex.state == states.ERROR: + return t_spec.is_conditional_transition('on-error') + + return False