Merge "mypy: cinder/api/common.py"

This commit is contained in:
Zuul 2022-07-06 15:37:07 +00:00 committed by Gerrit Code Review
commit 13a57cfea6
2 changed files with 73 additions and 30 deletions

View File

@ -13,10 +13,14 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from __future__ import annotations
import enum import enum
import json import json
import os import os
import re import re
import typing
from typing import Any, Iterable, Optional, Union # noqa: H301
import urllib import urllib
from oslo_config import cfg from oslo_config import cfg
@ -26,6 +30,8 @@ import webob
from cinder.api import api_utils from cinder.api import api_utils
from cinder.api import microversions as mv from cinder.api import microversions as mv
from cinder.common import constants from cinder.common import constants
if typing.TYPE_CHECKING:
from cinder import context
from cinder import exception from cinder import exception
from cinder.i18n import _ from cinder.i18n import _
@ -56,7 +62,8 @@ ATTRIBUTE_CONVERTERS = {'name~': 'display_name~',
METADATA_TYPES = enum.Enum('METADATA_TYPES', 'user image') METADATA_TYPES = enum.Enum('METADATA_TYPES', 'user image')
def get_pagination_params(params, max_limit=None): def get_pagination_params(params: dict,
max_limit: Optional[int] = None) -> tuple:
"""Return marker, limit, offset tuple from request. """Return marker, limit, offset tuple from request.
:param params: `wsgi.Request`'s GET dictionary, possibly containing :param params: `wsgi.Request`'s GET dictionary, possibly containing
@ -79,7 +86,7 @@ def get_pagination_params(params, max_limit=None):
return marker, limit, offset return marker, limit, offset
def _get_limit_param(params, max_limit=None): def _get_limit_param(params: dict, max_limit: Optional[int] = None) -> int:
"""Extract integer limit from request's dictionary or fail. """Extract integer limit from request's dictionary or fail.
Defaults to max_limit if not present and returns max_limit if present Defaults to max_limit if not present and returns max_limit if present
@ -98,12 +105,12 @@ def _get_limit_param(params, max_limit=None):
return limit return limit
def _get_marker_param(params): def _get_marker_param(params: dict[str, Any]) -> Optional[str]:
"""Extract marker id from request's dictionary (defaults to None).""" """Extract marker id from request's dictionary (defaults to None)."""
return params.pop('marker', None) return params.pop('marker', None)
def _get_offset_param(params): def _get_offset_param(params: dict[str, Any]) -> int:
"""Extract offset id from request's dictionary (defaults to 0) or fail.""" """Extract offset id from request's dictionary (defaults to 0) or fail."""
offset = params.pop('offset', 0) offset = params.pop('offset', 0)
return api_utils.validate_integer(offset, return api_utils.validate_integer(offset,
@ -112,7 +119,9 @@ def _get_offset_param(params):
constants.DB_MAX_INT) constants.DB_MAX_INT)
def limited(items, request, max_limit=None): def limited(items: list,
request: webob.Request,
max_limit: Optional[int] = None) -> list:
"""Return a slice of items according to requested offset and limit. """Return a slice of items according to requested offset and limit.
:param items: A sliceable entity :param items: A sliceable entity
@ -131,7 +140,9 @@ def limited(items, request, max_limit=None):
return items[offset:range_end] return items[offset:range_end]
def get_sort_params(params, default_key='created_at', default_dir='desc'): def get_sort_params(params: dict,
default_key: str = 'created_at',
default_dir: str = 'desc') -> tuple[list[str], list[str]]:
"""Retrieves sort keys/directions parameters. """Retrieves sort keys/directions parameters.
Processes the parameters to create a list of sort keys and sort directions Processes the parameters to create a list of sort keys and sort directions
@ -178,7 +189,7 @@ def get_sort_params(params, default_key='created_at', default_dir='desc'):
return sort_keys, sort_dirs return sort_keys, sort_dirs
def get_request_url(request): def get_request_url(request: webob.Request) -> str:
url = request.application_url url = request.application_url
headers = request.headers headers = request.headers
forwarded = headers.get('X-Forwarded-Host') forwarded = headers.get('X-Forwarded-Host')
@ -189,7 +200,7 @@ def get_request_url(request):
return url return url
def remove_version_from_href(href): def remove_version_from_href(href: str) -> str:
"""Removes the first API version from the href. """Removes the first API version from the href.
Given: 'http://cinder.example.com/v1.1/123' Given: 'http://cinder.example.com/v1.1/123'
@ -202,6 +213,7 @@ def remove_version_from_href(href):
Returns: 'http://cinder.example.com/volume/drivers/flashsystem' Returns: 'http://cinder.example.com/volume/drivers/flashsystem'
""" """
parsed_url: Union[list[str], urllib.parse.SplitResult]
parsed_url = urllib.parse.urlsplit(href) parsed_url = urllib.parse.urlsplit(href)
url_parts = parsed_url.path.split('/') url_parts = parsed_url.path.split('/')
@ -227,9 +239,9 @@ def remove_version_from_href(href):
class ViewBuilder(object): class ViewBuilder(object):
"""Model API responses as dictionaries.""" """Model API responses as dictionaries."""
_collection_name = None _collection_name: Optional[str] = None
def _get_project_id_in_url(self, request): def _get_project_id_in_url(self, request: webob.Request) -> str:
project_id = request.environ["cinder.context"].project_id project_id = request.environ["cinder.context"].project_id
if project_id and ("/v3/%s" % project_id in request.url): if project_id and ("/v3/%s" % project_id in request.url):
# project_ids are not mandatory within v3 URLs, but links need # project_ids are not mandatory within v3 URLs, but links need
@ -237,13 +249,18 @@ class ViewBuilder(object):
return project_id return project_id
return '' return ''
def _get_links(self, request, identifier): def _get_links(self,
request: webob.Request,
identifier: str) -> list[dict[str, str]]:
return [{"rel": "self", return [{"rel": "self",
"href": self._get_href_link(request, identifier), }, "href": self._get_href_link(request, identifier), },
{"rel": "bookmark", {"rel": "bookmark",
"href": self._get_bookmark_link(request, identifier), }] "href": self._get_bookmark_link(request, identifier), }]
def _get_next_link(self, request, identifier, collection_name): def _get_next_link(self,
request: webob.Request,
identifier: str,
collection_name: str) -> str:
"""Return href string with proper limit and marker params.""" """Return href string with proper limit and marker params."""
params = request.params.copy() params = request.params.copy()
params["marker"] = identifier params["marker"] = identifier
@ -254,27 +271,37 @@ class ViewBuilder(object):
collection_name) collection_name)
return "%s?%s" % (url, urllib.parse.urlencode(params)) return "%s?%s" % (url, urllib.parse.urlencode(params))
def _get_href_link(self, request, identifier): def _get_href_link(self, request: webob.Request, identifier: str) -> str:
"""Return an href string pointing to this object.""" """Return an href string pointing to this object."""
prefix = self._update_link_prefix(get_request_url(request), prefix = self._update_link_prefix(get_request_url(request),
CONF.public_endpoint) CONF.public_endpoint)
assert self._collection_name is not None
return os.path.join(prefix, return os.path.join(prefix,
self._get_project_id_in_url(request), self._get_project_id_in_url(request),
self._collection_name, self._collection_name,
str(identifier)) str(identifier))
def _get_bookmark_link(self, request, identifier): def _get_bookmark_link(self,
request: webob.Request,
identifier: str) -> str:
"""Create a URL that refers to a specific resource.""" """Create a URL that refers to a specific resource."""
base_url = remove_version_from_href(get_request_url(request)) base_url = remove_version_from_href(get_request_url(request))
base_url = self._update_link_prefix(base_url, base_url = self._update_link_prefix(base_url,
CONF.public_endpoint) CONF.public_endpoint)
assert self._collection_name is not None
return os.path.join(base_url, return os.path.join(base_url,
self._get_project_id_in_url(request), self._get_project_id_in_url(request),
self._collection_name, self._collection_name,
str(identifier)) str(identifier))
def _get_collection_links(self, request, items, collection_name, def _get_collection_links(self,
item_count=None, id_key="uuid"): request: webob.Request,
items: list,
collection_name: str,
item_count: Optional[int] = None,
id_key: str = "uuid") -> list[dict]:
"""Retrieve 'next' link, if applicable. """Retrieve 'next' link, if applicable.
The next link is included if we are returning as many items as we can, The next link is included if we are returning as many items as we can,
@ -306,8 +333,11 @@ class ViewBuilder(object):
return [] return []
def _generate_next_link(self, items, id_key, request, def _generate_next_link(self,
collection_name): items: list,
id_key: str,
request: webob.Request,
collection_name: str) -> list[dict]:
links = [] links = []
last_item = items[-1] last_item = items[-1]
if id_key in last_item: if id_key in last_item:
@ -321,7 +351,7 @@ class ViewBuilder(object):
}) })
return links return links
def _update_link_prefix(self, orig_url, prefix): def _update_link_prefix(self, orig_url: str, prefix: Optional[str]) -> str:
if not prefix: if not prefix:
return orig_url return orig_url
url_parts = list(urllib.parse.urlsplit(orig_url)) url_parts = list(urllib.parse.urlsplit(orig_url))
@ -332,7 +362,10 @@ class ViewBuilder(object):
return urllib.parse.urlunsplit(url_parts).rstrip('/') return urllib.parse.urlunsplit(url_parts).rstrip('/')
def get_cluster_host(req, params, cluster_version=None): def get_cluster_host(req: webob.Request,
params: dict,
cluster_version=None) -> tuple[Optional[str],
Optional[str]]:
"""Get cluster and host from the parameters. """Get cluster and host from the parameters.
This method checks the presence of cluster and host parameters and returns This method checks the presence of cluster and host parameters and returns
@ -361,7 +394,7 @@ def get_cluster_host(req, params, cluster_version=None):
return cluster_name, host return cluster_name, host
def _initialize_filters(): def _initialize_filters() -> None:
global _FILTERS_COLLECTION global _FILTERS_COLLECTION
if _FILTERS_COLLECTION: if _FILTERS_COLLECTION:
return return
@ -376,7 +409,8 @@ def _initialize_filters():
_FILTERS_COLLECTION = json.load(filters_file) _FILTERS_COLLECTION = json.load(filters_file)
def get_enabled_resource_filters(resource=None): def get_enabled_resource_filters(resource: Optional[str] = None) -> dict[str,
Any]:
"""Get list of configured/allowed filters for the specified resource. """Get list of configured/allowed filters for the specified resource.
This method checks resource_query_filters_file and returns dictionary This method checks resource_query_filters_file and returns dictionary
@ -393,6 +427,8 @@ def get_enabled_resource_filters(resource=None):
""" """
try: try:
_initialize_filters() _initialize_filters()
assert _FILTERS_COLLECTION is not None
if not resource: if not resource:
return _FILTERS_COLLECTION return _FILTERS_COLLECTION
else: else:
@ -402,12 +438,12 @@ def get_enabled_resource_filters(resource=None):
return {} return {}
def get_time_comparison_operators(): def get_time_comparison_operators() -> tuple[str, ...]:
"""Get list of time comparison operators. """Get time comparison operators.
This method returns list which contains the allowed comparison operators. This method returns tuple which contains the allowed comparison operators.
""" """
return ["gt", "gte", "eq", "neq", "lt", "lte"] return ("gt", "gte", "eq", "neq", "lt", "lte")
def convert_filter_attributes(filters, resource): def convert_filter_attributes(filters, resource):
@ -418,8 +454,10 @@ def convert_filter_attributes(filters, resource):
filters.pop(key) filters.pop(key)
def reject_invalid_filters(context, filters, resource, def reject_invalid_filters(context: 'context.RequestContext',
enable_like_filter=False): filters,
resource: str,
enable_like_filter: bool = False):
invalid_filters = [] invalid_filters = []
for key in filters.copy().keys(): for key in filters.copy().keys():
try: try:
@ -439,6 +477,7 @@ def reject_invalid_filters(context, filters, resource,
# pool API is only available for admin # pool API is only available for admin
return return
# Check the configured filters against those passed in resource # Check the configured filters against those passed in resource
configured_filters: Iterable
configured_filters = get_enabled_resource_filters(resource) configured_filters = get_enabled_resource_filters(resource)
if configured_filters: if configured_filters:
configured_filters = configured_filters[resource] configured_filters = configured_filters[resource]
@ -472,12 +511,15 @@ def process_general_filtering(resource):
def _decorator(*args, **kwargs): def _decorator(*args, **kwargs):
req_version = kwargs.get('req_version') req_version = kwargs.get('req_version')
filters = kwargs.get('filters') filters = kwargs.get('filters')
context = kwargs.get('context') ctxt = kwargs.get('context')
ctxt = typing.cast('context.RequestContext', ctxt)
assert req_version is not None
if req_version.matches(mv.RESOURCE_FILTER): if req_version.matches(mv.RESOURCE_FILTER):
support_like = False support_like = False
if req_version.matches(mv.LIKE_FILTER): if req_version.matches(mv.LIKE_FILTER):
support_like = True support_like = True
reject_invalid_filters(context, filters, reject_invalid_filters(ctxt, filters,
resource, support_like) resource, support_like)
convert_filter_attributes(filters, resource) convert_filter_attributes(filters, resource)

View File

@ -1,4 +1,5 @@
cinder/api/api_utils.py cinder/api/api_utils.py
cinder/api/common.py
cinder/api/v3/types.py cinder/api/v3/types.py
cinder/backup/api.py cinder/backup/api.py
cinder/backup/manager.py cinder/backup/manager.py