Merge "Allow async action execution to be paused and resumed"

This commit is contained in:
Jenkins 2017-08-04 05:07:31 +00:00 committed by Gerrit Code Review
commit abebc649a0
14 changed files with 741 additions and 15 deletions

View File

@ -34,6 +34,14 @@ from mistral_lib import actions as ml_actions
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
SUPPORTED_TRANSITION_STATES = [
states.SUCCESS,
states.ERROR,
states.CANCELLED,
states.PAUSED,
states.RUNNING
]
def _load_deferred_output_field(action_ex): def _load_deferred_output_field(action_ex):
# We need to refer to this lazy-load field explicitly in # We need to refer to this lazy-load field explicitly in
@ -180,6 +188,15 @@ class ActionExecutionsController(rest.RestController):
action_ex action_ex
) )
if action_ex.state not in SUPPORTED_TRANSITION_STATES:
raise exc.InvalidResultException(
"Error. Expected one of %s, actual: %s" % (
SUPPORTED_TRANSITION_STATES,
action_ex.state
)
)
if states.is_completed(action_ex.state):
output = action_ex.output output = action_ex.output
if action_ex.state == states.SUCCESS: if action_ex.state == states.SUCCESS:
@ -190,16 +207,13 @@ class ActionExecutionsController(rest.RestController):
result = ml_actions.Result(error=output) result = ml_actions.Result(error=output)
elif action_ex.state == states.CANCELLED: elif action_ex.state == states.CANCELLED:
result = ml_actions.Result(cancel=True) result = ml_actions.Result(cancel=True)
else:
raise exc.InvalidResultException(
"Error. Expected one of %s, actual: %s" % (
[states.SUCCESS, states.ERROR, states.CANCELLED],
action_ex.state
)
)
values = rpc.get_engine_client().on_action_complete(id, result) values = rpc.get_engine_client().on_action_complete(id, result)
if action_ex.state in [states.PAUSED, states.RUNNING]:
state = action_ex.state
values = rpc.get_engine_client().on_action_update(id, state)
return resources.ActionExecution.from_dict(values) return resources.ActionExecution.from_dict(values)
@rest_utils.wrap_wsme_controller_exception @rest_utils.wrap_wsme_controller_exception

View File

@ -940,7 +940,8 @@ def _get_incomplete_task_executions_query(kwargs):
models.TaskExecution.state == states.IDLE, models.TaskExecution.state == states.IDLE,
models.TaskExecution.state == states.RUNNING, models.TaskExecution.state == states.RUNNING,
models.TaskExecution.state == states.WAITING, models.TaskExecution.state == states.WAITING,
models.TaskExecution.state == states.RUNNING_DELAYED models.TaskExecution.state == states.RUNNING_DELAYED,
models.TaskExecution.state == states.PAUSED
) )
) )

View File

@ -54,6 +54,31 @@ def on_action_complete(action_ex, result):
task_handler.schedule_on_action_complete(action_ex) task_handler.schedule_on_action_complete(action_ex)
@profiler.trace('action-handler-on-action-update', hide_args=True)
def on_action_update(action_ex, state):
task_ex = action_ex.task_execution
action = _build_action(action_ex)
try:
action.update(state)
except exc.MistralException as e:
# If the update of the action execution fails, do not fail
# the action execution. Log the exception and re-raise the
# exception.
msg = (
"Failed to update action [error=%s, action=%s, task=%s]:\n%s"
% (e, action_ex.name, task_ex.name, tb.format_exc())
)
LOG.error(msg)
raise
if task_ex:
task_handler.schedule_on_action_update(action_ex)
@profiler.trace('action-handler-build-action', hide_args=True) @profiler.trace('action-handler-build-action', hide_args=True)
def _build_action(action_ex): def _build_action(action_ex):
if isinstance(action_ex, models.WorkflowExecution): if isinstance(action_ex, models.WorkflowExecution):

View File

