diff --git a/zun/common/exception.py b/zun/common/exception.py index 67ca7d4f3..2d27a7ddf 100644 --- a/zun/common/exception.py +++ b/zun/common/exception.py @@ -503,3 +503,25 @@ class ValidationError(ZunException): class ResourcesUnavailable(ZunException): message = _("Insufficient compute resources: %(reason)s.") + + +class PciConfigInvalidWhitelist(Invalid): + msg_fmt = _("Invalid PCI devices Whitelist config %(reason)s") + + +class PciDeviceWrongAddressFormat(ZunException): + msg_fmt = _("The PCI address %(address)s has an incorrect format.") + + +class PciDeviceInvalidDeviceName(ZunException): + msg_fmt = _("Invalid PCI Whitelist: " + "The PCI whitelist can specify devname or address," + " but not both") + + +class PciDeviceNotFoundById(NotFound): + msg_fmt = _("PCI device %(id)s not found") + + +class PciDeviceNotFound(NotFound): + msg_fmt = _("PCI Device %(node_id)s:%(address)s not found.") diff --git a/zun/pci/__init__.py b/zun/pci/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zun/pci/utils.py b/zun/pci/utils.py new file mode 100644 index 000000000..bcc9fcaaf --- /dev/null +++ b/zun/pci/utils.py @@ -0,0 +1,183 @@ +# Copyright (c) 2017 Intel, Inc. +# Copyright (c) 2017 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import glob +import os +import re + +from oslo_log import log as logging + + +from zun.common import exception + +LOG = logging.getLogger(__name__) + +PCI_VENDOR_PATTERN = "^(hex{4})$".replace("hex", "[\da-fA-F]") +_PCI_ADDRESS_PATTERN = ("^(hex{4}):(hex{2}):(hex{2}).(oct{1})$". + replace("hex", "[\da-fA-F]"). + replace("oct", "[0-7]")) +_PCI_ADDRESS_REGEX = re.compile(_PCI_ADDRESS_PATTERN) + +_SRIOV_TOTALVFS = "sriov_totalvfs" + + +def pci_device_prop_match(pci_dev, specs): + """Check if the pci_dev meet spec requirement + + Specs is a list of PCI device property requirements. + An example of device requirement that the PCI should be either: + a) Device with vendor_id as 0x8086 and product_id as 0x8259, or + b) Device with vendor_id as 0x10de and product_id as 0x10d8: + + [{"vendor_id":"8086", "product_id":"8259"}, + {"vendor_id":"10de", "product_id":"10d8", + "capabilities_network": ["rx", "tx", "tso", "gso"]}] + + """ + def _matching_devices(spec): + for k, v in spec.items(): + pci_dev_v = pci_dev.get(k) + if isinstance(v, list) and isinstance(pci_dev_v, list): + if not all(x in pci_dev.get(k) for x in v): + return False + elif pci_dev_v != v: + return False + return True + + return any(_matching_devices(spec) for spec in specs) + + +def parse_address(address): + """Returns (domain, bus, slot, function) from PCI address + + Which is stored in PciDevice DB table. + """ + m = _PCI_ADDRESS_REGEX.match(address) + if not m: + raise exception.PciDeviceWrongAddressFormat(address=address) + return m.groups() + + +def get_pci_address_fields(pci_addr): + dbs, sep, func = pci_addr.partition('.') + domain, bus, slot = dbs.split(':') + return (domain, bus, slot, func) + + +def get_pci_address(domain, bus, slot, func): + return '%s:%s:%s.%s' % (domain, bus, slot, func) + + +def get_function_by_ifname(ifname): + """Get the function by the interface name + + Given the device name, returns the PCI address of a device + and returns True if the address is in a physical function. + """ + dev_path = "/sys/class/net/%s/device" % ifname + sriov_totalvfs = 0 + if os.path.isdir(dev_path): + try: + # sriov_totalvfs contains the maximum possible VFs for this PF + with open(os.path.join(dev_path, _SRIOV_TOTALVFS)) as fd: + sriov_totalvfs = int(fd.read()) + return (os.readlink(dev_path).strip("./"), + sriov_totalvfs > 0) + except (IOError, ValueError): + return os.readlink(dev_path).strip("./"), False + return None, False + + +def is_physical_function(domain, bus, slot, function): + dev_path = "/sys/bus/pci/devices/%(d)s:%(b)s:%(s)s.%(f)s/" % { + "d": domain, "b": bus, "s": slot, "f": function} + if os.path.isdir(dev_path): + sriov_totalvfs = 0 + try: + with open(dev_path + _SRIOV_TOTALVFS) as fd: + sriov_totalvfs = int(fd.read()) + return sriov_totalvfs > 0 + except (IOError, ValueError): + pass + return False + + +def _get_sysfs_netdev_path(pci_addr, pf_interface): + """Get the sysfs path based on the PCI address of the device. + + Assumes a networking device - will not check for the existence of the path. + """ + if pf_interface: + return "/sys/bus/pci/devices/%s/physfn/net" % (pci_addr) + return "/sys/bus/pci/devices/%s/net" % (pci_addr) + + +def get_ifname_by_pci_address(pci_addr, pf_interface=False): + """Get the interface name based on a VF's pci address + + The returned interface name is either the parent PF's or that of the VF + itself based on the argument of pf_interface. + """ + dev_path = _get_sysfs_netdev_path(pci_addr, pf_interface) + try: + dev_info = os.listdir(dev_path) + return dev_info.pop() + except Exception: + raise exception.PciDeviceNotFoundById(id=pci_addr) + + +def get_mac_by_pci_address(pci_addr, pf_interface=False): + """Get the MAC address of the nic based on it's PCI address + + Raises PciDeviceNotFoundById in case the pci device is not a NIC + """ + dev_path = _get_sysfs_netdev_path(pci_addr, pf_interface) + if_name = get_ifname_by_pci_address(pci_addr, pf_interface) + addr_file = os.path.join(dev_path, if_name, 'address') + + try: + with open(addr_file) as f: + mac = next(f).strip() + return mac + except (IOError, StopIteration) as e: + LOG.warning("Could not find the expected sysfs file for " + "determining the MAC address of the PCI device " + "%(addr)s. May not be a NIC. Error: %(e)s", + {'addr': pci_addr, 'e': e}) + raise exception.PciDeviceNotFoundById(id=pci_addr) + + +def get_vf_num_by_pci_address(pci_addr): + """Get the VF number based on a VF's pci address + + A VF is associated with an VF number, which ip link command uses to + configure it. This number can be obtained from the PCI device filesystem. + """ + VIRTFN_RE = re.compile("virtfn(\d+)") + virtfns_path = "/sys/bus/pci/devices/%s/physfn/virtfn*" % (pci_addr) + vf_num = None + try: + for vf_path in glob.iglob(virtfns_path): + if re.search(pci_addr, os.readlink(vf_path)): + t = VIRTFN_RE.search(vf_path) + vf_num = t.group(1) + break + except Exception: + pass + if vf_num is None: + raise exception.PciDeviceNotFoundById(id=pci_addr) + return vf_num diff --git a/zun/tests/unit/pci/__init__.py b/zun/tests/unit/pci/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zun/tests/unit/pci/test_utils.py b/zun/tests/unit/pci/test_utils.py new file mode 100644 index 000000000..eb60071b6 --- /dev/null +++ b/zun/tests/unit/pci/test_utils.py @@ -0,0 +1,265 @@ +# Copyright (c) 2017 Intel, Inc. +# Copyright (c) 2017 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import glob +import os + +import fixtures +import mock +from six.moves import builtins + +from zun.common import exception +from zun.pci import utils +from zun.tests import base + + +class PciDeviceMatchTestCase(base.TestCase): + def setUp(self): + super(PciDeviceMatchTestCase, self).setUp() + self.fake_pci_1 = {'vendor_id': 'v1', + 'device_id': 'd1', + 'capabilities_network': ['cap1', 'cap2', 'cap3']} + + def test_single_spec_match(self): + self.assertTrue(utils.pci_device_prop_match( + self.fake_pci_1, [{'vendor_id': 'v1', 'device_id': 'd1'}])) + + def test_multiple_spec_match(self): + self.assertTrue(utils.pci_device_prop_match( + self.fake_pci_1, + [{'vendor_id': 'v1', 'device_id': 'd1'}, + {'vendor_id': 'v3', 'device_id': 'd3'}])) + + def test_spec_dismatch(self): + self.assertFalse(utils.pci_device_prop_match( + self.fake_pci_1, + [{'vendor_id': 'v4', 'device_id': 'd4'}, + {'vendor_id': 'v3', 'device_id': 'd3'}])) + + def test_spec_extra_key(self): + self.assertFalse(utils.pci_device_prop_match( + self.fake_pci_1, + [{'vendor_id': 'v1', 'device_id': 'd1', 'wrong_key': 'k1'}])) + + def test_spec_list(self): + self.assertTrue(utils.pci_device_prop_match( + self.fake_pci_1, [{'vendor_id': 'v1', 'device_id': 'd1', + 'capabilities_network': ['cap1', 'cap2', + 'cap3']}])) + self.assertTrue(utils.pci_device_prop_match( + self.fake_pci_1, [{'vendor_id': 'v1', 'device_id': 'd1', + 'capabilities_network': ['cap3', 'cap1']}])) + + def test_spec_list_no_matching(self): + self.assertFalse(utils.pci_device_prop_match( + self.fake_pci_1, [{'vendor_id': 'v1', 'device_id': 'd1', + 'capabilities_network': ['cap1', 'cap33']}])) + + def test_spec_list_wrong_type(self): + self.assertFalse(utils.pci_device_prop_match( + self.fake_pci_1, [{'vendor_id': 'v1', 'device_id': ['d1']}])) + + +class PciDeviceAddressParserTestCase(base.TestCase): + def test_parse_address(self): + self.parse_result = utils.parse_address("0000:04:12.6") + self.assertEqual(self.parse_result, ('0000', '04', '12', '6')) + + def test_parse_address_wrong(self): + self.assertRaises(exception.PciDeviceWrongAddressFormat, + utils.parse_address, "0000:04.12:6") + + def test_parse_address_invalid_character(self): + self.assertRaises(exception.PciDeviceWrongAddressFormat, + utils.parse_address, "0000:h4.12:6") + + +class GetFunctionByIfnameTestCase(base.TestCase): + + @mock.patch('os.path.isdir', return_value=True) + @mock.patch.object(os, 'readlink') + def test_virtual_function(self, mock_readlink, *args): + mock_readlink.return_value = '../../../0000.00.00.1' + with mock.patch.object(builtins, 'open', side_effect=IOError()): + address, physical_function = utils.get_function_by_ifname('eth0') + self.assertEqual(address, '0000.00.00.1') + self.assertFalse(physical_function) + + @mock.patch('os.path.isdir', return_value=True) + @mock.patch.object(os, 'readlink') + def test_physical_function(self, mock_readlink, *args): + ifname = 'eth0' + totalvf_path = "/sys/class/net/%s/device/%s" % (ifname, + utils._SRIOV_TOTALVFS) + mock_readlink.return_value = '../../../0000:00:00.1' + with mock.patch.object(builtins, 'open', + mock.mock_open(read_data='4')) as mock_open: + address, physical_function = utils.get_function_by_ifname('eth0') + self.assertEqual(address, '0000:00:00.1') + self.assertTrue(physical_function) + mock_open.assert_called_once_with(totalvf_path) + + @mock.patch('os.path.isdir', return_value=False) + def test_exception(self, *args): + address, physical_function = utils.get_function_by_ifname('lo') + self.assertIsNone(address) + self.assertFalse(physical_function) + + +class IsPhysicalFunctionTestCase(base.TestCase): + + def setUp(self): + super(IsPhysicalFunctionTestCase, self).setUp() + self.pci_args = utils.get_pci_address_fields('0000:00:00.1') + + @mock.patch('os.path.isdir', return_value=True) + def test_virtual_function(self, *args): + with mock.patch.object(builtins, 'open', side_effect=IOError()): + self.assertFalse(utils.is_physical_function(*self.pci_args)) + + @mock.patch('os.path.isdir', return_value=True) + def test_physical_function(self, *args): + with mock.patch.object(builtins, 'open', + mock.mock_open(read_data='4')): + self.assertTrue(utils.is_physical_function(*self.pci_args)) + + @mock.patch('os.path.isdir', return_value=False) + def test_exception(self, *args): + self.assertFalse(utils.is_physical_function(*self.pci_args)) + + +class GetIfnameByPciAddressTestCase(base.TestCase): + + def setUp(self): + super(GetIfnameByPciAddressTestCase, self).setUp() + self.pci_address = '0000:00:00.1' + + @mock.patch.object(os, 'listdir') + def test_physical_function_inferface_name(self, mock_listdir): + mock_listdir.return_value = ['foo', 'bar'] + ifname = utils.get_ifname_by_pci_address( + self.pci_address, pf_interface=True) + self.assertEqual(ifname, 'bar') + + @mock.patch.object(os, 'listdir') + def test_virtual_function_inferface_name(self, mock_listdir): + mock_listdir.return_value = ['foo', 'bar'] + ifname = utils.get_ifname_by_pci_address( + self.pci_address, pf_interface=False) + self.assertEqual(ifname, 'bar') + + @mock.patch.object(os, 'listdir') + def test_exception(self, mock_listdir): + mock_listdir.side_effect = OSError('No such file or directory') + self.assertRaises( + exception.PciDeviceNotFoundById, + utils.get_ifname_by_pci_address, + self.pci_address + ) + + +class GetMacByPciAddressTestCase(base.TestCase): + def setUp(self): + super(GetMacByPciAddressTestCase, self).setUp() + self.pci_address = '0000:07:00.1' + self.if_name = 'enp7s0f1' + self.tmpdir = self.useFixture(fixtures.TempDir()) + self.fake_file = os.path.join(self.tmpdir.path, "address") + with open(self.fake_file, "w") as f: + f.write("a0:36:9f:72:00:00\n") + + @mock.patch.object(os, 'listdir') + @mock.patch.object(os.path, 'join') + def test_get_mac(self, mock_join, mock_listdir): + mock_listdir.return_value = [self.if_name] + mock_join.return_value = self.fake_file + mac = utils.get_mac_by_pci_address(self.pci_address) + mock_join.assert_called_once_with( + "/sys/bus/pci/devices/%s/net" % self.pci_address, self.if_name, + "address") + self.assertEqual("a0:36:9f:72:00:00", mac) + + @mock.patch.object(os, 'listdir') + @mock.patch.object(os.path, 'join') + def test_get_mac_fails(self, mock_join, mock_listdir): + os.unlink(self.fake_file) + mock_listdir.return_value = [self.if_name] + mock_join.return_value = self.fake_file + self.assertRaises( + exception.PciDeviceNotFoundById, + utils.get_mac_by_pci_address, self.pci_address) + + @mock.patch.object(os, 'listdir') + @mock.patch.object(os.path, 'join') + def test_get_mac_fails_empty(self, mock_join, mock_listdir): + with open(self.fake_file, "w") as f: + f.truncate(0) + mock_listdir.return_value = [self.if_name] + mock_join.return_value = self.fake_file + self.assertRaises( + exception.PciDeviceNotFoundById, + utils.get_mac_by_pci_address, self.pci_address) + + @mock.patch.object(os, 'listdir') + @mock.patch.object(os.path, 'join') + def test_get_physical_function_mac(self, mock_join, mock_listdir): + mock_listdir.return_value = [self.if_name] + mock_join.return_value = self.fake_file + mac = utils.get_mac_by_pci_address(self.pci_address, pf_interface=True) + mock_join.assert_called_once_with( + "/sys/bus/pci/devices/%s/physfn/net" % self.pci_address, + self.if_name, "address") + self.assertEqual("a0:36:9f:72:00:00", mac) + + +class GetVfNumByPciAddressTestCase(base.TestCase): + + def setUp(self): + super(GetVfNumByPciAddressTestCase, self).setUp() + self.pci_address = '0000:00:00.1' + self.paths = [ + '/sys/bus/pci/devices/0000:00:00.1/physfn/virtfn3', + ] + + @mock.patch.object(os, 'readlink') + @mock.patch.object(glob, 'iglob') + def test_vf_number_found(self, mock_iglob, mock_readlink): + mock_iglob.return_value = self.paths + mock_readlink.return_value = '../../0000:00:00.1' + vf_num = utils.get_vf_num_by_pci_address(self.pci_address) + self.assertEqual(vf_num, '3') + + @mock.patch.object(os, 'readlink') + @mock.patch.object(glob, 'iglob') + def test_vf_number_not_found(self, mock_iglob, mock_readlink): + mock_iglob.return_value = self.paths + mock_readlink.return_value = '../../0000:00:00.2' + self.assertRaises( + exception.PciDeviceNotFoundById, + utils.get_vf_num_by_pci_address, + self.pci_address + ) + + @mock.patch.object(os, 'readlink') + @mock.patch.object(glob, 'iglob') + def test_exception(self, mock_iglob, mock_readlink): + mock_iglob.return_value = self.paths + mock_readlink.side_effect = OSError('No such file or directory') + self.assertRaises( + exception.PciDeviceNotFoundById, + utils.get_vf_num_by_pci_address, + self.pci_address + )