Merge "Allow node owners to administer associated ports"

This commit is contained in:
Zuul 2020-02-07 11:14:41 +00:00 committed by Gerrit Code Review
commit a960c548fe
11 changed files with 542 additions and 51 deletions

View File

@ -339,7 +339,8 @@ class PortsController(rest.RestController):
def _get_ports_collection(self, node_ident, address, portgroup_ident,
marker, limit, sort_key, sort_dir,
resource_url=None, fields=None, detail=None):
resource_url=None, fields=None, detail=None,
owner=None):
limit = api_utils.validate_limit(limit)
sort_dir = api_utils.validate_sort_dir(sort_dir)
@ -370,7 +371,8 @@ class PortsController(rest.RestController):
portgroup.id, limit,
marker_obj,
sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir,
owner=owner)
elif node_ident:
# FIXME(comstud): Since all we need is the node ID, we can
# make this more efficient by only querying
@ -380,13 +382,14 @@ class PortsController(rest.RestController):
ports = objects.Port.list_by_node_id(api.request.context,
node.id, limit, marker_obj,
sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir,
owner=owner)
elif address:
ports = self._get_ports_by_address(address)
ports = self._get_ports_by_address(address, owner=owner)
else:
ports = objects.Port.list(api.request.context, limit,
marker_obj, sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir, owner=owner)
parameters = {}
if detail is not None:
@ -399,7 +402,7 @@ class PortsController(rest.RestController):
sort_dir=sort_dir,
**parameters)
def _get_ports_by_address(self, address):
def _get_ports_by_address(self, address, owner=None):
"""Retrieve a port by its address.
:param address: MAC address of a port, to get the port which has
@ -408,7 +411,8 @@ class PortsController(rest.RestController):
"""
try:
port = objects.Port.get_by_address(api.request.context, address)
port = objects.Port.get_by_address(api.request.context, address,
owner=owner)
return [port]
except exception.PortNotFound:
return []
@ -469,8 +473,7 @@ class PortsController(rest.RestController):
for that portgroup.
:raises: NotAcceptable, HTTPNotFound
"""
cdict = api.request.context.to_policy_values()
policy.authorize('baremetal:port:get', cdict, cdict)
owner = api_utils.check_port_list_policy()
api_utils.check_allow_specify_fields(fields)
self._check_allowed_port_fields(fields)
@ -493,7 +496,7 @@ class PortsController(rest.RestController):
return self._get_ports_collection(node_uuid or node, address,
portgroup, marker, limit, sort_key,
sort_dir, fields=fields,
detail=detail)
detail=detail, owner=owner)
@METRICS.timer('PortsController.detail')
@expose.expose(PortCollection, types.uuid_or_name, types.uuid,
@ -523,8 +526,7 @@ class PortsController(rest.RestController):
:param sort_dir: direction to sort. "asc" or "desc". Default: asc.
:raises: NotAcceptable, HTTPNotFound
"""
cdict = api.request.context.to_policy_values()
policy.authorize('baremetal:port:get', cdict, cdict)
owner = api_utils.check_port_list_policy()
self._check_allowed_port_fields([sort_key])
if portgroup and not api_utils.allow_portgroups_subcontrollers():
@ -546,7 +548,7 @@ class PortsController(rest.RestController):
resource_url = '/'.join(['ports', 'detail'])
return self._get_ports_collection(node_uuid or node, address,
portgroup, marker, limit, sort_key,
sort_dir, resource_url)
sort_dir, resource_url, owner=owner)
@METRICS.timer('PortsController.get_one')
@expose.expose(Port, types.uuid, types.listtype)
@ -558,16 +560,15 @@ class PortsController(rest.RestController):
of the resource to be returned.
:raises: NotAcceptable, HTTPNotFound
"""
cdict = api.request.context.to_policy_values()
policy.authorize('baremetal:port:get', cdict, cdict)
if self.parent_node_ident or self.parent_portgroup_ident:
raise exception.OperationNotPermitted()
rpc_port, rpc_node = api_utils.check_port_policy_and_retrieve(
'baremetal:port:get', port_uuid)
api_utils.check_allow_specify_fields(fields)
self._check_allowed_port_fields(fields)
rpc_port = objects.Port.get_by_uuid(api.request.context, port_uuid)
return Port.convert_with_links(rpc_port, fields=fields)
@METRICS.timer('PortsController.post')
@ -578,13 +579,13 @@ class PortsController(rest.RestController):
:param port: a port within the request body.
:raises: NotAcceptable, HTTPNotFound, Conflict
"""
if self.parent_node_ident or self.parent_portgroup_ident:
raise exception.OperationNotPermitted()
context = api.request.context
cdict = context.to_policy_values()
policy.authorize('baremetal:port:create', cdict, cdict)
if self.parent_node_ident or self.parent_portgroup_ident:
raise exception.OperationNotPermitted()
pdict = port.as_dict()
self._check_allowed_port_fields(pdict)
@ -660,13 +661,14 @@ class PortsController(rest.RestController):
:param patch: a json PATCH document to apply to this port.
:raises: NotAcceptable, HTTPNotFound
"""
context = api.request.context
cdict = context.to_policy_values()
policy.authorize('baremetal:port:update', cdict, cdict)
if self.parent_node_ident or self.parent_portgroup_ident:
raise exception.OperationNotPermitted()
rpc_port, rpc_node = api_utils.check_port_policy_and_retrieve(
'baremetal:port:update', port_uuid)
context = api.request.context
fields_to_check = set()
for field in (self.advanced_net_fields
+ ['portgroup_uuid', 'physical_network',
@ -677,7 +679,6 @@ class PortsController(rest.RestController):
fields_to_check.add(field)
self._check_allowed_port_fields(fields_to_check)
rpc_port = objects.Port.get_by_uuid(context, port_uuid)
port_dict = rpc_port.as_dict()
# NOTE(lucasagomes):
# 1) Remove node_id because it's an internal value and
@ -708,7 +709,6 @@ class PortsController(rest.RestController):
if rpc_port[field] != patch_val:
rpc_port[field] = patch_val
rpc_node = objects.Node.get_by_id(context, rpc_port.node_id)
if (rpc_node.provision_state == ir_states.INSPECTING
and api_utils.allow_inspect_wait_state()):
msg = _('Cannot update port "%(port)s" on "%(node)s" while it is '
@ -742,15 +742,13 @@ class PortsController(rest.RestController):
:param port_uuid: UUID of a port.
:raises: OperationNotPermitted, HTTPNotFound
"""
context = api.request.context
cdict = context.to_policy_values()
policy.authorize('baremetal:port:delete', cdict, cdict)
if self.parent_node_ident or self.parent_portgroup_ident:
raise exception.OperationNotPermitted()
rpc_port = objects.Port.get_by_uuid(context, port_uuid)
rpc_node = objects.Node.get_by_id(context, rpc_port.node_id)
rpc_port, rpc_node = api_utils.check_port_policy_and_retrieve(
'baremetal:port:delete', port_uuid)
context = api.request.context
portgroup_uuid = None
if rpc_port.portgroup_id:

View File

@ -1231,6 +1231,52 @@ def check_node_list_policy(owner=None):
return owner
def check_port_policy_and_retrieve(policy_name, port_uuid):
"""Check if the specified policy authorizes this request on a port.
:param: policy_name: Name of the policy to check.
:param: port_uuid: the UUID of a port.
:raises: HTTPForbidden if the policy forbids access.
:raises: NodeNotFound if the node is not found.
:return: RPC port identified by port_uuid and associated node
"""
context = api.request.context
cdict = context.to_policy_values()
try:
rpc_port = objects.Port.get_by_uuid(context, port_uuid)
except exception.PortNotFound:
# don't expose non-existence of port unless requester
# has generic access to policy
policy.authorize(policy_name, cdict, cdict)
raise
rpc_node = objects.Node.get_by_id(context, rpc_port.node_id)
target_dict = dict(cdict)
target_dict['node.owner'] = rpc_node['owner']
policy.authorize(policy_name, target_dict, cdict)
return rpc_port, rpc_node
def check_port_list_policy():
"""Check if the specified policy authorizes this request on a port.
:raises: HTTPForbidden if the policy forbids access.
:return: owner that should be used for list query, if needed
"""
cdict = api.request.context.to_policy_values()
try:
policy.authorize('baremetal:port:list_all', cdict, cdict)
except exception.HTTPForbidden:
owner = cdict.get('project_id')
if not owner:
raise
policy.authorize('baremetal:port:list', cdict, cdict)
return owner
def allow_build_configdrive():
"""Check if building configdrive is allowed.

View File

@ -231,14 +231,24 @@ port_policies = [
'baremetal:port:get',
'rule:is_admin or rule:is_observer',
'Retrieve Port records',
[{'path': '/ports', 'method': 'GET'},
{'path': '/ports/detail', 'method': 'GET'},
{'path': '/ports/{port_id}', 'method': 'GET'},
[{'path': '/ports/{port_id}', 'method': 'GET'},
{'path': '/nodes/{node_ident}/ports', 'method': 'GET'},
{'path': '/nodes/{node_ident}/ports/detail', 'method': 'GET'},
{'path': '/portgroups/{portgroup_ident}/ports', 'method': 'GET'},
{'path': '/portgroups/{portgroup_ident}/ports/detail',
'method': 'GET'}]),
policy.DocumentedRuleDefault(
'baremetal:port:list',
'rule:baremetal:port:get',
'Retrieve multiple Port records, filtered by owner',
[{'path': '/ports', 'method': 'GET'},
{'path': '/ports/detail', 'method': 'GET'}]),
policy.DocumentedRuleDefault(
'baremetal:port:list_all',
'rule:baremetal:port:get',
'Retrieve multiple Port records',
[{'path': '/ports', 'method': 'GET'},
{'path': '/ports/detail', 'method': 'GET'}]),
policy.DocumentedRuleDefault(
'baremetal:port:create',
'rule:is_admin',

View File

@ -149,6 +149,12 @@ def add_port_filter_by_node(query, value):
return query.filter(models.Node.uuid == value)
def add_port_filter_by_node_owner(query, value):
query = query.join(models.Node,
models.Port.node_id == models.Node.id)
return query.filter(models.Node.owner == value)
def add_portgroup_filter(query, value):
"""Adds a portgroup-specific filter to a query.
@ -672,29 +678,38 @@ class Connection(api.Connection):
except NoResultFound:
raise exception.PortNotFound(port=port_uuid)
def get_port_by_address(self, address):
def get_port_by_address(self, address, owner=None):
query = model_query(models.Port).filter_by(address=address)
if owner:
query = add_port_filter_by_node_owner(query, owner)
try:
return query.one()
except NoResultFound:
raise exception.PortNotFound(port=address)
def get_port_list(self, limit=None, marker=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, owner=None):
query = model_query(models.Port)
if owner:
query = add_port_filter_by_node_owner(query, owner)
return _paginate_query(models.Port, limit, marker,
sort_key, sort_dir)
sort_key, sort_dir, query)
def get_ports_by_node_id(self, node_id, limit=None, marker=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, owner=None):
query = model_query(models.Port)
query = query.filter_by(node_id=node_id)
if owner:
query = add_port_filter_by_node_owner(query, owner)
return _paginate_query(models.Port, limit, marker,
sort_key, sort_dir, query)
def get_ports_by_portgroup_id(self, portgroup_id, limit=None, marker=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, owner=None):
query = model_query(models.Port)
query = query.filter_by(portgroup_id=portgroup_id)
if owner:
query = add_port_filter_by_node_owner(query, owner)
return _paginate_query(models.Port, limit, marker,
sort_key, sort_dir, query)

