diff --git a/neutron/plugins/ml2/drivers/mech_agent.py b/neutron/plugins/ml2/drivers/mech_agent.py index b33de46ad3e..f5b67bee712 100644 --- a/neutron/plugins/ml2/drivers/mech_agent.py +++ b/neutron/plugins/ml2/drivers/mech_agent.py @@ -49,6 +49,7 @@ class AgentMechanismDriverBase(api.MechanismDriver, metaclass=abc.ABCMeta): :param agent_type: Constant identifying agent type in agents_db :param supported_vnic_types: The binding:vnic_type values we can bind """ + super(AgentMechanismDriverBase, self).__init__() self.agent_type = agent_type self.supported_vnic_types = supported_vnic_types diff --git a/neutron/plugins/ml2/plugin.py b/neutron/plugins/ml2/plugin.py index ee21558a76e..ab642c501d1 100644 --- a/neutron/plugins/ml2/plugin.py +++ b/neutron/plugins/ml2/plugin.py @@ -227,7 +227,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, sg_rpc.disable_security_group_extension_by_config(aliases) vlantransparent._disable_extension_by_config(aliases) filter_validation._disable_extension_by_config(aliases) - self._aliases = aliases + self._aliases = self._filter_extensions_by_mech_driver(aliases) return self._aliases def __new__(cls, *args, **kwargs): @@ -292,6 +292,17 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, driver=extension_driver, service_plugin=service_plugin ) + def _filter_extensions_by_mech_driver(self, aliases): + """Return the supported extensions by the loaded mech drivers""" + if not self.mechanism_manager.ordered_mech_drivers: + return aliases + + supported_extensions = set([]) + for mech_driver in self.mechanism_manager.ordered_mech_drivers: + supported_extensions |= mech_driver.obj.supported_extensions( + set(aliases)) + return list(supported_extensions) + @registry.receives(resources.PORT, [provisioning_blocks.PROVISIONING_COMPLETE]) def _port_provisioned(self, rtype, event, trigger, payload=None): diff --git a/neutron/tests/unit/plugins/ml2/drivers/mechanism_logger.py b/neutron/tests/unit/plugins/ml2/drivers/mechanism_logger.py index e7d3ba41fc0..2f4a2c08f6e 100644 --- a/neutron/tests/unit/plugins/ml2/drivers/mechanism_logger.py +++ b/neutron/tests/unit/plugins/ml2/drivers/mechanism_logger.py @@ -25,6 +25,9 @@ class LoggerMechanismDriver(api.MechanismDriver): Generally used for testing and debugging. """ + def __init__(self, *args, **kwargs): + super(LoggerMechanismDriver, self).__init__(*args, **kwargs) + self._supported_extensions = set() def initialize(self): pass @@ -140,3 +143,8 @@ class LoggerMechanismDriver(api.MechanismDriver): "%(segments)s, candidate hosts %(hosts)s ", {'segments': segments, 'hosts': candidate_hosts}) return set() + + def supported_extensions(self, extensions): + if self._supported_extensions: + return extensions & self._supported_extensions + return extensions diff --git a/neutron/tests/unit/plugins/ml2/drivers/mechanism_test.py b/neutron/tests/unit/plugins/ml2/drivers/mechanism_test.py index a17ed2d987a..b4ed8109a41 100644 --- a/neutron/tests/unit/plugins/ml2/drivers/mechanism_test.py +++ b/neutron/tests/unit/plugins/ml2/drivers/mechanism_test.py @@ -32,6 +32,7 @@ class TestMechanismDriver(api.MechanismDriver): def __init__(self, *args, **kwargs): super(TestMechanismDriver, self).__init__(*args, **kwargs) self._supported_vnic_types = ('test_mechanism_driver_vnic_type', ) + self._supported_extensions = set([]) def initialize(self): self.bound_ports = set() @@ -267,6 +268,11 @@ class TestMechanismDriver(api.MechanismDriver): def get_standard_device_mappings(self, agent): return {} + def supported_extensions(self, extensions): + if self._supported_extensions: + return extensions & self._supported_extensions + return extensions + class TestMechanismDriverWithAgent(mech_agent.AgentMechanismDriverBase, TestMechanismDriver): diff --git a/neutron/tests/unit/plugins/ml2/test_plugin.py b/neutron/tests/unit/plugins/ml2/test_plugin.py index a9e00c12327..62244748565 100644 --- a/neutron/tests/unit/plugins/ml2/test_plugin.py +++ b/neutron/tests/unit/plugins/ml2/test_plugin.py @@ -166,6 +166,24 @@ class TestMl2BulkToggleWithoutBulkless(Ml2PluginV2TestCase): self.assertFalse(self._skip_native_bulk) +class TestMl2FilterExtensions(Ml2PluginV2TestCase): + + _mechanism_drivers = ['logger', 'test'] + + def test__filter_extensions_by_mech_driver(self): + extension_aliases = ['ext1', 'ext2', 'ext3', 'ext4', 'ext5'] + supported_aliases = [{'ext0', 'ext1', 'ext2'}, + {'ext4', 'ext5', 'ext6'}] + for idx, mech_driver in enumerate( + self.plugin.mechanism_manager.ordered_mech_drivers): + mech_driver.obj._supported_extensions = supported_aliases[idx] + + supported_extensions = sorted( + self.plugin._filter_extensions_by_mech_driver(extension_aliases)) + self.assertEqual(['ext1', 'ext2', 'ext4', 'ext5'], + supported_extensions) + + class TestMl2BasicGet(test_plugin.TestBasicGet, Ml2PluginV2TestCase): pass