Refactor driver loading to load a driver instance per node

This paves the way for composable drivers. It:

Creates a BareDriver class, which is a minimal subclass of BaseDriver,
with no interfaces attached.

Adds a method driver_factory.build_driver_for_task, that accepts a
task argument and builds an instance of BareDriver that does have
interfaces attached. These interfaces come from loading the driver in
node.driver, and attaching each of the core, standard, and vendor
interfaces to the BareDriver created. This also accepts a driver_name
argument, for loading a driver that is not the one specified in
node.driver (for example, when updating the driver for a node).
This method will eventually need to take arguments for each interface
that is broken out of the main driver singleton.

By doing this, we create a driver instance per node, instead of using
the monolithic driver singletons shared across nodes. Note that the
attached interfaces are references to the interfaces in the driver
singleton, and thus themselves still singletons. It is *which* interface
implementations are referenced that will vary by node.

This means that in the future, we can dynamically load and attach these
interfaces, with the implementation chosen being defined by a property
of the node.

This patch also does a small refactoring as to how the list of
interfaces attached to a driver are referenced, for cleanliness.

Change-Id: Ic2b2525f2abd0d252f36442097e68f73aeaec9f7
This commit is contained in:
Jim Rollenhagen 2016-02-29 01:16:03 +00:00
parent 22aa8a9eb4
commit 7f46a03fca
7 changed files with 115 additions and 61 deletions

View File

@ -23,6 +23,7 @@ from stevedore import dispatch
from ironic.common import exception from ironic.common import exception
from ironic.common.i18n import _ from ironic.common.i18n import _
from ironic.common.i18n import _LI from ironic.common.i18n import _LI
from ironic.drivers import base as driver_base
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
@ -47,6 +48,34 @@ CONF.register_opts(driver_opts)
EM_SEMAPHORE = 'extension_manager' EM_SEMAPHORE = 'extension_manager'
def build_driver_for_task(task, driver_name=None):
"""Builds a composable driver for a given task.
Starts with a `BareDriver` object, and attaches implementations of the
various driver interfaces to it. Currently these all come from the
monolithic driver singleton, but later will come from separate
driver factories and configurable via the database.
:param task: The task containing the node to build a driver for.
:param driver_name: The name of the monolithic driver to use as a base,
if different than task.node.driver.
:returns: A driver object for the task.
:raises: DriverNotFound if node.driver could not be
found in the "ironic.drivers" namespace.
"""
node = task.node
driver = driver_base.BareDriver()
_attach_interfaces_to_driver(driver, node, driver_name=driver_name)
return driver
def _attach_interfaces_to_driver(driver, node, driver_name=None):
driver_singleton = get_driver(driver_name or node.driver)
for iface in driver_singleton.all_interfaces:
impl = getattr(driver_singleton, iface, None)
setattr(driver, iface, impl)
def get_driver(driver_name): def get_driver(driver_name):
"""Simple method to get a ref to an instance of a driver. """Simple method to get a ref to an instance of a driver.

View File

@ -119,9 +119,7 @@ class BaseConductorManager(object):
self._collect_periodic_tasks(self, (admin_context,)) self._collect_periodic_tasks(self, (admin_context,))
for driver_obj in drivers.values(): for driver_obj in drivers.values():
self._collect_periodic_tasks(driver_obj, (self, admin_context)) self._collect_periodic_tasks(driver_obj, (self, admin_context))
for iface_name in (driver_obj.core_interfaces + for iface_name in driver_obj.all_interfaces:
driver_obj.standard_interfaces +
['vendor']):
iface = getattr(driver_obj, iface_name, None) iface = getattr(driver_obj, iface_name, None)
if iface and iface.__class__ not in periodic_task_classes: if iface and iface.__class__ not in periodic_task_classes:
self._collect_periodic_tasks(iface, (self, admin_context)) self._collect_periodic_tasks(iface, (self, admin_context))

View File

@ -1468,8 +1468,7 @@ class ConductorManager(base_manager.BaseConductorManager):
iwdi = images.is_whole_disk_image(context, iwdi = images.is_whole_disk_image(context,
task.node.instance_info) task.node.instance_info)
task.node.driver_internal_info['is_whole_disk_image'] = iwdi task.node.driver_internal_info['is_whole_disk_image'] = iwdi
for iface_name in (task.driver.core_interfaces + for iface_name in task.driver.non_vendor_interfaces:
task.driver.standard_interfaces):
iface = getattr(task.driver, iface_name, None) iface = getattr(task.driver, iface_name, None)
result = reason = None result = reason = None
if iface: if iface:

View File

@ -208,8 +208,8 @@ class TaskManager(object):
self.ports = objects.Port.list_by_node_id(context, self.node.id) self.ports = objects.Port.list_by_node_id(context, self.node.id)
self.portgroups = objects.Portgroup.list_by_node_id(context, self.portgroups = objects.Portgroup.list_by_node_id(context,
self.node.id) self.node.id)
self.driver = driver_factory.get_driver(driver_name or self.driver = driver_factory.build_driver_for_task(
self.node.driver) self, driver_name=driver_name)
# NOTE(deva): this handles the Juno-era NOSTATE state # NOTE(deva): this handles the Juno-era NOSTATE state
# and should be deleted after Kilo is released # and should be deleted after Kilo is released

View File

@ -41,7 +41,6 @@ RAID_CONFIG_SCHEMA = os.path.join(os.path.dirname(__file__),
CONF = cfg.CONF CONF = cfg.CONF
CONF.import_opt('periodic_interval', 'ironic.common.service')
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
@ -132,6 +131,14 @@ class BaseDriver(object):
def __init__(self): def __init__(self):
pass pass
@property
def all_interfaces(self):
return self.core_interfaces + self.standard_interfaces + ['vendor']
@property
def non_vendor_interfaces(self):
return self.core_interfaces + self.standard_interfaces
def get_properties(self): def get_properties(self):
"""Get the properties of the driver. """Get the properties of the driver.
@ -139,15 +146,23 @@ class BaseDriver(object):
""" """
properties = {} properties = {}
for iface_name in (self.core_interfaces + for iface_name in self.all_interfaces:
self.standard_interfaces +
['vendor']):
iface = getattr(self, iface_name, None) iface = getattr(self, iface_name, None)
if iface: if iface:
properties.update(iface.get_properties()) properties.update(iface.get_properties())
return properties return properties
class BareDriver(BaseDriver):
"""A bare driver object which will have interfaces attached later.
Any composable interfaces should be added as class attributes of this
class, as well as appended to core_interfaces or standard_interfaces here.
"""
def __init__(self):
pass
class BaseInterface(object): class BaseInterface(object):
"""A base interface implementing common functions for Driver Interfaces.""" """A base interface implementing common functions for Driver Interfaces."""
interface_type = 'base' interface_type = 'base'
@ -1145,6 +1160,11 @@ def driver_periodic_task(**kwargs):
new_kwargs[arg] = kwargs.pop(arg) new_kwargs[arg] = kwargs.pop(arg)
except KeyError: except KeyError:
pass pass
# NOTE(jroll) this is here to avoid a circular import when a module
# imports ironic.common.service. Normally I would balk at this, but this
# option is deprecared for removal and this code only runs at startup.
CONF.import_opt('periodic_interval', 'ironic.common.service')
new_kwargs.setdefault('spacing', CONF.periodic_interval) new_kwargs.setdefault('spacing', CONF.periodic_interval)
if kwargs: if kwargs:

View File

@ -112,6 +112,7 @@ class StartStopTestCase(mgr_utils.ServiceSetUpMixin, tests_db_base.DbTestCase):
class Driver(object): class Driver(object):
core_interfaces = [] core_interfaces = []
standard_interfaces = ['iface'] standard_interfaces = ['iface']
all_interfaces = core_interfaces + standard_interfaces
iface = TestInterface() iface = TestInterface()

View File

@ -34,7 +34,7 @@ from ironic.tests.unit.objects import utils as obj_utils
@mock.patch.object(objects.Node, 'get') @mock.patch.object(objects.Node, 'get')
@mock.patch.object(objects.Node, 'release') @mock.patch.object(objects.Node, 'release')
@mock.patch.object(objects.Node, 'reserve') @mock.patch.object(objects.Node, 'reserve')
@mock.patch.object(driver_factory, 'get_driver') @mock.patch.object(driver_factory, 'build_driver_for_task')
@mock.patch.object(objects.Port, 'list_by_node_id') @mock.patch.object(objects.Port, 'list_by_node_id')
@mock.patch.object(objects.Portgroup, 'list_by_node_id') @mock.patch.object(objects.Portgroup, 'list_by_node_id')
class TaskManagerTestCase(tests_db_base.DbTestCase): class TaskManagerTestCase(tests_db_base.DbTestCase):
@ -48,7 +48,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.future_mock = mock.Mock(spec=['cancel', 'add_done_callback']) self.future_mock = mock.Mock(spec=['cancel', 'add_done_callback'])
def test_excl_lock(self, get_portgroups_mock, get_ports_mock, def test_excl_lock(self, get_portgroups_mock, get_ports_mock,
get_driver_mock, reserve_mock, release_mock, build_driver_mock, reserve_mock, release_mock,
node_get_mock): node_get_mock):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
with task_manager.TaskManager(self.context, 'fake-node-id') as task: with task_manager.TaskManager(self.context, 'fake-node-id') as task:
@ -56,20 +56,20 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
self.assertEqual(get_ports_mock.return_value, task.ports) self.assertEqual(get_ports_mock.return_value, task.ports)
self.assertEqual(get_portgroups_mock.return_value, task.portgroups) self.assertEqual(get_portgroups_mock.return_value, task.portgroups)
self.assertEqual(get_driver_mock.return_value, task.driver) self.assertEqual(build_driver_mock.return_value, task.driver)
self.assertFalse(task.shared) self.assertFalse(task.shared)
build_driver_mock.assert_called_once_with(task, driver_name=None)
reserve_mock.assert_called_once_with(self.context, self.host, reserve_mock.assert_called_once_with(self.context, self.host,
'fake-node-id') 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with(self.node.driver)
release_mock.assert_called_once_with(self.context, self.host, release_mock.assert_called_once_with(self.context, self.host,
self.node.id) self.node.id)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_lock_with_driver( def test_excl_lock_with_driver(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
with task_manager.TaskManager(self.context, 'fake-node-id', with task_manager.TaskManager(self.context, 'fake-node-id',
@ -78,20 +78,21 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
self.assertEqual(get_ports_mock.return_value, task.ports) self.assertEqual(get_ports_mock.return_value, task.ports)
self.assertEqual(get_portgroups_mock.return_value, task.portgroups) self.assertEqual(get_portgroups_mock.return_value, task.portgroups)
self.assertEqual(get_driver_mock.return_value, task.driver) self.assertEqual(build_driver_mock.return_value, task.driver)
self.assertFalse(task.shared) self.assertFalse(task.shared)
build_driver_mock.assert_called_once_with(
task, driver_name='fake-driver')
reserve_mock.assert_called_once_with(self.context, self.host, reserve_mock.assert_called_once_with(self.context, self.host,
'fake-node-id') 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with('fake-driver')
release_mock.assert_called_once_with(self.context, self.host, release_mock.assert_called_once_with(self.context, self.host,
self.node.id) self.node.id)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_nested_acquire( def test_excl_nested_acquire(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node2 = obj_utils.create_test_node(self.context, node2 = obj_utils.create_test_node(self.context,
uuid=uuidutils.generate_uuid(), uuid=uuidutils.generate_uuid(),
@ -100,13 +101,13 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
get_ports_mock.return_value = mock.sentinel.ports1 get_ports_mock.return_value = mock.sentinel.ports1
get_portgroups_mock.return_value = mock.sentinel.portgroups1 get_portgroups_mock.return_value = mock.sentinel.portgroups1
get_driver_mock.return_value = mock.sentinel.driver1 build_driver_mock.return_value = mock.sentinel.driver1
with task_manager.TaskManager(self.context, 'node-id1') as task: with task_manager.TaskManager(self.context, 'node-id1') as task:
reserve_mock.return_value = node2 reserve_mock.return_value = node2
get_ports_mock.return_value = mock.sentinel.ports2 get_ports_mock.return_value = mock.sentinel.ports2
get_portgroups_mock.return_value = mock.sentinel.portgroups2 get_portgroups_mock.return_value = mock.sentinel.portgroups2
get_driver_mock.return_value = mock.sentinel.driver2 build_driver_mock.return_value = mock.sentinel.driver2
with task_manager.TaskManager(self.context, 'node-id2') as task2: with task_manager.TaskManager(self.context, 'node-id2') as task2:
self.assertEqual(self.context, task.context) self.assertEqual(self.context, task.context)
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
@ -121,15 +122,16 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(mock.sentinel.driver2, task2.driver) self.assertEqual(mock.sentinel.driver2, task2.driver)
self.assertFalse(task2.shared) self.assertFalse(task2.shared)
self.assertEqual([mock.call(task, driver_name=None),
mock.call(task2, driver_name=None)],
build_driver_mock.call_args_list)
self.assertEqual([mock.call(self.context, self.host, 'node-id1'), self.assertEqual([mock.call(self.context, self.host, 'node-id1'),
mock.call(self.context, self.host, 'node-id2')], mock.call(self.context, self.host, 'node-id2')],
reserve_mock.call_args_list) reserve_mock.call_args_list)
self.assertEqual([mock.call(self.context, self.node.id), self.assertEqual([mock.call(self.context, self.node.id),
mock.call(self.context, node2.id)], mock.call(self.context, node2.id)],
get_ports_mock.call_args_list) get_ports_mock.call_args_list)
self.assertEqual([mock.call(self.node.driver),
mock.call(node2.driver)],
get_driver_mock.call_args_list)
# release should be in reverse order # release should be in reverse order
self.assertEqual([mock.call(self.context, self.host, node2.id), self.assertEqual([mock.call(self.context, self.host, node2.id),
mock.call(self.context, self.host, self.node.id)], mock.call(self.context, self.host, self.node.id)],
@ -137,7 +139,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_lock_exception_then_lock( def test_excl_lock_exception_then_lock(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
retry_attempts = 3 retry_attempts = 3
self.config(node_locked_retry_attempts=retry_attempts, self.config(node_locked_retry_attempts=retry_attempts,
@ -157,7 +159,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(2, reserve_mock.call_count) self.assertEqual(2, reserve_mock.call_count)
def test_excl_lock_reserve_exception( def test_excl_lock_reserve_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
retry_attempts = 3 retry_attempts = 3
self.config(node_locked_retry_attempts=retry_attempts, self.config(node_locked_retry_attempts=retry_attempts,
@ -175,12 +177,12 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(retry_attempts, reserve_mock.call_count) self.assertEqual(retry_attempts, reserve_mock.call_count)
self.assertFalse(get_ports_mock.called) self.assertFalse(get_ports_mock.called)
self.assertFalse(get_portgroups_mock.called) self.assertFalse(get_portgroups_mock.called)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
self.assertFalse(release_mock.called) self.assertFalse(release_mock.called)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_lock_get_ports_exception( def test_excl_lock_get_ports_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
get_ports_mock.side_effect = exception.IronicException('foo') get_ports_mock.side_effect = exception.IronicException('foo')
@ -193,13 +195,13 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
reserve_mock.assert_called_once_with(self.context, self.host, reserve_mock.assert_called_once_with(self.context, self.host,
'fake-node-id') 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
release_mock.assert_called_once_with(self.context, self.host, release_mock.assert_called_once_with(self.context, self.host,
self.node.id) self.node.id)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_lock_get_portgroups_exception( def test_excl_lock_get_portgroups_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
get_portgroups_mock.side_effect = exception.IronicException('foo') get_portgroups_mock.side_effect = exception.IronicException('foo')
@ -212,16 +214,16 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
reserve_mock.assert_called_once_with(self.context, self.host, reserve_mock.assert_called_once_with(self.context, self.host,
'fake-node-id') 'fake-node-id')
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
release_mock.assert_called_once_with(self.context, self.host, release_mock.assert_called_once_with(self.context, self.host,
self.node.id) self.node.id)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_excl_lock_get_driver_exception( def test_excl_lock_build_driver_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
get_driver_mock.side_effect = ( build_driver_mock.side_effect = (
exception.DriverNotFound(driver_name='foo')) exception.DriverNotFound(driver_name='foo'))
self.assertRaises(exception.DriverNotFound, self.assertRaises(exception.DriverNotFound,
@ -233,13 +235,13 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
'fake-node-id') 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with(self.node.driver) build_driver_mock.assert_called_once_with(mock.ANY, driver_name=None)
release_mock.assert_called_once_with(self.context, self.host, release_mock.assert_called_once_with(self.context, self.host,
self.node.id) self.node.id)
self.assertFalse(node_get_mock.called) self.assertFalse(node_get_mock.called)
def test_shared_lock( def test_shared_lock(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
with task_manager.TaskManager(self.context, 'fake-node-id', with task_manager.TaskManager(self.context, 'fake-node-id',
@ -248,18 +250,19 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
self.assertEqual(get_ports_mock.return_value, task.ports) self.assertEqual(get_ports_mock.return_value, task.ports)
self.assertEqual(get_portgroups_mock.return_value, task.portgroups) self.assertEqual(get_portgroups_mock.return_value, task.portgroups)
self.assertEqual(get_driver_mock.return_value, task.driver) self.assertEqual(build_driver_mock.return_value, task.driver)
self.assertTrue(task.shared) self.assertTrue(task.shared)
build_driver_mock.assert_called_once_with(task, driver_name=None)
self.assertFalse(reserve_mock.called) self.assertFalse(reserve_mock.called)
self.assertFalse(release_mock.called) self.assertFalse(release_mock.called)
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with(self.node.driver)
def test_shared_lock_with_driver( def test_shared_lock_with_driver(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
with task_manager.TaskManager(self.context, with task_manager.TaskManager(self.context,
@ -270,18 +273,20 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
self.assertEqual(get_ports_mock.return_value, task.ports) self.assertEqual(get_ports_mock.return_value, task.ports)
self.assertEqual(get_portgroups_mock.return_value, task.portgroups) self.assertEqual(get_portgroups_mock.return_value, task.portgroups)
self.assertEqual(get_driver_mock.return_value, task.driver) self.assertEqual(build_driver_mock.return_value, task.driver)
self.assertTrue(task.shared) self.assertTrue(task.shared)
build_driver_mock.assert_called_once_with(
task, driver_name='fake-driver')
self.assertFalse(reserve_mock.called) self.assertFalse(reserve_mock.called)
self.assertFalse(release_mock.called) self.assertFalse(release_mock.called)
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with('fake-driver')
def test_shared_lock_node_get_exception( def test_shared_lock_node_get_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.side_effect = exception.NodeNotFound(node='foo') node_get_mock.side_effect = exception.NodeNotFound(node='foo')
@ -296,10 +301,10 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
self.assertFalse(get_ports_mock.called) self.assertFalse(get_ports_mock.called)
self.assertFalse(get_portgroups_mock.called) self.assertFalse(get_portgroups_mock.called)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
def test_shared_lock_get_ports_exception( def test_shared_lock_get_ports_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
get_ports_mock.side_effect = exception.IronicException('foo') get_ports_mock.side_effect = exception.IronicException('foo')
@ -314,10 +319,10 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertFalse(release_mock.called) self.assertFalse(release_mock.called)
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
def test_shared_lock_get_portgroups_exception( def test_shared_lock_get_portgroups_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
get_portgroups_mock.side_effect = exception.IronicException('foo') get_portgroups_mock.side_effect = exception.IronicException('foo')
@ -332,13 +337,13 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertFalse(release_mock.called) self.assertFalse(release_mock.called)
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
self.assertFalse(get_driver_mock.called) self.assertFalse(build_driver_mock.called)
def test_shared_lock_get_driver_exception( def test_shared_lock_build_driver_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
get_driver_mock.side_effect = ( build_driver_mock.side_effect = (
exception.DriverNotFound(driver_name='foo')) exception.DriverNotFound(driver_name='foo'))
self.assertRaises(exception.DriverNotFound, self.assertRaises(exception.DriverNotFound,
@ -352,10 +357,10 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with(self.node.driver) build_driver_mock.assert_called_once_with(mock.ANY, driver_name=None)
def test_upgrade_lock( def test_upgrade_lock(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
node_get_mock.return_value = self.node node_get_mock.return_value = self.node
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
@ -365,7 +370,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertEqual(self.node, task.node) self.assertEqual(self.node, task.node)
self.assertEqual(get_ports_mock.return_value, task.ports) self.assertEqual(get_ports_mock.return_value, task.ports)
self.assertEqual(get_portgroups_mock.return_value, task.portgroups) self.assertEqual(get_portgroups_mock.return_value, task.portgroups)
self.assertEqual(get_driver_mock.return_value, task.driver) self.assertEqual(build_driver_mock.return_value, task.driver)
self.assertTrue(task.shared) self.assertTrue(task.shared)
self.assertFalse(reserve_mock.called) self.assertFalse(reserve_mock.called)
@ -375,6 +380,9 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
task.upgrade_lock() task.upgrade_lock()
self.assertFalse(task.shared) self.assertFalse(task.shared)
build_driver_mock.assert_called_once_with(mock.ANY,
driver_name=None)
# make sure reserve() was called only once # make sure reserve() was called only once
reserve_mock.assert_called_once_with(self.context, self.host, reserve_mock.assert_called_once_with(self.context, self.host,
'fake-node-id') 'fake-node-id')
@ -383,10 +391,9 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
node_get_mock.assert_called_once_with(self.context, 'fake-node-id') node_get_mock.assert_called_once_with(self.context, 'fake-node-id')
get_ports_mock.assert_called_once_with(self.context, self.node.id) get_ports_mock.assert_called_once_with(self.context, self.node.id)
get_portgroups_mock.assert_called_once_with(self.context, self.node.id) get_portgroups_mock.assert_called_once_with(self.context, self.node.id)
get_driver_mock.assert_called_once_with(self.node.driver)
def test_spawn_after( def test_spawn_after(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
spawn_mock = mock.Mock(return_value=self.future_mock) spawn_mock = mock.Mock(return_value=self.future_mock)
task_release_mock = mock.Mock() task_release_mock = mock.Mock()
@ -406,7 +413,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
self.assertFalse(task_release_mock.called) self.assertFalse(task_release_mock.called)
def test_spawn_after_exception_while_yielded( def test_spawn_after_exception_while_yielded(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
spawn_mock = mock.Mock() spawn_mock = mock.Mock()
task_release_mock = mock.Mock() task_release_mock = mock.Mock()
@ -423,7 +430,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
task_release_mock.assert_called_once_with() task_release_mock.assert_called_once_with()
def test_spawn_after_spawn_fails( def test_spawn_after_spawn_fails(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
spawn_mock = mock.Mock(side_effect=exception.IronicException('foo')) spawn_mock = mock.Mock(side_effect=exception.IronicException('foo'))
task_release_mock = mock.Mock() task_release_mock = mock.Mock()
@ -440,7 +447,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
task_release_mock.assert_called_once_with() task_release_mock.assert_called_once_with()
def test_spawn_after_link_fails( def test_spawn_after_link_fails(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
self.future_mock.add_done_callback.side_effect = ( self.future_mock.add_done_callback.side_effect = (
exception.IronicException('foo')) exception.IronicException('foo'))
@ -463,7 +470,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
task_release_mock.assert_called_once_with() task_release_mock.assert_called_once_with()
def test_spawn_after_on_error_hook( def test_spawn_after_on_error_hook(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
expected_exception = exception.IronicException('foo') expected_exception = exception.IronicException('foo')
spawn_mock = mock.Mock(side_effect=expected_exception) spawn_mock = mock.Mock(side_effect=expected_exception)
@ -485,7 +492,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
'fake-argument') 'fake-argument')
def test_spawn_after_on_error_hook_exception( def test_spawn_after_on_error_hook_exception(
self, get_portgroups_mock, get_ports_mock, get_driver_mock, self, get_portgroups_mock, get_ports_mock, build_driver_mock,
reserve_mock, release_mock, node_get_mock): reserve_mock, release_mock, node_get_mock):
expected_exception = exception.IronicException('foo') expected_exception = exception.IronicException('foo')
spawn_mock = mock.Mock(side_effect=expected_exception) spawn_mock = mock.Mock(side_effect=expected_exception)
@ -512,7 +519,7 @@ class TaskManagerTestCase(tests_db_base.DbTestCase):
@mock.patch.object(states.machine, 'copy') @mock.patch.object(states.machine, 'copy')
def test_init_prepares_fsm( def test_init_prepares_fsm(
self, copy_mock, get_portgroups_mock, get_ports_mock, self, copy_mock, get_portgroups_mock, get_ports_mock,
get_driver_mock, reserve_mock, release_mock, node_get_mock): build_driver_mock, reserve_mock, release_mock, node_get_mock):
m = mock.Mock(spec=fsm.FSM) m = mock.Mock(spec=fsm.FSM)
reserve_mock.return_value = self.node reserve_mock.return_value = self.node
copy_mock.return_value = m copy_mock.return_value = m