@ -73,6 +73,23 @@ class Action(object):
self.action_ex.state = states.ERROR self.action_ex.state = states.ERROR
self.action_ex.output = {'result': msg} self.action_ex.output = {'result': msg}
def update(self, state):
assert self.action_ex
if state == states.PAUSED and self.is_sync(self.action_ex.input):
raise exc.InvalidStateTransitionException(
'Transition to the PAUSED state is only supported '
'for asynchronous action execution.'
)
if not states.is_valid_transition(self.action_ex.state, state):
raise exc.InvalidStateTransitionException(
'Invalid state transition from %s to %s.' %
(self.action_ex.state, state)
)
self.action_ex.state = state
@abc.abstractmethod @abc.abstractmethod
def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False): def schedule(self, input_dict, target, index=0, desc='', safe_rerun=False):
"""Schedule action run. """Schedule action run.

View File

@ -118,6 +118,21 @@ class DefaultEngine(base.Engine):
return action_ex.get_clone() return action_ex.get_clone()
@db_utils.retry_on_deadlock
@action_queue.process
@profiler.trace('engine-on-action-update', hide_args=True)
def on_action_update(self, action_ex_id, state, wf_action=False,
async_=False):
with db_api.transaction():
if wf_action:
action_ex = db_api.get_workflow_execution(action_ex_id)
else:
action_ex = db_api.get_action_execution(action_ex_id)
action_handler.on_action_update(action_ex, state)
return action_ex.get_clone()
def pause_workflow(self, wf_ex_id): def pause_workflow(self, wf_ex_id):
with db_api.transaction(): with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex_id) wf_ex = db_api.get_workflow_execution(wf_ex_id)

View File

@ -153,6 +153,24 @@ class EngineServer(service_base.MistralService):
return self.engine.on_action_complete(action_ex_id, result, wf_action) return self.engine.on_action_complete(action_ex_id, result, wf_action)
def on_action_update(self, rpc_ctx, action_ex_id, state, wf_action):
"""Receives RPC calls to communicate action execution state to engine.
:param rpc_ctx: RPC request context.
:param action_ex_id: Action execution id.
:param state: Action execution state.
:param wf_action: True if given id points to a workflow execution.
:return: Action execution.
"""
LOG.info(
"Received RPC request 'on_action_update'"
"[action_ex_id=%s, state=%s]",
action_ex_id,
state
)
return self.engine.on_action_update(action_ex_id, state, wf_action)
def pause_workflow(self, rpc_ctx, execution_id): def pause_workflow(self, rpc_ctx, execution_id):
"""Receives calls over RPC to pause workflows on engine. """Receives calls over RPC to pause workflows on engine.

View File

