diff --git a/ironic/api/controllers/v1/port.py b/ironic/api/controllers/v1/port.py index 6422ae27f6..789de86228 100644 --- a/ironic/api/controllers/v1/port.py +++ b/ironic/api/controllers/v1/port.py @@ -339,7 +339,30 @@ class PortsController(rest.RestController): def _get_ports_collection(self, node_ident, address, portgroup_ident, marker, limit, sort_key, sort_dir, resource_url=None, fields=None, detail=None, - owner=None): + project=None): + """Retrieve a collection of ports. + + :param node_ident: UUID or name of a node, to get only ports for that + node. + :param address: MAC address of a port, to get the port which has + this MAC address. + :param portgroup_ident: UUID or name of a portgroup, to get only ports + for that portgroup. + :param marker: pagination marker for large data sets. + :param limit: maximum number of resources to return in a single result. + This value cannot be larger than the value of max_limit + in the [api] section of the ironic configuration, or only + max_limit resources will be returned. + :param sort_key: column to sort results by. Default: id. + :param sort_dir: direction to sort. "asc" or "desc". Default: asc. + :param resource_url: Optional, base url to be used for links + :param fields: Optional, a list with a specified set of fields + of the resource to be returned. + :param detail: Optional, show detailed list of ports + :param project: Optional, filter by project + :returns: a list of ports. + + """ limit = api_utils.validate_limit(limit) sort_dir = api_utils.validate_sort_dir(sort_dir) @@ -371,7 +394,7 @@ class PortsController(rest.RestController): marker_obj, sort_key=sort_key, sort_dir=sort_dir, - owner=owner) + project=project) elif node_ident: # FIXME(comstud): Since all we need is the node ID, we can # make this more efficient by only querying @@ -382,13 +405,13 @@ class PortsController(rest.RestController): node.id, limit, marker_obj, sort_key=sort_key, sort_dir=sort_dir, - owner=owner) + project=project) elif address: - ports = self._get_ports_by_address(address, owner=owner) + ports = self._get_ports_by_address(address, project=project) else: ports = objects.Port.list(api.request.context, limit, marker_obj, sort_key=sort_key, - sort_dir=sort_dir, owner=owner) + sort_dir=sort_dir, project=project) parameters = {} if detail is not None: @@ -401,17 +424,18 @@ class PortsController(rest.RestController): sort_dir=sort_dir, **parameters) - def _get_ports_by_address(self, address, owner=None): + def _get_ports_by_address(self, address, project=None): """Retrieve a port by its address. :param address: MAC address of a port, to get the port which has this MAC address. + :param project: Optional, filter by project :returns: a list with the port, or an empty list if no port is found. """ try: port = objects.Port.get_by_address(api.request.context, address, - owner=owner) + project=project) return [port] except exception.PortNotFound: return [] @@ -480,7 +504,7 @@ class PortsController(rest.RestController): for that portgroup. :raises: NotAcceptable, HTTPNotFound """ - owner = api_utils.check_port_list_policy() + project = api_utils.check_port_list_policy() api_utils.check_allow_specify_fields(fields) self._check_allowed_port_fields(fields) @@ -503,7 +527,7 @@ class PortsController(rest.RestController): return self._get_ports_collection(node_uuid or node, address, portgroup, marker, limit, sort_key, sort_dir, fields=fields, - detail=detail, owner=owner) + detail=detail, project=project) @METRICS.timer('PortsController.detail') @expose.expose(PortCollection, types.uuid_or_name, types.uuid, @@ -533,7 +557,7 @@ class PortsController(rest.RestController): :param sort_dir: direction to sort. "asc" or "desc". Default: asc. :raises: NotAcceptable, HTTPNotFound """ - owner = api_utils.check_port_list_policy() + project = api_utils.check_port_list_policy() self._check_allowed_port_fields([sort_key]) if portgroup and not api_utils.allow_portgroups_subcontrollers(): @@ -555,7 +579,8 @@ class PortsController(rest.RestController): resource_url = '/'.join(['ports', 'detail']) return self._get_ports_collection(node_uuid or node, address, portgroup, marker, limit, sort_key, - sort_dir, resource_url, owner=owner) + sort_dir, resource_url, + project=project) @METRICS.timer('PortsController.get_one') @expose.expose(Port, types.uuid, types.listtype) diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index e8f4c737a9..4eca24c49f 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -155,6 +155,13 @@ def add_port_filter_by_node_owner(query, value): return query.filter(models.Node.owner == value) +def add_port_filter_by_node_project(query, value): + query = query.join(models.Node, + models.Port.node_id == models.Node.id) + return query.filter((models.Node.owner == value) + | (models.Node.lessee == value)) + + def add_portgroup_filter(query, value): """Adds a portgroup-specific filter to a query. @@ -687,38 +694,49 @@ class Connection(api.Connection): except NoResultFound: raise exception.PortNotFound(port=port_uuid) - def get_port_by_address(self, address, owner=None): + def get_port_by_address(self, address, owner=None, project=None): query = model_query(models.Port).filter_by(address=address) if owner: query = add_port_filter_by_node_owner(query, owner) + elif project: + query = add_port_filter_by_node_project(query, project) try: return query.one() except NoResultFound: raise exception.PortNotFound(port=address) def get_port_list(self, limit=None, marker=None, - sort_key=None, sort_dir=None, owner=None): + sort_key=None, sort_dir=None, owner=None, + project=None): query = model_query(models.Port) if owner: query = add_port_filter_by_node_owner(query, owner) + elif project: + query = add_port_filter_by_node_project(query, project) return _paginate_query(models.Port, limit, marker, sort_key, sort_dir, query) def get_ports_by_node_id(self, node_id, limit=None, marker=None, - sort_key=None, sort_dir=None, owner=None): + sort_key=None, sort_dir=None, owner=None, + project=None): query = model_query(models.Port) query = query.filter_by(node_id=node_id) if owner: query = add_port_filter_by_node_owner(query, owner) + elif project: + query = add_port_filter_by_node_project(query, project) return _paginate_query(models.Port, limit, marker, sort_key, sort_dir, query) def get_ports_by_portgroup_id(self, portgroup_id, limit=None, marker=None, - sort_key=None, sort_dir=None, owner=None): + sort_key=None, sort_dir=None, owner=None, + project=None): query = model_query(models.Port) query = query.filter_by(portgroup_id=portgroup_id) if owner: query = add_port_filter_by_node_owner(query, owner) + elif project: + query = add_port_filter_by_node_project(query, project) return _paginate_query(models.Port, limit, marker, sort_key, sort_dir, query) diff --git a/ironic/objects/port.py b/ironic/objects/port.py index 6c75c8c21e..85690b1624 100644 --- a/ironic/objects/port.py +++ b/ironic/objects/port.py @@ -203,18 +203,21 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): # Implications of calling new remote procedures should be thought through. # @object_base.remotable_classmethod @classmethod - def get_by_address(cls, context, address, owner=None): + def get_by_address(cls, context, address, owner=None, project=None): """Find a port based on address and return a :class:`Port` object. :param cls: the :class:`Port` :param context: Security context :param address: the address of a port. - :param owner: a node owner to match against + :param owner: DEPRECATED a node owner to match against + :param project: a node owner or lessee to match against :returns: a :class:`Port` object. :raises: PortNotFound """ - db_port = cls.dbapi.get_port_by_address(address, owner=owner) + if owner and not project: + project = owner + db_port = cls.dbapi.get_port_by_address(address, project=project) port = cls._from_db_object(context, cls(), db_port) return port @@ -224,7 +227,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): # @object_base.remotable_classmethod @classmethod def list(cls, context, limit=None, marker=None, - sort_key=None, sort_dir=None, owner=None): + sort_key=None, sort_dir=None, owner=None, project=None): """Return a list of Port objects. :param context: Security context. @@ -232,16 +235,19 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): :param marker: pagination marker for large data sets. :param sort_key: column to sort results by. :param sort_dir: direction to sort. "asc" or "desc". - :param owner: a node owner to match against + :param owner: DEPRECATED a node owner to match against + :param project: a node owner or lessee to match against :returns: a list of :class:`Port` object. :raises: InvalidParameterValue """ + if owner and not project: + project = owner db_ports = cls.dbapi.get_port_list(limit=limit, marker=marker, sort_key=sort_key, sort_dir=sort_dir, - owner=owner) + project=project) return cls._from_db_object_list(context, db_ports) # NOTE(xek): We don't want to enable RPC on this call just yet. Remotable @@ -250,7 +256,8 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): # @object_base.remotable_classmethod @classmethod def list_by_node_id(cls, context, node_id, limit=None, marker=None, - sort_key=None, sort_dir=None, owner=None): + sort_key=None, sort_dir=None, owner=None, + project=None): """Return a list of Port objects associated with a given node ID. :param context: Security context. @@ -259,15 +266,18 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): :param marker: pagination marker for large data sets. :param sort_key: column to sort results by. :param sort_dir: direction to sort. "asc" or "desc". - :param owner: a node owner to match against + :param owner: DEPRECATED a node owner to match against + :param project: a node owner or lessee to match against :returns: a list of :class:`Port` object. """ + if owner and not project: + project = owner db_ports = cls.dbapi.get_ports_by_node_id(node_id, limit=limit, marker=marker, sort_key=sort_key, sort_dir=sort_dir, - owner=owner) + project=project) return cls._from_db_object_list(context, db_ports) # NOTE(xek): We don't want to enable RPC on this call just yet. Remotable @@ -277,7 +287,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): @classmethod def list_by_portgroup_id(cls, context, portgroup_id, limit=None, marker=None, sort_key=None, sort_dir=None, - owner=None): + owner=None, project=None): """Return a list of Port objects associated with a given portgroup ID. :param context: Security context. @@ -286,16 +296,19 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat): :param marker: pagination marker for large data sets. :param sort_key: column to sort results by. :param sort_dir: direction to sort. "asc" or "desc". - :param owner: a node owner to match against + :param owner: DEPRECATED a node owner to match against + :param project: a node owner or lessee to match against :returns: a list of :class:`Port` object. """ + if owner and not project: + project = owner db_ports = cls.dbapi.get_ports_by_portgroup_id(portgroup_id, limit=limit, marker=marker, sort_key=sort_key, sort_dir=sort_dir, - owner=owner) + project=project) return cls._from_db_object_list(context, db_ports) # NOTE(xek): We don't want to enable RPC on this call just yet. Remotable diff --git a/ironic/tests/unit/api/controllers/v1/test_port.py b/ironic/tests/unit/api/controllers/v1/test_port.py index 24df8bf265..b5e145b9fc 100644 --- a/ironic/tests/unit/api/controllers/v1/test_port.py +++ b/ironic/tests/unit/api/controllers/v1/test_port.py @@ -24,6 +24,7 @@ from oslo_utils import timeutils from oslo_utils import uuidutils from testtools import matchers +from ironic import api from ironic.api.controllers import base as api_base from ironic.api.controllers import v1 as api_v1 from ironic.api.controllers.v1 import notification_utils @@ -195,6 +196,40 @@ class TestPortsController__CheckAllowedPortFields(base.TestCase): mock_allow_port.assert_called_once_with() +@mock.patch.object(objects.Port, 'list', autospec=True) +@mock.patch.object(api, 'request', spec_set=['context']) +class TestPortsController__GetPortsCollection(base.TestCase): + + def setUp(self): + super(TestPortsController__GetPortsCollection, self).setUp() + self.controller = api_port.PortsController() + + def test__get_ports_collection(self, mock_request, mock_list): + mock_request.context = 'fake-context' + mock_list.return_value = [] + self.controller._get_ports_collection(None, None, None, None, None, + None, 'asc') + mock_list.assert_called_once_with('fake-context', 1000, None, + project=None, sort_dir='asc', + sort_key=None) + + +@mock.patch.object(objects.Port, 'get_by_address', autospec=True) +@mock.patch.object(api, 'request', spec_set=['context']) +class TestPortsController__GetPortByAddress(base.TestCase): + + def setUp(self): + super(TestPortsController__GetPortByAddress, self).setUp() + self.controller = api_port.PortsController() + + def test__get_ports_by_address(self, mock_request, mock_gba): + mock_request.context = 'fake-context' + mock_gba.return_value = None + self.controller._get_ports_by_address('fake-address') + mock_gba.assert_called_once_with('fake-context', 'fake-address', + project=None) + + class TestListPorts(test_api_base.BaseApiTest): def setUp(self): diff --git a/ironic/tests/unit/db/test_ports.py b/ironic/tests/unit/db/test_ports.py index e0d2e1d665..d2434d603a 100644 --- a/ironic/tests/unit/db/test_ports.py +++ b/ironic/tests/unit/db/test_ports.py @@ -28,7 +28,8 @@ class DbPortTestCase(base.DbTestCase): # This method creates a port for every test and # replaces a test for creating a port. super(DbPortTestCase, self).setUp() - self.node = db_utils.create_test_node(owner='12345') + self.node = db_utils.create_test_node(owner='12345', + lessee='54321') self.portgroup = db_utils.create_test_portgroup(node_id=self.node.id) self.port = db_utils.create_test_port(node_id=self.node.id, portgroup_id=self.portgroup.id) @@ -56,6 +57,17 @@ class DbPortTestCase(base.DbTestCase): self.port.address, owner='54321') + def test_get_port_by_address_filter_by_project(self): + res = self.dbapi.get_port_by_address(self.port.address, + project=self.node.lessee) + self.assertEqual(self.port.id, res.id) + + def test_get_port_by_address_filter_by_project_no_match(self): + self.assertRaises(exception.PortNotFound, + self.dbapi.get_port_by_address, + self.port.address, + project='55555') + def test_get_port_list(self): uuids = [] for i in range(1, 6): @@ -99,6 +111,30 @@ class DbPortTestCase(base.DbTestCase): res_uuids = [r.uuid for r in res] self.assertCountEqual(uuids, res_uuids) + def test_get_port_list_filter_by_node_project(self): + lessee_node = db_utils.create_test_node(uuid=uuidutils.generate_uuid(), + lessee=self.node.owner) + + uuids = [] + for i in range(1, 3): + port = db_utils.create_test_port(uuid=uuidutils.generate_uuid(), + node_id=lessee_node.id, + address='52:54:00:cf:2d:4%s' % i) + uuids.append(str(port.uuid)) + for i in range(4, 6): + port = db_utils.create_test_port(uuid=uuidutils.generate_uuid(), + address='52:54:00:cf:2d:4%s' % i) + for i in range(7, 9): + port = db_utils.create_test_port(uuid=uuidutils.generate_uuid(), + node_id=self.node.id, + address='52:54:00:cf:2d:4%s' % i) + uuids.append(str(port.uuid)) + # Also add the uuid for the port created in setUp() + uuids.append(str(self.port.uuid)) + res = self.dbapi.get_port_list(project=self.node.owner) + res_uuids = [r.uuid for r in res] + self.assertCountEqual(uuids, res_uuids) + def test_get_ports_by_node_id(self): res = self.dbapi.get_ports_by_node_id(self.node.id) self.assertEqual(self.port.address, res[0].address) @@ -113,6 +149,16 @@ class DbPortTestCase(base.DbTestCase): owner='54321') self.assertEqual([], res) + def test_get_ports_by_node_id_filter_by_node_project(self): + res = self.dbapi.get_ports_by_node_id(self.node.id, + project=self.node.lessee) + self.assertEqual(self.port.address, res[0].address) + + def test_get_ports_by_node_id_filter_by_node_project_no_match(self): + res = self.dbapi.get_ports_by_node_id(self.node.id, + owner='11111') + self.assertEqual([], res) + def test_get_ports_by_node_id_that_does_not_exist(self): self.assertEqual([], self.dbapi.get_ports_by_node_id(99)) @@ -130,6 +176,16 @@ class DbPortTestCase(base.DbTestCase): owner='54321') self.assertEqual([], res) + def test_get_ports_by_portgroup_id_filter_by_node_project(self): + res = self.dbapi.get_ports_by_portgroup_id(self.portgroup.id, + project=self.node.lessee) + self.assertEqual(self.port.address, res[0].address) + + def test_get_ports_by_portgroup_id_filter_by_node_project_no_match(self): + res = self.dbapi.get_ports_by_portgroup_id(self.portgroup.id, + project='11111') + self.assertEqual([], res) + def test_get_ports_by_portgroup_id_that_does_not_exist(self): self.assertEqual([], self.dbapi.get_ports_by_portgroup_id(99)) diff --git a/ironic/tests/unit/objects/test_port.py b/ironic/tests/unit/objects/test_port.py index 5a88ecc1bd..43c58876e9 100644 --- a/ironic/tests/unit/objects/test_port.py +++ b/ironic/tests/unit/objects/test_port.py @@ -66,7 +66,7 @@ class TestPortObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn): port = objects.Port.get(self.context, address) - mock_get_port.assert_called_once_with(address, owner=None) + mock_get_port.assert_called_once_with(address, project=None) self.assertEqual(self.context, port._context) def test_get_bad_id_and_uuid_and_address(self): @@ -146,6 +146,22 @@ class TestPortObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn): self.assertThat(ports, matchers.HasLength(1)) self.assertIsInstance(ports[0], objects.Port) self.assertEqual(self.context, ports[0]._context) + mock_get_list.assert_called_once_with( + limit=None, marker=None, project=None, sort_dir=None, + sort_key=None) + + def test_list_deprecated_owner(self): + with mock.patch.object(self.dbapi, 'get_port_list', + autospec=True) as mock_get_list: + mock_get_list.return_value = [self.fake_port] + ports = objects.Port.list(self.context, + owner='12345') + self.assertThat(ports, matchers.HasLength(1)) + self.assertIsInstance(ports[0], objects.Port) + self.assertEqual(self.context, ports[0]._context) + mock_get_list.assert_called_once_with( + limit=None, marker=None, project='12345', sort_dir=None, + sort_key=None) @mock.patch.object(obj_base.IronicObject, 'supports_version', spec_set=types.FunctionType) diff --git a/releasenotes/notes/port-list-by-project-8cfaf3b2cf0dd627.yaml b/releasenotes/notes/port-list-by-project-8cfaf3b2cf0dd627.yaml new file mode 100644 index 0000000000..32d2821afe --- /dev/null +++ b/releasenotes/notes/port-list-by-project-8cfaf3b2cf0dd627.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Allow port lists to be filtered by project. Doing so checks the specified + project against the port's node's owner and lessee.