diff --git a/mistral/engine/actions/action_factory.py b/mistral/engine/actions/action_factory.py index 4d1a2c1b2..e3a11eb7d 100644 --- a/mistral/engine/actions/action_factory.py +++ b/mistral/engine/actions/action_factory.py @@ -48,7 +48,8 @@ def _get_mapping(): action_types.REST_API: get_rest_action, action_types.MISTRAL_REST_API: get_mistral_rest_action, action_types.OSLO_RPC: get_amqp_action, - action_types.SEND_EMAIL: get_send_email_action + action_types.SEND_EMAIL: get_send_email_action, + action_types.SSH: get_ssh_action } @@ -143,3 +144,23 @@ def get_send_email_action(db_task, task, service): return actions.SendEmailAction(action_type, action_name, task_params, service_params) + + +def get_ssh_action(db_task, task, service): + action_type = service.type + action_name = task.get_action_name() + task_params = task.parameters + action = service.actions.get(action_name) + action_params = action.parameters + + # Merge/replace action_params by task_params. + all_params = copy.copy(action_params) + all_params.update(task_params) + + cmd = all_params['cmd'] + host = all_params['host'] + username = all_params['username'] + password = all_params['password'] + + return actions.SSHAction(action_type, action_name, cmd, + host, username, password) diff --git a/mistral/engine/actions/action_types.py b/mistral/engine/actions/action_types.py index fc4660056..5e8a1d572 100644 --- a/mistral/engine/actions/action_types.py +++ b/mistral/engine/actions/action_types.py @@ -22,8 +22,9 @@ REST_API = 'REST_API' OSLO_RPC = 'OSLO_RPC' MISTRAL_REST_API = 'MISTRAL_REST_API' SEND_EMAIL = "SEND_EMAIL" +SSH = "SSH" -_ALL = [ECHO, REST_API, OSLO_RPC, MISTRAL_REST_API, SEND_EMAIL] +_ALL = [ECHO, REST_API, OSLO_RPC, MISTRAL_REST_API, SEND_EMAIL, SSH] def is_valid(action_type): diff --git a/mistral/engine/actions/actions.py b/mistral/engine/actions/actions.py index 0a312c2d7..ae08f63c6 100644 --- a/mistral/engine/actions/actions.py +++ b/mistral/engine/actions/actions.py @@ -28,6 +28,7 @@ import six from mistral.openstack.common import log as logging from mistral.engine import expressions as expr from mistral import exceptions as exc +from mistral.utils import ssh_utils LOG = logging.getLogger(__name__) @@ -211,3 +212,32 @@ class SendEmailAction(Action): except (smtplib.SMTPException, IOError) as e: raise exc.ActionException("Failed to send an email message: %s" % e) + + +class SSHAction(Action): + def __init__(self, action_type, action_name, + cmd, host, username, password): + super(SSHAction, self).__init__(action_type, action_name) + self.cmd = cmd + self.host = host + self.username = username + self.password = password + + def run(self): + def raise_exc(parent_exc=None): + message = ("Failed to execute ssh cmd " + "'%s' on %s" % (self.cmd, self.host)) + if parent_exc: + message += "\nException: %s" % str(parent_exc) + raise exc.ActionException(message) + + try: + status_code, result = ssh_utils.execute_command(self.cmd, + self.host, + self.username, + self.password) + if status_code > 0: + return raise_exc() + return result + except Exception as e: + return raise_exc(parent_exc=e) diff --git a/mistral/tests/unit/engine/actions/test_action_factory.py b/mistral/tests/unit/engine/actions/test_action_factory.py index e858f3130..c2f722d1c 100644 --- a/mistral/tests/unit/engine/actions/test_action_factory.py +++ b/mistral/tests/unit/engine/actions/test_action_factory.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import unittest2 @@ -136,3 +137,21 @@ class ActionFactoryTest(unittest2.TestCase): action = action_factory.create_action(task) self.assertIn("X-Auth-Token", action.headers) self.assertEqual(auth_token, action.headers["X-Auth-Token"]) + + def test_get_ssh_action(self): + task = copy.copy(SAMPLE_TASK) + task['service_spec'].update({'type': action_types.SSH}) + create_vm = task['service_spec']['actions']['create-vm'] + create_vm['parameters'].update({'host': '10.0.0.1', + 'cmd': 'ls -l'}) + task_params = task['task_spec']['parameters'] + task_params.update({'username': '$.ssh_username', + 'password': '$.ssh_password'}) + task['in_context'] = {'ssh_username': 'ubuntu', + 'ssh_password': 'ubuntu_password'} + action = action_factory.create_action(task) + + self.assertEqual("ubuntu", action.username) + self.assertEqual("ubuntu_password", action.password) + self.assertEqual("ls -l", action.cmd) + self.assertEqual("10.0.0.1", action.host) diff --git a/mistral/utils/ssh_utils.py b/mistral/utils/ssh_utils.py new file mode 100644 index 000000000..a774ec6ac --- /dev/null +++ b/mistral/utils/ssh_utils.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2014 - Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paramiko + +from mistral.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +def _read_paramimko_stream(recv_func): + result = '' + buf = recv_func(1024) + while buf != '': + result += buf + buf = recv_func(1024) + + return result + + +def _connect(host, username, password): + LOG.debug('Creating SSH connection to %s' % host) + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(host, username=username, password=password) + return ssh + + +def _cleanup(ssh): + ssh.close() + + +def execute_command(cmd, host, username, password, + get_stderr=False, raise_when_error=True): + ssh = _connect(host, username, password) + + LOG.debug("Executing command %s") + + try: + chan = ssh.get_transport().open_session() + chan.exec_command(cmd) + + # TODO (nmakhotkin): that could hang if stderr buffer overflows + stdout = _read_paramimko_stream(chan.recv) + stderr = _read_paramimko_stream(chan.recv_stderr) + + ret_code = chan.recv_exit_status() + + if ret_code and raise_when_error: + raise RuntimeError("Cmd: %s\nReturn code: %s\nstdout: %s" + % (cmd, ret_code, stdout)) + if get_stderr: + return ret_code, stdout, stderr + else: + return ret_code, stdout + finally: + _cleanup(ssh) diff --git a/requirements.txt b/requirements.txt index 26819b921..6de181f2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ requests kombu>=2.4.8 oslo.config>=1.2.0 oslo.messaging>=1.3.0a4 +paramiko>=1.9.0 python-keystoneclient>=0.3.2 networkx six>=1.5.2