@ -44,6 +44,10 @@ _SCHEDULED_ON_ACTION_COMPLETE_PATH = (
'mistral.engine.task_handler._scheduled_on_action_complete' 'mistral.engine.task_handler._scheduled_on_action_complete'
) )
_SCHEDULED_ON_ACTION_UPDATE_PATH = (
'mistral.engine.task_handler._scheduled_on_action_update'
)
@profiler.trace('task-handler-run-task', hide_args=True) @profiler.trace('task-handler-run-task', hide_args=True)
def run_task(wf_cmd): def run_task(wf_cmd):
@ -116,6 +120,46 @@ def _on_action_complete(action_ex):
wf_handler.force_fail_workflow(wf_ex, msg) wf_handler.force_fail_workflow(wf_ex, msg)
@profiler.trace('task-handler-on-action-update', hide_args=True)
def _on_action_update(action_ex):
"""Handles action update event.
:param action_ex: Action execution.
"""
task_ex = action_ex.task_execution
if not task_ex:
return
task_spec = spec_parser.get_task_spec(task_ex.spec)
wf_ex = task_ex.workflow_execution
task = _create_task(
wf_ex,
spec_parser.get_workflow_spec_by_execution_id(wf_ex.id),
task_spec,
task_ex.in_context,
task_ex
)
try:
task.on_action_update(action_ex)
except exc.MistralException as e:
wf_ex = task_ex.workflow_execution
msg = ("Failed to handle action update [error=%s, wf=%s, task=%s,"
" action=%s]:\n%s" %
(e, wf_ex.name, task_ex.name, action_ex.name, tb.format_exc()))
LOG.error(msg)
task.set_state(states.ERROR, msg)
wf_handler.force_fail_workflow(wf_ex, msg)
return return
@ -386,3 +430,48 @@ def schedule_on_action_complete(action_ex, delay=0):
action_ex_id=action_ex.id, action_ex_id=action_ex.id,
wf_action=isinstance(action_ex, models.WorkflowExecution) wf_action=isinstance(action_ex, models.WorkflowExecution)
) )
@action_queue.process
def _scheduled_on_action_update(action_ex_id, wf_action):
with db_api.transaction():
if wf_action:
action_ex = db_api.get_workflow_execution(action_ex_id)
else:
action_ex = db_api.get_action_execution(action_ex_id)
_on_action_update(action_ex)
def schedule_on_action_update(action_ex, delay=0):
"""Schedules task update check.
This method provides transactional decoupling of action update from
task update check. It's needed in non-locking model in order to
avoid 'phantom read' phenomena when reading state of multiple actions
to see if a task is updated. Just starting a separate transaction
without using scheduler is not safe due to concurrency window that we'll
have in this case (time between transactions) whereas scheduler is a
special component that is designed to be resistant to failures.
:param action_ex: Action execution.
:param delay: Minimum amount of time before task update check
should be made.
"""
# Optimization to avoid opening a new transaction if it's not needed.
if not action_ex.task_execution.spec.get('with-items'):
_on_action_update(action_ex)
return
key = 'th_on_a_c-%s' % action_ex.task_execution_id
scheduler.schedule_call(
None,
_SCHEDULED_ON_ACTION_UPDATE_PATH,
delay,
key=key,
action_ex_id=action_ex.id,
wf_action=isinstance(action_ex, models.WorkflowExecution)
)

View File

@ -77,6 +77,14 @@ class Task(object):
""" """
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def on_action_update(self, action_ex):
"""Handle action update.
:param action_ex: Action execution.
"""
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def run(self): def run(self):
"""Runs task.""" """Runs task."""
@ -191,6 +199,26 @@ class Task(object):
dispatcher.dispatch_workflow_commands(self.wf_ex, cmds) dispatcher.dispatch_workflow_commands(self.wf_ex, cmds)
@profiler.trace('task-update')
def update(self, state, state_info=None):
"""Update task and set specified state.
Method sets specified task state.
:param state: New task state.
:param state_info: New state information (i.e. error message).
"""
assert self.task_ex
# Ignore if task already completed.
if states.is_completed(self.task_ex.state):
return
# Update only if state transition is valid.
if states.is_valid_transition(self.task_ex.state, state):
self.set_state(state, state_info)
def _before_task_start(self): def _before_task_start(self):
policies_spec = self.task_spec.get_policies() policies_spec = self.task_spec.get_policies()
@ -268,6 +296,10 @@ class RegularTask(Task):
self.complete(state, state_info) self.complete(state, state_info)
@profiler.trace('regular-task-on-action-update', hide_args=True)
def on_action_update(self, action_ex):
self.update(action_ex.state)
@profiler.trace('task-run') @profiler.trace('task-run')
def run(self): def run(self):
if not self.task_ex: if not self.task_ex:

View File

@ -215,3 +215,8 @@ class KombuException(Exception):
self.exc_type = e.__class__.__name__ self.exc_type = e.__class__.__name__
self.value = str(e) self.value = str(e)
class InvalidStateTransitionException(MistralException):
http_code = 400
message = 'Invalid state transition'

View File

