From c9fcb03c6d60cbbf7d7b6cff67ddc5ef8a45c6a2 Mon Sep 17 00:00:00 2001 From: Winson Chan Date: Fri, 7 Jul 2017 18:48:27 +0000 Subject: [PATCH] Allow async action execution to be paused and resumed Allow async action execution to be paused and resumed by updating the action execution via API. When an action execution is paused, the state transition will cascade up to the task execution. Implements: blueprint mistral-action-ex-pause-resume Change-Id: I87233d27c46cfe86a23beb8dfdc96f77e58d24c1 --- .../api/controllers/v2/action_execution.py | 40 ++- mistral/db/v2/sqlalchemy/api.py | 3 +- mistral/engine/action_handler.py | 25 ++ mistral/engine/actions.py | 17 + mistral/engine/default_engine.py | 15 + mistral/engine/engine_server.py | 18 ++ mistral/engine/task_handler.py | 89 +++++ mistral/engine/tasks.py | 32 ++ mistral/exceptions.py | 5 + mistral/rpc/clients.py | 34 ++ .../unit/api/v2/test_action_executions.py | 42 ++- mistral/tests/unit/engine/base.py | 8 + .../tests/unit/engine/test_default_engine.py | 125 ++++++++ .../unit/engine/test_task_pause_resume.py | 303 ++++++++++++++++++ 14 files changed, 741 insertions(+), 15 deletions(-) create mode 100644 mistral/tests/unit/engine/test_task_pause_resume.py diff --git a/mistral/api/controllers/v2/action_execution.py b/mistral/api/controllers/v2/action_execution.py index 4612b084f..f8346154d 100644 --- a/mistral/api/controllers/v2/action_execution.py +++ b/mistral/api/controllers/v2/action_execution.py @@ -34,6 +34,14 @@ from mistral_lib import actions as ml_actions LOG = logging.getLogger(__name__) +SUPPORTED_TRANSITION_STATES = [ + states.SUCCESS, + states.ERROR, + states.CANCELLED, + states.PAUSED, + states.RUNNING +] + def _load_deferred_output_field(action_ex): # We need to refer to this lazy-load field explicitly in @@ -180,25 +188,31 @@ class ActionExecutionsController(rest.RestController): action_ex ) - output = action_ex.output - - if action_ex.state == states.SUCCESS: - result = ml_actions.Result(data=output) - elif action_ex.state == states.ERROR: - if not output: - output = 'Unknown error' - result = ml_actions.Result(error=output) - elif action_ex.state == states.CANCELLED: - result = ml_actions.Result(cancel=True) - else: + if action_ex.state not in SUPPORTED_TRANSITION_STATES: raise exc.InvalidResultException( "Error. Expected one of %s, actual: %s" % ( - [states.SUCCESS, states.ERROR, states.CANCELLED], + SUPPORTED_TRANSITION_STATES, action_ex.state ) ) - values = rpc.get_engine_client().on_action_complete(id, result) + if states.is_completed(action_ex.state): + output = action_ex.output + + if action_ex.state == states.SUCCESS: + result = ml_actions.Result(data=output) + elif action_ex.state == states.ERROR: + if not output: + output = 'Unknown error' + result = ml_actions.Result(error=output) + elif action_ex.state == states.CANCELLED: + result = ml_actions.Result(cancel=True) + + values = rpc.get_engine_client().on_action_complete(id, result) + + if action_ex.state in [states.PAUSED, states.RUNNING]: + state = action_ex.state + values = rpc.get_engine_client().on_action_update(id, state) return resources.ActionExecution.from_dict(values) diff --git a/mistral/db/v2/sqlalchemy/api.py b/mistral/db/v2/sqlalchemy/api.py index 28b2bb83c..97d4c3804 100644 --- a/mistral/db/v2/sqlalchemy/api.py +++ b/mistral/db/v2/sqlalchemy/api.py @@ -940,7 +940,8 @@ def _get_incomplete_task_executions_query(kwargs): models.TaskExecution.state == states.IDLE, models.TaskExecution.state == states.RUNNING, models.TaskExecution.state == states.WAITING, - models.TaskExecution.state == states.RUNNING_DELAYED + models.TaskExecution.state == states.RUNNING_DELAYED, + models.TaskExecution.state == states.PAUSED ) ) diff --git a/mistral/engine/action_handler.py b/mistral/engine/action_handler.py index 96c66c8e6..0b057f581 100644 --- a/mistral/engine/action_handler.py +++ b/mistral/engine/action_handler.py @@ -54,6 +54,31 @@ def on_action_complete(action_ex, result): task_handler.schedule_on_action_complete(action_ex) +@profiler.trace('action-handler-on-action-update', hide_args=True) +def on_action_update(action_ex, state): + task_ex = action_ex.task_execution + + action = _build_action(action_ex) + + try: + action.update(state) + except exc.MistralException as e: + # If the update of the action execution fails, do not fail + # the action execution. Log the exception and re-raise the + # exception. + msg = ( + "Failed to update action [error=%s, action=%s, task=%s]:\n%s" + % (e, action_ex.name, task_ex.name, tb.format_exc()) + ) + + LOG.error(msg) + + raise + + if task_ex: + task_handler.schedule_on_action_update(action_ex) + + @profiler.trace('action-handler-build-action', hide_args=True) def _build_action(action_ex): if isinstance(action_ex, models.WorkflowExecution): diff --git a/mistral/engine/actions.py b/mistral/engine/actions.py index 392b8568d..ec7a59c11 100644 --- a/mistral/engine/actions.py +++ b/mistral/engine/actions.py @@ -73,6 +73,23 @@ class Action(object): self.action_ex.state = states.ERROR self.action_ex.output = {'result': msg} + def update(self, state): + assert self.action_ex + + if state == states.PAUSED and self.is_sync(self.action_ex.input): + raise exc.InvalidStateTransitionException( + 'Transition to the PAUSED state is only supported ' + 'for asynchronous action execution.' + ) + + if not states.is_valid_transition(self.action_ex.state, state): + raise exc.InvalidStateTransitionException( + 'Invalid state transition from %s to %s.' % + (self.action_ex.state, state) + ) + + self.action_ex.state = state + @abc.abstractmethod def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False): """Schedule action run. diff --git a/mistral/engine/default_engine.py b/mistral/engine/default_engine.py index d4c895dad..7f7d04ef3 100644 --- a/mistral/engine/default_engine.py +++ b/mistral/engine/default_engine.py @@ -118,6 +118,21 @@ class DefaultEngine(base.Engine): return action_ex.get_clone() + @db_utils.retry_on_deadlock + @action_queue.process + @profiler.trace('engine-on-action-update', hide_args=True) + def on_action_update(self, action_ex_id, state, wf_action=False, + async_=False): + with db_api.transaction(): + if wf_action: + action_ex = db_api.get_workflow_execution(action_ex_id) + else: + action_ex = db_api.get_action_execution(action_ex_id) + + action_handler.on_action_update(action_ex, state) + + return action_ex.get_clone() + def pause_workflow(self, wf_ex_id): with db_api.transaction(): wf_ex = db_api.get_workflow_execution(wf_ex_id) diff --git a/mistral/engine/engine_server.py b/mistral/engine/engine_server.py index 3ef4b9e09..74fc87787 100644 --- a/mistral/engine/engine_server.py +++ b/mistral/engine/engine_server.py @@ -153,6 +153,24 @@ class EngineServer(service_base.MistralService): return self.engine.on_action_complete(action_ex_id, result, wf_action) + def on_action_update(self, rpc_ctx, action_ex_id, state, wf_action): + """Receives RPC calls to communicate action execution state to engine. + + :param rpc_ctx: RPC request context. + :param action_ex_id: Action execution id. + :param state: Action execution state. + :param wf_action: True if given id points to a workflow execution. + :return: Action execution. + """ + LOG.info( + "Received RPC request 'on_action_update'" + "[action_ex_id=%s, state=%s]", + action_ex_id, + state + ) + + return self.engine.on_action_update(action_ex_id, state, wf_action) + def pause_workflow(self, rpc_ctx, execution_id): """Receives calls over RPC to pause workflows on engine. diff --git a/mistral/engine/task_handler.py b/mistral/engine/task_handler.py index a7e4163b7..0eb2c1308 100644 --- a/mistral/engine/task_handler.py +++ b/mistral/engine/task_handler.py @@ -44,6 +44,10 @@ _SCHEDULED_ON_ACTION_COMPLETE_PATH = ( 'mistral.engine.task_handler._scheduled_on_action_complete' ) +_SCHEDULED_ON_ACTION_UPDATE_PATH = ( + 'mistral.engine.task_handler._scheduled_on_action_update' +) + @profiler.trace('task-handler-run-task', hide_args=True) def run_task(wf_cmd): @@ -116,6 +120,46 @@ def _on_action_complete(action_ex): wf_handler.force_fail_workflow(wf_ex, msg) + +@profiler.trace('task-handler-on-action-update', hide_args=True) +def _on_action_update(action_ex): + """Handles action update event. + + :param action_ex: Action execution. + """ + + task_ex = action_ex.task_execution + + if not task_ex: + return + + task_spec = spec_parser.get_task_spec(task_ex.spec) + + wf_ex = task_ex.workflow_execution + + task = _create_task( + wf_ex, + spec_parser.get_workflow_spec_by_execution_id(wf_ex.id), + task_spec, + task_ex.in_context, + task_ex + ) + + try: + task.on_action_update(action_ex) + except exc.MistralException as e: + wf_ex = task_ex.workflow_execution + + msg = ("Failed to handle action update [error=%s, wf=%s, task=%s," + " action=%s]:\n%s" % + (e, wf_ex.name, task_ex.name, action_ex.name, tb.format_exc())) + + LOG.error(msg) + + task.set_state(states.ERROR, msg) + + wf_handler.force_fail_workflow(wf_ex, msg) + return @@ -386,3 +430,48 @@ def schedule_on_action_complete(action_ex, delay=0): action_ex_id=action_ex.id, wf_action=isinstance(action_ex, models.WorkflowExecution) ) + + +@action_queue.process +def _scheduled_on_action_update(action_ex_id, wf_action): + with db_api.transaction(): + if wf_action: + action_ex = db_api.get_workflow_execution(action_ex_id) + else: + action_ex = db_api.get_action_execution(action_ex_id) + + _on_action_update(action_ex) + + +def schedule_on_action_update(action_ex, delay=0): + """Schedules task update check. + + This method provides transactional decoupling of action update from + task update check. It's needed in non-locking model in order to + avoid 'phantom read' phenomena when reading state of multiple actions + to see if a task is updated. Just starting a separate transaction + without using scheduler is not safe due to concurrency window that we'll + have in this case (time between transactions) whereas scheduler is a + special component that is designed to be resistant to failures. + + :param action_ex: Action execution. + :param delay: Minimum amount of time before task update check + should be made. + """ + + # Optimization to avoid opening a new transaction if it's not needed. + if not action_ex.task_execution.spec.get('with-items'): + _on_action_update(action_ex) + + return + + key = 'th_on_a_c-%s' % action_ex.task_execution_id + + scheduler.schedule_call( + None, + _SCHEDULED_ON_ACTION_UPDATE_PATH, + delay, + key=key, + action_ex_id=action_ex.id, + wf_action=isinstance(action_ex, models.WorkflowExecution) + ) diff --git a/mistral/engine/tasks.py b/mistral/engine/tasks.py index 758fe6e37..b4f2ca086 100644 --- a/mistral/engine/tasks.py +++ b/mistral/engine/tasks.py @@ -77,6 +77,14 @@ class Task(object): """ raise NotImplementedError + @abc.abstractmethod + def on_action_update(self, action_ex): + """Handle action update. + + :param action_ex: Action execution. + """ + raise NotImplementedError + @abc.abstractmethod def run(self): """Runs task.""" @@ -191,6 +199,26 @@ class Task(object): dispatcher.dispatch_workflow_commands(self.wf_ex, cmds) + @profiler.trace('task-update') + def update(self, state, state_info=None): + """Update task and set specified state. + + Method sets specified task state. + + :param state: New task state. + :param state_info: New state information (i.e. error message). + """ + + assert self.task_ex + + # Ignore if task already completed. + if states.is_completed(self.task_ex.state): + return + + # Update only if state transition is valid. + if states.is_valid_transition(self.task_ex.state, state): + self.set_state(state, state_info) + def _before_task_start(self): policies_spec = self.task_spec.get_policies() @@ -268,6 +296,10 @@ class RegularTask(Task): self.complete(state, state_info) + @profiler.trace('regular-task-on-action-update', hide_args=True) + def on_action_update(self, action_ex): + self.update(action_ex.state) + @profiler.trace('task-run') def run(self): if not self.task_ex: diff --git a/mistral/exceptions.py b/mistral/exceptions.py index 81003b9ff..b2f25593b 100644 --- a/mistral/exceptions.py +++ b/mistral/exceptions.py @@ -215,3 +215,8 @@ class KombuException(Exception): self.exc_type = e.__class__.__name__ self.value = str(e) + + +class InvalidStateTransitionException(MistralException): + http_code = 400 + message = 'Invalid state transition' diff --git a/mistral/rpc/clients.py b/mistral/rpc/clients.py index 1cdaa96cb..d4b13e0aa 100644 --- a/mistral/rpc/clients.py +++ b/mistral/rpc/clients.py @@ -147,6 +147,40 @@ class EngineClient(eng.Engine): wf_action=wf_action ) + @base.wrap_messaging_exception + @profiler.trace('engine-client-on-action-update', hide_args=True) + def on_action_update(self, action_ex_id, state, wf_action=False, + async_=False): + """Conveys update of action state to Mistral Engine. + + This method should be used by clients of Mistral Engine to update + state of a action execution once action has executed. + + Note: calling this method serves an event notifying Mistral that it + may need to change the state of the parent task and workflow. Use + on_action_complete if the action execution reached completion state. + + :param action_ex_id: Action execution id. + :param action_ex_id: Updated state. + :param wf_action: If True it means that the given id points to + a workflow execution rather than action execution. It happens + when a nested workflow execution sends its result to a parent + workflow. + :param async: If True, run action in asynchronous mode (w/o waiting + for completion). + :return: Action(or workflow if wf_action=True) execution object. + """ + + call = self._client.async_call if async_ else self._client.sync_call + + return call( + auth_ctx.ctx(), + 'on_action_update', + action_ex_id=action_ex_id, + state=state, + wf_action=wf_action + ) + @base.wrap_messaging_exception def pause_workflow(self, wf_ex_id): """Stops the workflow with the given execution id. diff --git a/mistral/tests/unit/api/v2/test_action_executions.py b/mistral/tests/unit/api/v2/test_action_executions.py index a0314426f..25b07038c 100644 --- a/mistral/tests/unit/api/v2/test_action_executions.py +++ b/mistral/tests/unit/api/v2/test_action_executions.py @@ -140,6 +140,18 @@ CANCELLED_ACTION_EX_DB['task_name'] = 'task1' CANCELLED_ACTION = copy.deepcopy(ACTION_EX) CANCELLED_ACTION['state'] = 'CANCELLED' +PAUSED_ACTION_EX_DB = copy.copy(ACTION_EX_DB).to_dict() +PAUSED_ACTION_EX_DB['state'] = 'PAUSED' +PAUSED_ACTION_EX_DB['task_name'] = 'task1' +PAUSED_ACTION = copy.deepcopy(ACTION_EX) +PAUSED_ACTION['state'] = 'PAUSED' + +RUNNING_ACTION_EX_DB = copy.copy(ACTION_EX_DB).to_dict() +RUNNING_ACTION_EX_DB['state'] = 'RUNNING' +RUNNING_ACTION_EX_DB['task_name'] = 'task1' +RUNNING_ACTION = copy.deepcopy(ACTION_EX) +RUNNING_ACTION['state'] = 'RUNNING' + ERROR_ACTION_EX = copy.copy(ACTION_EX_DB).to_dict() ERROR_ACTION_EX['state'] = 'ERROR' ERROR_ACTION_EX['task_name'] = 'task1' @@ -395,6 +407,34 @@ class TestActionExecutionsController(base.APITest): ml_actions.Result(cancel=True) ) + @mock.patch.object(rpc_clients.EngineClient, 'on_action_update') + def test_put_paused(self, on_action_update_mock_func): + on_action_update_mock_func.return_value = PAUSED_ACTION_EX_DB + + resp = self.app.put_json('/v2/action_executions/123', PAUSED_ACTION) + + self.assertEqual(200, resp.status_int) + self.assertDictEqual(PAUSED_ACTION, resp.json) + + on_action_update_mock_func.assert_called_once_with( + PAUSED_ACTION['id'], + PAUSED_ACTION['state'] + ) + + @mock.patch.object(rpc_clients.EngineClient, 'on_action_update') + def test_put_resume(self, on_action_update_mock_func): + on_action_update_mock_func.return_value = RUNNING_ACTION_EX_DB + + resp = self.app.put_json('/v2/action_executions/123', RUNNING_ACTION) + + self.assertEqual(200, resp.status_int) + self.assertDictEqual(RUNNING_ACTION, resp.json) + + on_action_update_mock_func.assert_called_once_with( + RUNNING_ACTION['id'], + RUNNING_ACTION['state'] + ) + @mock.patch.object( rpc_clients.EngineClient, 'on_action_complete', @@ -411,7 +451,7 @@ class TestActionExecutionsController(base.APITest): def test_put_bad_state(self): action = copy.deepcopy(ACTION_EX) - action['state'] = 'PAUSED' + action['state'] = 'DELAYED' resp = self.app.put_json( '/v2/action_executions/123', diff --git a/mistral/tests/unit/engine/base.py b/mistral/tests/unit/engine/base.py index de09a0ad9..394418a65 100644 --- a/mistral/tests/unit/engine/base.py +++ b/mistral/tests/unit/engine/base.py @@ -208,6 +208,10 @@ class EngineTestCase(base.DbTestCase): def is_task_processed(self, task_ex_id): return db_api.get_task_execution(task_ex_id).processed + def await_task_running(self, ex_id, delay=DEFAULT_DELAY, + timeout=DEFAULT_TIMEOUT): + self.await_task_state(ex_id, states.RUNNING, delay, timeout) + def await_task_success(self, ex_id, delay=DEFAULT_DELAY, timeout=DEFAULT_TIMEOUT): self.await_task_state(ex_id, states.SUCCESS, delay, timeout) @@ -220,6 +224,10 @@ class EngineTestCase(base.DbTestCase): timeout=DEFAULT_TIMEOUT): self.await_task_state(ex_id, states.CANCELLED, delay, timeout) + def await_task_paused(self, ex_id, delay=DEFAULT_DELAY, + timeout=DEFAULT_TIMEOUT): + self.await_task_state(ex_id, states.PAUSED, delay, timeout) + def await_task_delayed(self, ex_id, delay=DEFAULT_DELAY, timeout=DEFAULT_TIMEOUT): self.await_task_state(ex_id, states.RUNNING_DELAYED, delay, timeout) diff --git a/mistral/tests/unit/engine/test_default_engine.py b/mistral/tests/unit/engine/test_default_engine.py index 93fd21d8c..df2e31a79 100644 --- a/mistral/tests/unit/engine/test_default_engine.py +++ b/mistral/tests/unit/engine/test_default_engine.py @@ -26,6 +26,7 @@ from mistral.engine import default_engine as d_eng from mistral import exceptions as exc from mistral.executors import base as exe from mistral.services import workbooks as wb_service +from mistral.services import workflows as wf_service from mistral.tests.unit import base from mistral.tests.unit.engine import base as eng_test_base from mistral.workflow import states @@ -299,6 +300,130 @@ class DefaultEngineTest(base.DbTestCase): self.assertIn("Invalid input", str(e)) self.assertIn("unexpected=['unexpected_param']", str(e)) + def test_on_action_update(self): + workflow = """ + version: '2.0' + wf_async: + type: direct + tasks: + task1: + action: std.async_noop + on-success: + - task2 + task2: + action: std.noop + """ + + # Start workflow. + wf_service.create_workflows(workflow) + wf_ex = self.engine.start_workflow('wf_async') + + self.assertIsNotNone(wf_ex) + self.assertEqual(states.RUNNING, wf_ex.state) + + with db_api.transaction(): + # Note: We need to reread execution to access related tasks. + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + self.assertEqual(1, len(task_execs)) + + task1_ex = task_execs[0] + + self.assertEqual('task1', task1_ex.name) + self.assertEqual(states.RUNNING, task1_ex.state) + + action_execs = db_api.get_action_executions( + task_execution_id=task1_ex.id + ) + + self.assertEqual(1, len(action_execs)) + + task1_action_ex = action_execs[0] + + self.assertEqual(states.RUNNING, task1_action_ex.state) + + # Pause action execution of 'task1'. + task1_action_ex = self.engine.on_action_update( + task1_action_ex.id, + states.PAUSED + ) + + self.assertIsInstance(task1_action_ex, models.ActionExecution) + self.assertEqual(states.PAUSED, task1_action_ex.state) + + with db_api.transaction(): + # Note: We need to reread execution to access related tasks. + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.PAUSED, task_execs[0].state) + self.assertEqual(states.RUNNING, wf_ex.state) + + action_execs = db_api.get_action_executions( + task_execution_id=task1_ex.id + ) + + self.assertEqual(1, len(action_execs)) + + task1_action_ex = action_execs[0] + + self.assertEqual(states.PAUSED, task1_action_ex.state) + + def test_on_action_update_non_async(self): + workflow = """ + version: '2.0' + wf_sync: + type: direct + tasks: + task1: + action: std.noop + on-success: + - task2 + task2: + action: std.noop + """ + + # Start workflow. + wf_service.create_workflows(workflow) + wf_ex = self.engine.start_workflow('wf_sync') + + self.assertIsNotNone(wf_ex) + self.assertEqual(states.RUNNING, wf_ex.state) + + with db_api.transaction(): + # Note: We need to reread execution to access related tasks. + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + self.assertEqual(1, len(task_execs)) + + task1_ex = task_execs[0] + + self.assertEqual('task1', task1_ex.name) + self.assertEqual(states.RUNNING, task1_ex.state) + + action_execs = db_api.get_action_executions( + task_execution_id=task1_ex.id + ) + + self.assertEqual(1, len(action_execs)) + + task1_action_ex = action_execs[0] + + self.assertEqual(states.RUNNING, task1_action_ex.state) + + self.assertRaises( + exc.InvalidStateTransitionException, + self.engine.on_action_update, + task1_action_ex.id, + states.PAUSED + ) + def test_on_action_complete(self): wf_input = {'param1': 'Hey', 'param2': 'Hi'} diff --git a/mistral/tests/unit/engine/test_task_pause_resume.py b/mistral/tests/unit/engine/test_task_pause_resume.py new file mode 100644 index 000000000..8d7ea8cba --- /dev/null +++ b/mistral/tests/unit/engine/test_task_pause_resume.py @@ -0,0 +1,303 @@ +# Copyright 2015 - StackStorm, Inc. +# Copyright 2016 - Brocade Communications Systems, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mistral.db.v2 import api as db_api +from mistral.services import workflows as wf_service +from mistral.tests.unit.engine import base +from mistral.workflow import states +from mistral_lib import actions as ml_actions + + +class TaskPauseResumeTest(base.EngineTestCase): + + def test_pause_resume_action_ex(self): + workflow = """ + version: '2.0' + + wf: + tasks: + task1: + action: std.async_noop + on-success: + - task2 + task2: + action: std.noop + """ + + wf_service.create_workflows(workflow) + + wf_ex = self.engine.start_workflow('wf') + + self.await_workflow_state(wf_ex.id, states.RUNNING) + + with db_api.transaction(): + wf_execs = db_api.get_workflow_executions() + + wf_ex = self._assert_single_item(wf_execs, name='wf') + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.RUNNING, task_1_ex.state) + self.assertEqual(1, len(task_1_action_exs)) + self.assertEqual(states.RUNNING, task_1_action_exs[0].state) + + # Pause the action execution of task 1. + self.engine.on_action_update(task_1_action_exs[0].id, states.PAUSED) + + self.await_task_paused(task_1_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.PAUSED, task_1_ex.state) + self.assertEqual(1, len(task_1_action_exs)) + self.assertEqual(states.PAUSED, task_1_action_exs[0].state) + + # Resume the action execution of task 1. + self.engine.on_action_update(task_1_action_exs[0].id, states.RUNNING) + + self.await_task_running(task_1_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.RUNNING, task_1_ex.state) + self.assertEqual(1, len(task_1_action_exs)) + self.assertEqual(states.RUNNING, task_1_action_exs[0].state) + + # Complete action execution of task 1. + self.engine.on_action_complete( + task_1_action_exs[0].id, + ml_actions.Result(data={'result': 'foobar'}) + ) + + # Wait for the workflow execution to complete. + self.await_workflow_success(wf_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item(task_execs, name='task1') + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + task_2_ex = self._assert_single_item(task_execs, name='task2') + + self.assertEqual(states.SUCCESS, wf_ex.state) + self.assertEqual(2, len(task_execs)) + self.assertEqual(states.SUCCESS, task_1_ex.state) + self.assertEqual(1, len(task_1_action_exs)) + self.assertEqual(states.SUCCESS, task_1_action_exs[0].state) + self.assertEqual(states.SUCCESS, task_2_ex.state) + + def test_pause_resume_action_ex_with_items_task(self): + workflow = """ + version: '2.0' + + wf: + tasks: + task1: + with-items: i in <% range(3) %> + action: std.async_noop + on-success: + - task2 + task2: + action: std.noop + """ + + wf_service.create_workflows(workflow) + + wf_ex = self.engine.start_workflow('wf') + + self.await_workflow_state(wf_ex.id, states.RUNNING) + + with db_api.transaction(): + wf_execs = db_api.get_workflow_executions() + + wf_ex = self._assert_single_item(wf_execs, name='wf') + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.RUNNING, task_1_ex.state) + self.assertEqual(3, len(task_1_action_exs)) + self.assertEqual(states.RUNNING, task_1_action_exs[0].state) + self.assertEqual(states.RUNNING, task_1_action_exs[1].state) + self.assertEqual(states.RUNNING, task_1_action_exs[2].state) + + # Pause the 1st action execution of task 1. + self.engine.on_action_update(task_1_action_exs[0].id, states.PAUSED) + + self.await_task_paused(task_1_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.PAUSED, task_1_ex.state) + self.assertEqual(3, len(task_1_action_exs)) + self.assertEqual(states.PAUSED, task_1_action_exs[0].state) + self.assertEqual(states.RUNNING, task_1_action_exs[1].state) + self.assertEqual(states.RUNNING, task_1_action_exs[2].state) + + # Complete 2nd and 3rd action executions of task 1. + self.engine.on_action_complete( + task_1_action_exs[1].id, + ml_actions.Result(data={'result': 'two'}) + ) + + self.engine.on_action_complete( + task_1_action_exs[2].id, + ml_actions.Result(data={'result': 'three'}) + ) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.PAUSED, task_1_ex.state) + self.assertEqual(3, len(task_1_action_exs)) + self.assertEqual(states.PAUSED, task_1_action_exs[0].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[1].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[2].state) + + # Resume the 1st action execution of task 1. + self.engine.on_action_update(task_1_action_exs[0].id, states.RUNNING) + + self.await_task_running(task_1_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_1_ex = self._assert_single_item( + wf_ex.task_executions, + name='task1' + ) + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + self.assertEqual(states.RUNNING, wf_ex.state) + self.assertEqual(1, len(task_execs)) + self.assertEqual(states.RUNNING, task_1_ex.state) + self.assertEqual(3, len(task_1_action_exs)) + self.assertEqual(states.RUNNING, task_1_action_exs[0].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[1].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[2].state) + + # Complete the 1st action execution of task 1. + self.engine.on_action_complete( + task_1_action_exs[0].id, + ml_actions.Result(data={'result': 'foobar'}) + ) + + # Wait for the workflow execution to complete. + self.await_workflow_success(wf_ex.id) + + with db_api.transaction(): + wf_ex = db_api.get_workflow_execution(wf_ex.id) + + task_execs = wf_ex.task_executions + + task_1_ex = self._assert_single_item(task_execs, name='task1') + + task_1_action_exs = db_api.get_action_executions( + task_execution_id=task_1_ex.id + ) + + task_2_ex = self._assert_single_item(task_execs, name='task2') + + self.assertEqual(states.SUCCESS, wf_ex.state) + self.assertEqual(2, len(task_execs)) + self.assertEqual(states.SUCCESS, task_1_ex.state) + self.assertEqual(3, len(task_1_action_exs)) + self.assertEqual(states.SUCCESS, task_1_action_exs[0].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[1].state) + self.assertEqual(states.SUCCESS, task_1_action_exs[2].state) + self.assertEqual(states.SUCCESS, task_2_ex.state)