diff --git a/oslo_limit/exception.py b/oslo_limit/exception.py index 4111f4d..bb84d21 100644 --- a/oslo_limit/exception.py +++ b/oslo_limit/exception.py @@ -21,3 +21,15 @@ class SessionInitError(Exception): "Can't initialise OpenStackSDK session: %(reason)s." ) % {'reason': reason} super(SessionInitError, self).__init__(msg) + + +class LimitNotFound(Exception): + def __init__(self, resource, service, region): + msg = _("Can't find the limit for resource %(resource)s " + "for service %(service)s in region %(region)s." + ) % { + 'resource': resource, 'service': service, 'region': region} + self.resource = resource + self.service = service + self.region = region + super(LimitNotFound, self).__init__(msg) diff --git a/oslo_limit/limit.py b/oslo_limit/limit.py index 55348ed..fbc5210 100644 --- a/oslo_limit/limit.py +++ b/oslo_limit/limit.py @@ -144,3 +144,69 @@ class _StrictTwoLevelEnforcer(object): _MODELS = [_FlatEnforcer, _StrictTwoLevelEnforcer] + + +class _EnforcerUtils(object): + """Logic common used by multiple enforcers""" + + def __init__(self): + self.connection = _get_keystone_connection() + + # get and cache endpoint info + endpoint_id = CONF.oslo_limit.endpoint_id + self._endpoint = self.connection.get_endpoint(endpoint_id) + if not self._endpoint: + raise ValueError("can't find endpoint for %s" % endpoint_id) + self._service_id = self._endpoint.service_id + self._region_id = self._endpoint.region_id + + def get_project_limits(self, project_id, resource_names): + """Get all the limits for given project a resource_name list + + We will raise + :param project_id: + :param resource_names: list of resource_name strings + :return: list of (resource_name,limit) pairs + + :raises exception.LimitNotFound if no limit is found + """ + # Using a list to preserver the resource_name order + project_limits = [] + for resource_name in resource_names: + limit = self._get_limit(project_id, resource_name) + project_limits.append((resource_name, limit)) + return project_limits + + def _get_limit(self, project_id, resource_name): + # TODO(johngarbutt): might need to cache here + project_limit = self._get_project_limit(project_id, resource_name) + if project_limit: + return project_limit.resource_limit + + registered_limit = self._get_registered_limit(resource_name) + if registered_limit: + return registered_limit.default_limit + + raise exception.LimitNotFound( + resource_name, self._service_id, self._region_id) + + def _get_project_limit(self, project_id, resource_name): + limit = self.connection.limits( + service_id=self._service_id, + region_id=self._region_id, + resource_name=resource_name, + project_id=project_id) + try: + return next(limit) + except StopIteration: + return None + + def _get_registered_limit(self, resource_name): + reg_limit = self.connection.registered_limits( + service_id=self._service_id, + region_id=self._region_id, + resource_name=resource_name) + try: + return next(reg_limit) + except StopIteration: + return None diff --git a/oslo_limit/tests/test_limit.py b/oslo_limit/tests/test_limit.py index c79af21..44f50c2 100644 --- a/oslo_limit/tests/test_limit.py +++ b/oslo_limit/tests/test_limit.py @@ -22,10 +22,14 @@ Tests for `limit` module. import mock import uuid +from openstack.identity.v3 import endpoint +from openstack.identity.v3 import limit as klimit +from openstack.identity.v3 import registered_limit from oslo_config import cfg from oslo_config import fixture as config_fixture from oslotest import base +from oslo_limit import exception from oslo_limit import limit from oslo_limit import opts @@ -103,3 +107,72 @@ class TestEnforcer(base.BaseTestCase): json.json.return_value = {"model": {"name": "foo"}} e = self.assertRaises(ValueError, enforcer._get_model_impl) self.assertEqual("enforcement model foo is not supported", str(e)) + + +class TestEnforcerUtils(base.BaseTestCase): + def setUp(self): + super(TestEnforcerUtils, self).setUp() + self.mock_conn = mock.MagicMock() + limit._SDK_CONNECTION = self.mock_conn + + def test_get_endpoint(self): + fake_endpoint = endpoint.Endpoint() + self.mock_conn.get_endpoint.return_value = fake_endpoint + + utils = limit._EnforcerUtils() + + self.assertEqual(fake_endpoint, utils._endpoint) + self.mock_conn.get_endpoint.assert_called_once_with(None) + + def test_get_registered_limit_empty(self): + self.mock_conn.registered_limits.return_value = iter([]) + + utils = limit._EnforcerUtils() + reg_limit = utils._get_registered_limit("foo") + + self.assertIsNone(reg_limit) + + def test_get_registered_limit(self): + foo = registered_limit.RegisteredLimit() + foo.resource_name = "foo" + self.mock_conn.registered_limits.return_value = iter([foo]) + + utils = limit._EnforcerUtils() + reg_limit = utils._get_registered_limit("foo") + + self.assertEqual(foo, reg_limit) + + def test_get_project_limits(self): + fake_endpoint = endpoint.Endpoint() + fake_endpoint.service_id = "service_id" + fake_endpoint.region_id = "region_id" + self.mock_conn.get_endpoint.return_value = fake_endpoint + project_id = uuid.uuid4().hex + + # a is a project limit, b and c don't have one + empty_iterator = iter([]) + a = klimit.Limit() + a.resource_name = "a" + a.resource_limit = 1 + a_iterator = iter([a]) + self.mock_conn.limits.side_effect = [a_iterator, empty_iterator, + empty_iterator] + + # b has a limit, but c doesn't, a isn't ever checked + b = registered_limit.RegisteredLimit() + b.resource_name = "b" + b.default_limit = 2 + b_iterator = iter([b]) + self.mock_conn.registered_limits.side_effect = [b_iterator, + empty_iterator] + + utils = limit._EnforcerUtils() + limits = utils.get_project_limits(project_id, ["a", "b"]) + self.assertEqual([('a', 1), ('b', 2)], limits) + + e = self.assertRaises(exception.LimitNotFound, + utils.get_project_limits, + project_id, ["c"]) + self.assertEqual("c", e.resource) + self.assertEqual("service_id", e.service) + self.assertEqual("region_id", e.region)