View File

@ -203,17 +203,18 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
# Implications of calling new remote procedures should be thought through.
# @object_base.remotable_classmethod
@classmethod
def get_by_address(cls, context, address):
def get_by_address(cls, context, address, owner=None):
"""Find a port based on address and return a :class:`Port` object.
:param cls: the :class:`Port`
:param context: Security context
:param address: the address of a port.
:param owner: a node owner to match against
:returns: a :class:`Port` object.
:raises: PortNotFound
"""
db_port = cls.dbapi.get_port_by_address(address)
db_port = cls.dbapi.get_port_by_address(address, owner=owner)
port = cls._from_db_object(context, cls(), db_port)
return port
@ -223,7 +224,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
# @object_base.remotable_classmethod
@classmethod
def list(cls, context, limit=None, marker=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, owner=None):
"""Return a list of Port objects.
:param context: Security context.
@ -231,6 +232,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
:param marker: pagination marker for large data sets.
:param sort_key: column to sort results by.
:param sort_dir: direction to sort. "asc" or "desc".
:param owner: a node owner to match against
:returns: a list of :class:`Port` object.
:raises: InvalidParameterValue
@ -238,7 +240,8 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
db_ports = cls.dbapi.get_port_list(limit=limit,
marker=marker,
sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir,
owner=owner)
return cls._from_db_object_list(context, db_ports)
# NOTE(xek): We don't want to enable RPC on this call just yet. Remotable
@ -247,7 +250,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
# @object_base.remotable_classmethod
@classmethod
def list_by_node_id(cls, context, node_id, limit=None, marker=None,
sort_key=None, sort_dir=None):
sort_key=None, sort_dir=None, owner=None):
"""Return a list of Port objects associated with a given node ID.
:param context: Security context.
@ -256,13 +259,15 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
:param marker: pagination marker for large data sets.
:param sort_key: column to sort results by.
:param sort_dir: direction to sort. "asc" or "desc".
:param owner: a node owner to match against
:returns: a list of :class:`Port` object.
"""
db_ports = cls.dbapi.get_ports_by_node_id(node_id, limit=limit,
marker=marker,
sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir,
owner=owner)
return cls._from_db_object_list(context, db_ports)
# NOTE(xek): We don't want to enable RPC on this call just yet. Remotable
@ -271,7 +276,8 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
# @object_base.remotable_classmethod
@classmethod
def list_by_portgroup_id(cls, context, portgroup_id, limit=None,
marker=None, sort_key=None, sort_dir=None):
marker=None, sort_key=None, sort_dir=None,
owner=None):
"""Return a list of Port objects associated with a given portgroup ID.
:param context: Security context.
@ -280,6 +286,7 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
:param marker: pagination marker for large data sets.
:param sort_key: column to sort results by.
:param sort_dir: direction to sort. "asc" or "desc".
:param owner: a node owner to match against
:returns: a list of :class:`Port` object.
"""
@ -287,7 +294,8 @@ class Port(base.IronicObject, object_base.VersionedObjectDictCompat):
limit=limit,
marker=marker,
sort_key=sort_key,
sort_dir=sort_dir)
sort_dir=sort_dir,
owner=owner)
return cls._from_db_object_list(context, db_ports)
# NOTE(xek): We don't want to enable RPC on this call just yet. Remotable

View File

@ -54,6 +54,8 @@ class TestExposedAPIMethodsCheckPolicy(test_base.TestCase):
('api_utils.check_node_policy_and_retrieve' in src) or
('api_utils.check_node_list_policy' in src) or
('self._get_node_and_topic' in src) or
('api_utils.check_port_policy_and_retrieve' in src) or
('api_utils.check_port_list_policy' in src) or
('policy.authorize' in src and
'context.to_policy_values' in src),
'no policy check found in in exposed '

View File

@ -32,6 +32,7 @@ from ironic.api.controllers.v1 import port as api_port
from ironic.api.controllers.v1 import utils as api_utils
from ironic.api.controllers.v1 import versions
from ironic.common import exception
from ironic.common import policy
from ironic.common import states
from ironic.common import utils as common_utils
from ironic.conductor import rpcapi
@ -189,7 +190,7 @@ class TestListPorts(test_api_base.BaseApiTest):
def setUp(self):
super(TestListPorts, self).setUp()
self.node = obj_utils.create_test_node(self.context)
self.node = obj_utils.create_test_node(self.context, owner='12345')
def test_empty(self):
data = self.get_json('/ports')
@ -250,6 +251,42 @@ class TestListPorts(test_api_base.BaseApiTest):
self.assertEqual(port.uuid, data['ports'][0]["uuid"])
self.assertIsNone(data['ports'][0]["portgroup_uuid"])
@mock.patch.object(policy, 'authorize', spec=True)
def test_list_non_admin_forbidden(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
raise exception.HTTPForbidden(resource='fake')
mock_authorize.side_effect = mock_authorize_function
address_template = "aa:bb:cc:dd:ee:f%d"
for id_ in range(3):
obj_utils.create_test_port(self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address=address_template % id_)
response = self.get_json('/ports',
headers={'X-Project-Id': '12345'},
expect_errors=True)
self.assertEqual(http_client.FORBIDDEN, response.status_int)
@mock.patch.object(policy, 'authorize', spec=True)
def test_list_non_admin_forbidden_no_project(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
address_template = "aa:bb:cc:dd:ee:f%d"
for id_ in range(3):
obj_utils.create_test_port(self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address=address_template % id_)
response = self.get_json('/ports', expect_errors=True)
self.assertEqual(http_client.FORBIDDEN, response.status_int)
def test_get_one(self):
port = obj_utils.create_test_port(self.context, node_id=self.node.id)
data = self.get_json('/ports/%s' % port.uuid)
@ -581,6 +618,33 @@ class TestListPorts(test_api_base.BaseApiTest):
uuids = [n['uuid'] for n in data['ports']]
self.assertCountEqual(ports, uuids)
@mock.patch.object(policy, 'authorize', spec=True)
def test_many_non_admin(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
ports = []
# these ports should be retrieved by the API call
for id_ in range(0, 2):
port = obj_utils.create_test_port(
self.context, node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:3%s' % id_)
ports.append(port.uuid)
# these ports should NOT be retrieved by the API call
for id_ in range(3, 5):
port = obj_utils.create_test_port(
self.context, uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:3%s' % id_)
data = self.get_json('/ports', headers={'X-Project-Id': '12345'})
self.assertEqual(len(ports), len(data['ports']))
uuids = [n['uuid'] for n in data['ports']]
self.assertCountEqual(ports, uuids)
def _test_links(self, public_url=None):
cfg.CONF.set_override('public_endpoint', public_url, 'api')
uuid = uuidutils.generate_uuid()
@ -686,6 +750,47 @@ class TestListPorts(test_api_base.BaseApiTest):
self.assertEqual('application/json', response.content_type)
self.assertIn(invalid_address, response.json['error_message'])
@mock.patch.object(policy, 'authorize', spec=True)
def test_port_by_address_non_admin(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
address_template = "aa:bb:cc:dd:ee:f%d"
for id_ in range(3):
obj_utils.create_test_port(self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address=address_template % id_)
target_address = address_template % 1
data = self.get_json('/ports?address=%s' % target_address,
headers={'X-Project-Id': '12345'})
self.assertThat(data['ports'], matchers.HasLength(1))
self.assertEqual(target_address, data['ports'][0]['address'])
@mock.patch.object(policy, 'authorize', spec=True)
def test_port_by_address_non_admin_no_match(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
address_template = "aa:bb:cc:dd:ee:f%d"
for id_ in range(3):
obj_utils.create_test_port(self.context,
node_id=self.node.id,
uuid=uuidutils.generate_uuid(),
address=address_template % id_)
target_address = address_template % 1
data = self.get_json('/ports?address=%s' % target_address,
headers={'X-Project-Id': '54321'})
self.assertThat(data['ports'], matchers.HasLength(0))
def test_sort_key(self):
ports = []
for id_ in range(3):
@ -765,6 +870,60 @@ class TestListPorts(test_api_base.BaseApiTest):
headers={api_base.Version.string: '1.5'})
self.assertEqual(3, len(data['ports']))
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(api_utils, 'get_rpc_node')
def test_get_all_by_node_name_non_admin(
self, mock_get_rpc_node, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
mock_get_rpc_node.return_value = self.node
for i in range(5):
if i < 3:
node_id = self.node.id
else:
node_id = 100000 + i
obj_utils.create_test_port(self.context,
node_id=node_id,
uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:3%s' % i)
data = self.get_json("/ports?node=%s" % 'test-node',
headers={
api_base.Version.string: '1.5',
'X-Project-Id': '12345'
})
self.assertEqual(3, len(data['ports']))
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(api_utils, 'get_rpc_node')
def test_get_all_by_node_name_non_admin_no_match(
self, mock_get_rpc_node, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
mock_get_rpc_node.return_value = self.node
for i in range(5):
if i < 3:
node_id = self.node.id
else:
node_id = 100000 + i
obj_utils.create_test_port(self.context,
node_id=node_id,
uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:3%s' % i)
data = self.get_json("/ports?node=%s" % 'test-node',
headers={
api_base.Version.string: '1.5',
'X-Project-Id': '54321'
})
self.assertEqual(0, len(data['ports']))
@mock.patch.object(api_utils, 'get_rpc_node')
def test_get_all_by_node_uuid_and_name(self, mock_get_rpc_node):
# GET /v1/ports specifying node and uuid - should only use node_uuid
@ -832,6 +991,48 @@ class TestListPorts(test_api_base.BaseApiTest):
self.assertEqual('application/json', response.content_type)
self.assertEqual(http_client.NOT_ACCEPTABLE, response.status_int)
@mock.patch.object(policy, 'authorize', spec=True)
def test_get_all_by_portgroup_uuid_non_admin(self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
pg = obj_utils.create_test_portgroup(self.context,
node_id=self.node.id)
port = obj_utils.create_test_port(self.context, node_id=self.node.id,
portgroup_id=pg.id)
data = self.get_json('/ports/detail?portgroup=%s' % pg.uuid,
headers={
api_base.Version.string: '1.24',
'X-Project-Id': '12345'
})
self.assertEqual(port.uuid, data['ports'][0]['uuid'])
self.assertEqual(pg.uuid,
data['ports'][0]['portgroup_uuid'])
@mock.patch.object(policy, 'authorize', spec=True)
def test_get_all_by_portgroup_uuid_non_admin_no_match(
self, mock_authorize):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
pg = obj_utils.create_test_portgroup(self.context)
obj_utils.create_test_port(self.context, node_id=self.node.id,
portgroup_id=pg.id)
data = self.get_json('/ports/detail?portgroup=%s' % pg.uuid,
headers={
api_base.Version.string: '1.24',
'X-Project-Id': '54321'
})
self.assertThat(data['ports'], matchers.HasLength(0))
def test_get_all_by_portgroup_name(self):
pg = obj_utils.create_test_portgroup(self.context,
node_id=self.node.id)

View File

@ -1031,3 +1031,159 @@ class TestCheckNodeListPolicy(base.TestCase):
utils.check_node_list_policy,
'54321'
)
class TestCheckPortPolicyAndRetrieve(base.TestCase):
def setUp(self):
super(TestCheckPortPolicyAndRetrieve, self).setUp()
self.valid_port_uuid = uuidutils.generate_uuid()
self.node = test_api_utils.post_get_test_node()
self.node['owner'] = '12345'
self.port = objects.Port(self.context, node_id=42)
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(objects.Port, 'get_by_uuid')
@mock.patch.object(objects.Node, 'get_by_id')
def test_check_port_policy_and_retrieve(
self, mock_ngbi, mock_pgbu, mock_authorize, mock_pr
):
mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {}
mock_pgbu.return_value = self.port
mock_ngbi.return_value = self.node
rpc_port, rpc_node = utils.check_port_policy_and_retrieve(
'fake_policy', self.valid_port_uuid
)
mock_pgbu.assert_called_once_with(mock_pr.context,
self.valid_port_uuid)
mock_ngbi.assert_called_once_with(mock_pr.context, 42)
mock_authorize.assert_called_once_with(
'fake_policy', {'node.owner': '12345'}, {})
self.assertEqual(self.port, rpc_port)
self.assertEqual(self.node, rpc_node)
@mock.patch.object(api, 'request', spec_set=["context"])
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(objects.Port, 'get_by_uuid')
def test_check_port_policy_and_retrieve_no_port_policy_forbidden(
self, mock_pgbu, mock_authorize, mock_pr
):
mock_pr.context.to_policy_values.return_value = {}
mock_authorize.side_effect = exception.HTTPForbidden(resource='fake')
mock_pgbu.side_effect = exception.PortNotFound(
port=self.valid_port_uuid)
self.assertRaises(
exception.HTTPForbidden,
utils.check_port_policy_and_retrieve,
'fake-policy',
self.valid_port_uuid
)
@mock.patch.object(api, 'request', spec_set=["context"])
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(objects.Port, 'get_by_uuid')
def test_check_port_policy_and_retrieve_no_port(
self, mock_pgbu, mock_authorize, mock_pr
):
mock_pr.context.to_policy_values.return_value = {}
mock_pgbu.side_effect = exception.PortNotFound(
port=self.valid_port_uuid)
self.assertRaises(
exception.PortNotFound,
utils.check_port_policy_and_retrieve,
'fake-policy',
self.valid_port_uuid
)
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
@mock.patch.object(objects.Port, 'get_by_uuid')
@mock.patch.object(objects.Node, 'get_by_id')
def test_check_port_policy_and_retrieve_policy_forbidden(
self, mock_ngbi, mock_pgbu, mock_authorize, mock_pr
):
mock_pr.version.minor = 50
mock_pr.context.to_policy_values.return_value = {}
mock_authorize.side_effect = exception.HTTPForbidden(resource='fake')
mock_pgbu.return_value = self.port
mock_ngbi.return_value = self.node
self.assertRaises(
exception.HTTPForbidden,
utils.check_port_policy_and_retrieve,
'fake-policy',
self.valid_port_uuid
)
class TestCheckPortListPolicy(base.TestCase):
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
def test_check_port_list_policy(
self, mock_authorize, mock_pr
):
mock_pr.context.to_policy_values.return_value = {
'project_id': '12345'
}
mock_pr.version.minor = 50
owner = utils.check_port_list_policy()
self.assertIsNone(owner)
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
def test_check_port_list_policy_forbidden(
self, mock_authorize, mock_pr
):
def mock_authorize_function(rule, target, creds):
raise exception.HTTPForbidden(resource='fake')
mock_authorize.side_effect = mock_authorize_function
mock_pr.context.to_policy_values.return_value = {
'project_id': '12345'
}
mock_pr.version.minor = 50
self.assertRaises(
exception.HTTPForbidden,
utils.check_port_list_policy,
)
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
def test_check_port_list_policy_forbidden_no_project(
self, mock_authorize, mock_pr
):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
mock_pr.context.to_policy_values.return_value = {}
mock_pr.version.minor = 50
self.assertRaises(
exception.HTTPForbidden,
utils.check_port_list_policy,
)
@mock.patch.object(api, 'request', spec_set=["context", "version"])
@mock.patch.object(policy, 'authorize', spec=True)
def test_check_port_list_policy_non_admin(
self, mock_authorize, mock_pr
):
def mock_authorize_function(rule, target, creds):
if rule == 'baremetal:port:list_all':
raise exception.HTTPForbidden(resource='fake')
return True
mock_authorize.side_effect = mock_authorize_function
mock_pr.context.to_policy_values.return_value = {
'project_id': '12345'
}
mock_pr.version.minor = 50
owner = utils.check_port_list_policy()
self.assertEqual(owner, '12345')

View File

@ -28,7 +28,7 @@ class DbPortTestCase(base.DbTestCase):
# This method creates a port for every test and
# replaces a test for creating a port.
super(DbPortTestCase, self).setUp()
self.node = db_utils.create_test_node()
self.node = db_utils.create_test_node(owner='12345')
self.portgroup = db_utils.create_test_portgroup(node_id=self.node.id)
self.port = db_utils.create_test_port(node_id=self.node.id,
portgroup_id=self.portgroup.id)
@ -45,6 +45,17 @@ class DbPortTestCase(base.DbTestCase):
res = self.dbapi.get_port_by_address(self.port.address)
self.assertEqual(self.port.id, res.id)
def test_get_port_by_address_filter_by_owner(self):
res = self.dbapi.get_port_by_address(self.port.address,
owner=self.node.owner)
self.assertEqual(self.port.id, res.id)
def test_get_port_by_address_filter_by_owner_no_match(self):
self.assertRaises(exception.PortNotFound,
self.dbapi.get_port_by_address,
self.port.address,
owner='54321')
def test_get_port_list(self):
uuids = []
for i in range(1, 6):
@ -72,10 +83,36 @@ class DbPortTestCase(base.DbTestCase):
self.assertRaises(exception.InvalidParameterValue,
self.dbapi.get_port_list, sort_key='foo')
def test_get_port_list_filter_by_node_owner(self):
uuids = []
for i in range(1, 3):
port = db_utils.create_test_port(uuid=uuidutils.generate_uuid(),
address='52:54:00:cf:2d:4%s' % i)
for i in range(4, 6):
port = db_utils.create_test_port(uuid=uuidutils.generate_uuid(),
node_id=self.node.id,
address='52:54:00:cf:2d:4%s' % i)
uuids.append(str(port.uuid))
# Also add the uuid for the port created in setUp()
uuids.append(str(self.port.uuid))
res = self.dbapi.get_port_list(owner=self.node.owner)
res_uuids = [r.uuid for r in res]
self.assertCountEqual(uuids, res_uuids)
def test_get_ports_by_node_id(self):
res = self.dbapi.get_ports_by_node_id(self.node.id)
self.assertEqual(self.port.address, res[0].address)
def test_get_ports_by_node_id_filter_by_node_owner(self):
res = self.dbapi.get_ports_by_node_id(self.node.id,
owner=self.node.owner)
self.assertEqual(self.port.address, res[0].address)
def test_get_ports_by_node_id_filter_by_node_owner_no_match(self):
res = self.dbapi.get_ports_by_node_id(self.node.id,
owner='54321')
self.assertEqual([], res)
def test_get_ports_by_node_id_that_does_not_exist(self):
self.assertEqual([], self.dbapi.get_ports_by_node_id(99))
@ -83,6 +120,16 @@ class DbPortTestCase(base.DbTestCase):
res = self.dbapi.get_ports_by_portgroup_id(self.portgroup.id)
self.assertEqual(self.port.address, res[0].address)
def test_get_ports_by_portgroup_id_filter_by_node_owner(self):
res = self.dbapi.get_ports_by_portgroup_id(self.portgroup.id,
owner=self.node.owner)
self.assertEqual(self.port.address, res[0].address)
def test_get_ports_by_portgroup_id_filter_by_node_owner_no_match(self):
res = self.dbapi.get_ports_by_portgroup_id(self.portgroup.id,
owner='54321')
self.assertEqual([], res)
def test_get_ports_by_portgroup_id_that_does_not_exist(self):
self.assertEqual([], self.dbapi.get_ports_by_portgroup_id(99))

View File

@ -66,7 +66,7 @@ class TestPortObject(db_base.DbTestCase, obj_utils.SchemasTestMixIn):
port = objects.Port.get(self.context, address)
mock_get_port.assert_called_once_with(address)
mock_get_port.assert_called_once_with(address, owner=None)
self.assertEqual(self.context, port._context)
def test_get_bad_id_and_uuid_and_address(self):

View File

@ -0,0 +1,8 @@
---
features:
- |
A port is owned by its associated node's owner. This owner is now exposed
to policy checks, giving Ironic admins the option of modifying the policy
file to allow users specified by a node's owner field to perform API
actions on that node's associated ports through the ``is_node_owner``
rule.