@ -147,6 +147,40 @@ class EngineClient(eng.Engine):
wf_action=wf_action wf_action=wf_action
) )
@base.wrap_messaging_exception
@profiler.trace('engine-client-on-action-update', hide_args=True)
def on_action_update(self, action_ex_id, state, wf_action=False,
async_=False):
"""Conveys update of action state to Mistral Engine.
This method should be used by clients of Mistral Engine to update
state of a action execution once action has executed.
Note: calling this method serves an event notifying Mistral that it
may need to change the state of the parent task and workflow. Use
on_action_complete if the action execution reached completion state.
:param action_ex_id: Action execution id.
:param action_ex_id: Updated state.
:param wf_action: If True it means that the given id points to
a workflow execution rather than action execution. It happens
when a nested workflow execution sends its result to a parent
workflow.
:param async: If True, run action in asynchronous mode (w/o waiting
for completion).
:return: Action(or workflow if wf_action=True) execution object.
"""
call = self._client.async_call if async_ else self._client.sync_call
return call(
auth_ctx.ctx(),
'on_action_update',
action_ex_id=action_ex_id,
state=state,
wf_action=wf_action
)
@base.wrap_messaging_exception @base.wrap_messaging_exception
def pause_workflow(self, wf_ex_id): def pause_workflow(self, wf_ex_id):
"""Stops the workflow with the given execution id. """Stops the workflow with the given execution id.

View File

