Merge "Add console feature to ssh driver"

This commit is contained in:
Jenkins 2015-12-23 23:04:15 +00:00 committed by Gerrit Code Review
commit 930dee6242
7 changed files with 229 additions and 7 deletions

View File

@ -121,6 +121,7 @@ class AgentAndSSHDriver(base.BaseDriver):
self.raid = agent.AgentRAID()
self.inspect = inspector.Inspector.create_if_enabled(
'AgentAndSSHDriver')
self.console = ssh.ShellinaboxConsole()
class AgentAndVirtualBoxDriver(base.BaseDriver):

View File

@ -104,6 +104,7 @@ class FakeSSHDriver(base.BaseDriver):
self.power = ssh.SSHPower()
self.deploy = fake.FakeDeploy()
self.management = ssh.SSHManagement()
self.console = ssh.ShellinaboxConsole()
class FakeIPMINativeDriver(base.BaseDriver):

View File

@ -45,6 +45,7 @@ from ironic.common import states
from ironic.common import utils
from ironic.conductor import task_manager
from ironic.drivers import base
from ironic.drivers.modules import console_utils
from ironic.drivers import utils as driver_utils
libvirt_opts = [
@ -87,6 +88,10 @@ OTHER_PROPERTIES = {
}
COMMON_PROPERTIES = REQUIRED_PROPERTIES.copy()
COMMON_PROPERTIES.update(OTHER_PROPERTIES)
CONSOLE_PROPERTIES = {
'ssh_terminal_port': _("node's UDP port to connect to. Only required for "
"console access and only applicable for 'virsh'.")
}
# NOTE(dguerri) Generic boot device map. Virtualisation types that don't define
# a more specific one, will use this.
@ -369,6 +374,11 @@ def _parse_driver_info(node):
key_contents = info.get('ssh_key_contents')
key_filename = info.get('ssh_key_filename')
virt_type = info.get('ssh_virt_type')
terminal_port = info.get('ssh_terminal_port')
if terminal_port is not None:
terminal_port = utils.validate_network_port(terminal_port,
'ssh_terminal_port')
# NOTE(deva): we map 'address' from API to 'host' for common utils
res = {
@ -376,7 +386,8 @@ def _parse_driver_info(node):
'username': username,
'port': port,
'virt_type': virt_type,
'uuid': node.uuid
'uuid': node.uuid,
'terminal_port': terminal_port
}
cmd_set = _get_command_sets(virt_type)
@ -788,3 +799,80 @@ class SSHManagement(base.ManagementInterface):
"""
raise NotImplementedError()
class ShellinaboxConsole(base.ConsoleInterface):
"""A ConsoleInterface that uses ssh and shellinabox."""
def get_properties(self):
properties = COMMON_PROPERTIES.copy()
properties.update(CONSOLE_PROPERTIES)
return properties
def validate(self, task):
"""Validate the Node console info.
:param task: a task from TaskManager.
:raises: MissingParameterValue if required ssh parameters are
missing
:raises: InvalidParameterValue if required parameters are invalid.
"""
driver_info = _parse_driver_info(task.node)
if driver_info['virt_type'] != 'virsh':
raise exception.InvalidParameterValue(_(
"not supported for non-virsh types"))
if not driver_info['terminal_port']:
raise exception.MissingParameterValue(_(
"Missing 'ssh_terminal_port' parameter in node's "
"'driver_info'"))
def start_console(self, task):
"""Start a remote console for the node.
:param task: a task from TaskManager
:raises: MissingParameterValue if required ssh parameters are
missing
:raises: ConsoleError if the directory for the PID file cannot be
created
:raises: ConsoleSubprocessFailed when invoking the subprocess failed
:raises: InvalidParameterValue if required parameters are invalid.
"""
driver_info = _parse_driver_info(task.node)
driver_info['macs'] = driver_utils.get_node_mac_addresses(task)
ssh_obj = _get_connection(task.node)
node_name = _get_hosts_name_for_node(ssh_obj, driver_info)
ssh_cmd = ("/:%(uid)s:%(gid)s:HOME:virsh console %(node)s"
% {'uid': os.getuid(),
'gid': os.getgid(),
'node': node_name})
console_utils.start_shellinabox_console(driver_info['uuid'],
driver_info['terminal_port'],
ssh_cmd)
def stop_console(self, task):
"""Stop the remote console session for the node.
:param task: a task from TaskManager
:raises: ConsoleError if unable to stop the console
"""
console_utils.stop_shellinabox_console(task.node.uuid)
def get_console(self, task):
"""Get the type and connection information about the console.
:param task: a task from TaskManager
:raises: MissingParameterValue if required ssh parameters are
missing
:raises: InvalidParameterValue if required parameter are invalid.
"""
driver_info = _parse_driver_info(task.node)
url = console_utils.get_shellinabox_console_url(
driver_info['terminal_port'])
return {'type': 'shellinabox', 'url': url}

View File

@ -107,6 +107,7 @@ class PXEAndSSHDriver(base.BaseDriver):
self.inspect = inspector.Inspector.create_if_enabled(
'PXEAndSSHDriver')
self.raid = agent.AgentRAID()
self.console = ssh.ShellinaboxConsole()
class PXEAndIPMINativeDriver(base.BaseDriver):

View File

@ -3326,7 +3326,7 @@ class ManagerTestProperties(tests_db_base.DbTestCase):
def test_driver_properties_fake_ssh(self):
expected = ['ssh_address', 'ssh_username', 'ssh_virt_type',
'ssh_key_contents', 'ssh_key_filename',
'ssh_password', 'ssh_port']
'ssh_password', 'ssh_port', 'ssh_terminal_port']
self._check_driver_properties("fake_ssh", expected)
def test_driver_properties_fake_pxe(self):
@ -3365,7 +3365,7 @@ class ManagerTestProperties(tests_db_base.DbTestCase):
expected = ['deploy_kernel', 'deploy_ramdisk',
'ssh_address', 'ssh_username', 'ssh_virt_type',
'ssh_key_contents', 'ssh_key_filename',
'ssh_password', 'ssh_port']
'ssh_password', 'ssh_port', 'ssh_terminal_port']
self._check_driver_properties("pxe_ssh", expected)
def test_driver_properties_pxe_seamicro(self):

View File

@ -29,6 +29,7 @@ from ironic.common import exception
from ironic.common import states
from ironic.common import utils
from ironic.conductor import task_manager
from ironic.drivers.modules import console_utils
from ironic.drivers.modules import ssh
from ironic.drivers import utils as driver_utils
from ironic.tests.unit.conductor import mgr_utils
@ -619,11 +620,9 @@ class SSHDriverTestCase(db_base.DbTestCase):
@mock.patch.object(utils, 'ssh_connect', autospec=True)
def test__validate_info_ssh_connect_failed(self, ssh_connect_mock):
info = ssh._parse_driver_info(self.node)
ssh_connect_mock.side_effect = iter(
[exception.SSHConnectFailed(host='fake')])
with task_manager.acquire(self.context, info['uuid'],
with task_manager.acquire(self.context, self.node.uuid,
shared=False) as task:
self.assertRaises(exception.InvalidParameterValue,
task.driver.power.validate, task)
@ -632,11 +631,17 @@ class SSHDriverTestCase(db_base.DbTestCase):
def test_get_properties(self):
expected = ssh.COMMON_PROPERTIES
expected2 = list(ssh.COMMON_PROPERTIES) + list(ssh.CONSOLE_PROPERTIES)
with task_manager.acquire(self.context, self.node.uuid,
shared=True) as task:
self.assertEqual(expected, task.driver.power.get_properties())
self.assertEqual(expected, task.driver.get_properties())
self.assertEqual(expected, task.driver.management.get_properties())
self.assertEqual(
sorted(expected2),
sorted(task.driver.console.get_properties().keys()))
self.assertEqual(
sorted(expected2),
sorted(task.driver.get_properties().keys()))
def test_validate_fail_no_port(self):
new_node = obj_utils.create_test_node(
@ -1075,3 +1080,126 @@ class SSHDriverTestCase(db_base.DbTestCase):
with task_manager.acquire(self.context, node.uuid) as task:
self.assertRaises(exception.MissingParameterValue,
task.driver.management.validate, task)
def test_console_validate(self):
with task_manager.acquire(
self.context, self.node.uuid, shared=True) as task:
task.node.driver_info['ssh_virt_type'] = 'virsh'
task.node.driver_info['ssh_terminal_port'] = 123
task.driver.console.validate(task)
def test_console_validate_missing_port(self):
with task_manager.acquire(
self.context, self.node.uuid, shared=True) as task:
task.node.driver_info['ssh_virt_type'] = 'virsh'
task.node.driver_info.pop('ssh_terminal_port', None)
self.assertRaises(exception.MissingParameterValue,
task.driver.console.validate, task)
def test_console_validate_not_virsh(self):
with task_manager.acquire(
self.context, self.node.uuid, shared=True) as task:
self.assertRaisesRegex(exception.InvalidParameterValue,
'not supported for non-virsh types',
task.driver.console.validate, task)
def test_console_validate_invalid_port(self):
with task_manager.acquire(
self.context, self.node.uuid, shared=True) as task:
task.node.driver_info['ssh_terminal_port'] = ''
self.assertRaisesRegex(exception.InvalidParameterValue,
'is not a valid integer',
task.driver.console.validate, task)
@mock.patch.object(ssh, '_get_connection', autospec=True)
@mock.patch.object(ssh, '_get_hosts_name_for_node', autospec=True)
@mock.patch.object(console_utils, 'start_shellinabox_console',
autospec=True)
def test_start_console(self, mock_exec,
get_hosts_name_mock, mock_get_conn):
info = ssh._parse_driver_info(self.node)
mock_exec.return_value = None
get_hosts_name_mock.return_value = "NodeName"
mock_get_conn.return_value = self.sshclient
with task_manager.acquire(self.context,
self.node.uuid) as task:
self.driver.console.start_console(task)
mock_exec.assert_called_once_with(info['uuid'],
info['terminal_port'],
mock.ANY)
@mock.patch.object(ssh, '_get_connection', autospec=True)
@mock.patch.object(ssh, '_get_hosts_name_for_node', autospec=True)
@mock.patch.object(console_utils, 'start_shellinabox_console',
autospec=True)
def test_start_console_fail(self, mock_exec,
get_hosts_name_mock, mock_get_conn):
get_hosts_name_mock.return_value = "NodeName"
mock_get_conn.return_value = self.sshclient
mock_exec.side_effect = exception.ConsoleSubprocessFailed(
error='error')
with task_manager.acquire(self.context,
self.node.uuid) as task:
self.assertRaises(exception.ConsoleSubprocessFailed,
self.driver.console.start_console,
task)
mock_exec.assert_called_once_with(self.node.uuid, mock.ANY, mock.ANY)
@mock.patch.object(ssh, '_get_connection', autospec=True)
@mock.patch.object(ssh, '_get_hosts_name_for_node', autospec=True)
@mock.patch.object(console_utils, 'start_shellinabox_console',
autospec=True)
def test_start_console_fail_nodir(self, mock_exec,
get_hosts_name_mock, mock_get_conn):
get_hosts_name_mock.return_value = "NodeName"
mock_get_conn.return_value = self.sshclient
mock_exec.side_effect = exception.ConsoleError()
with task_manager.acquire(self.context,
self.node.uuid) as task:
self.assertRaises(exception.ConsoleError,
self.driver.console.start_console,
task)
mock_exec.assert_called_once_with(self.node.uuid, mock.ANY, mock.ANY)
@mock.patch.object(console_utils, 'stop_shellinabox_console',
autospec=True)
def test_stop_console(self, mock_exec):
mock_exec.return_value = None
with task_manager.acquire(self.context,
self.node.uuid) as task:
self.driver.console.stop_console(task)
mock_exec.assert_called_once_with(self.node.uuid)
@mock.patch.object(console_utils, 'stop_shellinabox_console',
autospec=True)
def test_stop_console_fail(self, mock_stop):
mock_stop.side_effect = exception.ConsoleError()
with task_manager.acquire(self.context,
self.node.uuid) as task:
self.assertRaises(exception.ConsoleError,
self.driver.console.stop_console,
task)
mock_stop.assert_called_once_with(self.node.uuid)
@mock.patch.object(console_utils, 'get_shellinabox_console_url',
autospec=True)
def test_get_console(self, mock_exec):
url = 'http://localhost:4201'
mock_exec.return_value = url
expected = {'type': 'shellinabox', 'url': url}
with task_manager.acquire(self.context,
self.node.uuid) as task:
task.node.driver_info['ssh_terminal_port'] = 6900
console_info = self.driver.console.get_console(task)
self.assertEqual(expected, console_info)
mock_exec.assert_called_once_with(6900)

View File

@ -0,0 +1,3 @@
---
features:
- Adds ShellinaboxConsole support for virsh SSH driver.