@ -140,6 +140,18 @@ CANCELLED_ACTION_EX_DB['task_name'] = 'task1'
CANCELLED_ACTION = copy.deepcopy(ACTION_EX) CANCELLED_ACTION = copy.deepcopy(ACTION_EX)
CANCELLED_ACTION['state'] = 'CANCELLED' CANCELLED_ACTION['state'] = 'CANCELLED'
PAUSED_ACTION_EX_DB = copy.copy(ACTION_EX_DB).to_dict()
PAUSED_ACTION_EX_DB['state'] = 'PAUSED'
PAUSED_ACTION_EX_DB['task_name'] = 'task1'
PAUSED_ACTION = copy.deepcopy(ACTION_EX)
PAUSED_ACTION['state'] = 'PAUSED'
RUNNING_ACTION_EX_DB = copy.copy(ACTION_EX_DB).to_dict()
RUNNING_ACTION_EX_DB['state'] = 'RUNNING'
RUNNING_ACTION_EX_DB['task_name'] = 'task1'
RUNNING_ACTION = copy.deepcopy(ACTION_EX)
RUNNING_ACTION['state'] = 'RUNNING'
ERROR_ACTION_EX = copy.copy(ACTION_EX_DB).to_dict() ERROR_ACTION_EX = copy.copy(ACTION_EX_DB).to_dict()
ERROR_ACTION_EX['state'] = 'ERROR' ERROR_ACTION_EX['state'] = 'ERROR'
ERROR_ACTION_EX['task_name'] = 'task1' ERROR_ACTION_EX['task_name'] = 'task1'
@ -395,6 +407,34 @@ class TestActionExecutionsController(base.APITest):
ml_actions.Result(cancel=True) ml_actions.Result(cancel=True)
) )
@mock.patch.object(rpc_clients.EngineClient, 'on_action_update')
def test_put_paused(self, on_action_update_mock_func):
on_action_update_mock_func.return_value = PAUSED_ACTION_EX_DB
resp = self.app.put_json('/v2/action_executions/123', PAUSED_ACTION)
self.assertEqual(200, resp.status_int)
self.assertDictEqual(PAUSED_ACTION, resp.json)
on_action_update_mock_func.assert_called_once_with(
PAUSED_ACTION['id'],
PAUSED_ACTION['state']
)
@mock.patch.object(rpc_clients.EngineClient, 'on_action_update')
def test_put_resume(self, on_action_update_mock_func):
on_action_update_mock_func.return_value = RUNNING_ACTION_EX_DB
resp = self.app.put_json('/v2/action_executions/123', RUNNING_ACTION)
self.assertEqual(200, resp.status_int)
self.assertDictEqual(RUNNING_ACTION, resp.json)
on_action_update_mock_func.assert_called_once_with(
RUNNING_ACTION['id'],
RUNNING_ACTION['state']
)
@mock.patch.object( @mock.patch.object(
rpc_clients.EngineClient, rpc_clients.EngineClient,
'on_action_complete', 'on_action_complete',
@ -411,7 +451,7 @@ class TestActionExecutionsController(base.APITest):
def test_put_bad_state(self): def test_put_bad_state(self):
action = copy.deepcopy(ACTION_EX) action = copy.deepcopy(ACTION_EX)
action['state'] = 'PAUSED' action['state'] = 'DELAYED'
resp = self.app.put_json( resp = self.app.put_json(
'/v2/action_executions/123', '/v2/action_executions/123',

View File

@ -208,6 +208,10 @@ class EngineTestCase(base.DbTestCase):
def is_task_processed(self, task_ex_id): def is_task_processed(self, task_ex_id):
return db_api.get_task_execution(task_ex_id).processed return db_api.get_task_execution(task_ex_id).processed
def await_task_running(self, ex_id, delay=DEFAULT_DELAY,
timeout=DEFAULT_TIMEOUT):
self.await_task_state(ex_id, states.RUNNING, delay, timeout)
def await_task_success(self, ex_id, delay=DEFAULT_DELAY, def await_task_success(self, ex_id, delay=DEFAULT_DELAY,
timeout=DEFAULT_TIMEOUT): timeout=DEFAULT_TIMEOUT):
self.await_task_state(ex_id, states.SUCCESS, delay, timeout) self.await_task_state(ex_id, states.SUCCESS, delay, timeout)
@ -220,6 +224,10 @@ class EngineTestCase(base.DbTestCase):
timeout=DEFAULT_TIMEOUT): timeout=DEFAULT_TIMEOUT):
self.await_task_state(ex_id, states.CANCELLED, delay, timeout) self.await_task_state(ex_id, states.CANCELLED, delay, timeout)
def await_task_paused(self, ex_id, delay=DEFAULT_DELAY,
timeout=DEFAULT_TIMEOUT):
self.await_task_state(ex_id, states.PAUSED, delay, timeout)
def await_task_delayed(self, ex_id, delay=DEFAULT_DELAY, def await_task_delayed(self, ex_id, delay=DEFAULT_DELAY,
timeout=DEFAULT_TIMEOUT): timeout=DEFAULT_TIMEOUT):
self.await_task_state(ex_id, states.RUNNING_DELAYED, delay, timeout) self.await_task_state(ex_id, states.RUNNING_DELAYED, delay, timeout)

View File

@ -26,6 +26,7 @@ from mistral.engine import default_engine as d_eng
from mistral import exceptions as exc from mistral import exceptions as exc
from mistral.executors import base as exe from mistral.executors import base as exe
from mistral.services import workbooks as wb_service from mistral.services import workbooks as wb_service
from mistral.services import workflows as wf_service
from mistral.tests.unit import base from mistral.tests.unit import base
from mistral.tests.unit.engine import base as eng_test_base from mistral.tests.unit.engine import base as eng_test_base
from mistral.workflow import states from mistral.workflow import states
@ -299,6 +300,130 @@ class DefaultEngineTest(base.DbTestCase):
self.assertIn("Invalid input", str(e)) self.assertIn("Invalid input", str(e))
self.assertIn("unexpected=['unexpected_param']", str(e)) self.assertIn("unexpected=['unexpected_param']", str(e))
def test_on_action_update(self):
workflow = """
version: '2.0'
wf_async:
type: direct
tasks:
task1:
action: std.async_noop
on-success:
- task2
task2:
action: std.noop
"""
# Start workflow.
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf_async')
self.assertIsNotNone(wf_ex)
self.assertEqual(states.RUNNING, wf_ex.state)
with db_api.transaction():
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(1, len(task_execs))
task1_ex = task_execs[0]
self.assertEqual('task1', task1_ex.name)
self.assertEqual(states.RUNNING, task1_ex.state)
action_execs = db_api.get_action_executions(
task_execution_id=task1_ex.id
)
self.assertEqual(1, len(action_execs))
task1_action_ex = action_execs[0]
self.assertEqual(states.RUNNING, task1_action_ex.state)
# Pause action execution of 'task1'.
task1_action_ex = self.engine.on_action_update(
task1_action_ex.id,
states.PAUSED
)
self.assertIsInstance(task1_action_ex, models.ActionExecution)
self.assertEqual(states.PAUSED, task1_action_ex.state)
with db_api.transaction():
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(1, len(task_execs))
self.assertEqual(states.PAUSED, task_execs[0].state)
self.assertEqual(states.RUNNING, wf_ex.state)
action_execs = db_api.get_action_executions(
task_execution_id=task1_ex.id
)
self.assertEqual(1, len(action_execs))
task1_action_ex = action_execs[0]
self.assertEqual(states.PAUSED, task1_action_ex.state)
def test_on_action_update_non_async(self):
workflow = """
version: '2.0'
wf_sync:
type: direct
tasks:
task1:
action: std.noop
on-success:
- task2
task2:
action: std.noop
"""
# Start workflow.
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf_sync')
self.assertIsNotNone(wf_ex)
self.assertEqual(states.RUNNING, wf_ex.state)
with db_api.transaction():
# Note: We need to reread execution to access related tasks.
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
self.assertEqual(1, len(task_execs))
task1_ex = task_execs[0]
self.assertEqual('task1', task1_ex.name)
self.assertEqual(states.RUNNING, task1_ex.state)
action_execs = db_api.get_action_executions(
task_execution_id=task1_ex.id
)
self.assertEqual(1, len(action_execs))
task1_action_ex = action_execs[0]
self.assertEqual(states.RUNNING, task1_action_ex.state)
self.assertRaises(
exc.InvalidStateTransitionException,
self.engine.on_action_update,
task1_action_ex.id,
states.PAUSED
)
def test_on_action_complete(self): def test_on_action_complete(self):
wf_input = {'param1': 'Hey', 'param2': 'Hi'} wf_input = {'param1': 'Hey', 'param2': 'Hi'}

View File

@ -0,0 +1,303 @@
# Copyright 2015 - StackStorm, Inc.
# Copyright 2016 - Brocade Communications Systems, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mistral.db.v2 import api as db_api
from mistral.services import workflows as wf_service
from mistral.tests.unit.engine import base
from mistral.workflow import states
from mistral_lib import actions as ml_actions
class TaskPauseResumeTest(base.EngineTestCase):
def test_pause_resume_action_ex(self):
workflow = """
version: '2.0'
wf:
tasks:
task1:
action: std.async_noop
on-success:
- task2
task2:
action: std.noop
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_state(wf_ex.id, states.RUNNING)
with db_api.transaction():
wf_execs = db_api.get_workflow_executions()
wf_ex = self._assert_single_item(wf_execs, name='wf')
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.RUNNING, task_1_ex.state)
self.assertEqual(1, len(task_1_action_exs))
self.assertEqual(states.RUNNING, task_1_action_exs[0].state)
# Pause the action execution of task 1.
self.engine.on_action_update(task_1_action_exs[0].id, states.PAUSED)
self.await_task_paused(task_1_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.PAUSED, task_1_ex.state)
self.assertEqual(1, len(task_1_action_exs))
self.assertEqual(states.PAUSED, task_1_action_exs[0].state)
# Resume the action execution of task 1.
self.engine.on_action_update(task_1_action_exs[0].id, states.RUNNING)
self.await_task_running(task_1_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.RUNNING, task_1_ex.state)
self.assertEqual(1, len(task_1_action_exs))
self.assertEqual(states.RUNNING, task_1_action_exs[0].state)
# Complete action execution of task 1.
self.engine.on_action_complete(
task_1_action_exs[0].id,
ml_actions.Result(data={'result': 'foobar'})
)
# Wait for the workflow execution to complete.
self.await_workflow_success(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(task_execs, name='task1')
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
task_2_ex = self._assert_single_item(task_execs, name='task2')
self.assertEqual(states.SUCCESS, wf_ex.state)
self.assertEqual(2, len(task_execs))
self.assertEqual(states.SUCCESS, task_1_ex.state)
self.assertEqual(1, len(task_1_action_exs))
self.assertEqual(states.SUCCESS, task_1_action_exs[0].state)
self.assertEqual(states.SUCCESS, task_2_ex.state)
def test_pause_resume_action_ex_with_items_task(self):
workflow = """
version: '2.0'
wf:
tasks:
task1:
with-items: i in <% range(3) %>
action: std.async_noop
on-success:
- task2
task2:
action: std.noop
"""
wf_service.create_workflows(workflow)
wf_ex = self.engine.start_workflow('wf')
self.await_workflow_state(wf_ex.id, states.RUNNING)
with db_api.transaction():
wf_execs = db_api.get_workflow_executions()
wf_ex = self._assert_single_item(wf_execs, name='wf')
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.RUNNING, task_1_ex.state)
self.assertEqual(3, len(task_1_action_exs))
self.assertEqual(states.RUNNING, task_1_action_exs[0].state)
self.assertEqual(states.RUNNING, task_1_action_exs[1].state)
self.assertEqual(states.RUNNING, task_1_action_exs[2].state)
# Pause the 1st action execution of task 1.
self.engine.on_action_update(task_1_action_exs[0].id, states.PAUSED)
self.await_task_paused(task_1_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.PAUSED, task_1_ex.state)
self.assertEqual(3, len(task_1_action_exs))
self.assertEqual(states.PAUSED, task_1_action_exs[0].state)
self.assertEqual(states.RUNNING, task_1_action_exs[1].state)
self.assertEqual(states.RUNNING, task_1_action_exs[2].state)
# Complete 2nd and 3rd action executions of task 1.
self.engine.on_action_complete(
task_1_action_exs[1].id,
ml_actions.Result(data={'result': 'two'})
)
self.engine.on_action_complete(
task_1_action_exs[2].id,
ml_actions.Result(data={'result': 'three'})
)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.PAUSED, task_1_ex.state)
self.assertEqual(3, len(task_1_action_exs))
self.assertEqual(states.PAUSED, task_1_action_exs[0].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[1].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[2].state)
# Resume the 1st action execution of task 1.
self.engine.on_action_update(task_1_action_exs[0].id, states.RUNNING)
self.await_task_running(task_1_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_1_ex = self._assert_single_item(
wf_ex.task_executions,
name='task1'
)
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
self.assertEqual(states.RUNNING, wf_ex.state)
self.assertEqual(1, len(task_execs))
self.assertEqual(states.RUNNING, task_1_ex.state)
self.assertEqual(3, len(task_1_action_exs))
self.assertEqual(states.RUNNING, task_1_action_exs[0].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[1].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[2].state)
# Complete the 1st action execution of task 1.
self.engine.on_action_complete(
task_1_action_exs[0].id,
ml_actions.Result(data={'result': 'foobar'})
)
# Wait for the workflow execution to complete.
self.await_workflow_success(wf_ex.id)
with db_api.transaction():
wf_ex = db_api.get_workflow_execution(wf_ex.id)
task_execs = wf_ex.task_executions
task_1_ex = self._assert_single_item(task_execs, name='task1')
task_1_action_exs = db_api.get_action_executions(
task_execution_id=task_1_ex.id
)
task_2_ex = self._assert_single_item(task_execs, name='task2')
self.assertEqual(states.SUCCESS, wf_ex.state)
self.assertEqual(2, len(task_execs))
self.assertEqual(states.SUCCESS, task_1_ex.state)
self.assertEqual(3, len(task_1_action_exs))
self.assertEqual(states.SUCCESS, task_1_action_exs[0].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[1].state)
self.assertEqual(states.SUCCESS, task_1_action_exs[2].state)
self.assertEqual(states.SUCCESS, task_2_ex.state)