mypy: annotate volume_utils / utils / exc

Change-Id: I886600b1712f4c9415e59cea7166289c0870e58c
This commit is contained in:
Eric Harney 2020-06-03 14:05:39 -04:00
parent 7441694cd4
commit 8953340820
5 changed files with 318 additions and 176 deletions

View File

@ -22,6 +22,8 @@ SHOULD include dedicated exception logging.
"""
from typing import Union
from oslo_log import log as logging
from oslo_versionedobjects import exception as obj_exc
import webob.exc
@ -35,7 +37,8 @@ LOG = logging.getLogger(__name__)
class ConvertedException(webob.exc.WSGIHTTPException):
def __init__(self, code=500, title="", explanation=""):
def __init__(self, code: int = 500, title: str = "",
explanation: str = ""):
self.code = code
# There is a strict rule about constructing status line for HTTP:
# '...Status-Line, consisting of the protocol version followed by a
@ -66,10 +69,10 @@ class CinderException(Exception):
"""
message = _("An unknown exception occurred.")
code = 500
headers = {}
headers: dict = {}
safe = False
def __init__(self, message=None, **kwargs):
def __init__(self, message: Union[str, tuple] = None, **kwargs):
self.kwargs = kwargs
self.kwargs['message'] = message
@ -112,7 +115,7 @@ class CinderException(Exception):
# with duplicate keyword exception.
self.kwargs.pop('message', None)
def _log_exception(self):
def _log_exception(self) -> None:
# kwargs doesn't match a variable in the message
# log the issue and the kwargs
LOG.exception('Exception in string format operation:')
@ -120,7 +123,7 @@ class CinderException(Exception):
LOG.error("%(name)s: %(value)s",
{'name': name, 'value': value})
def _should_format(self):
def _should_format(self) -> bool:
return self.kwargs['message'] is None or '%(message)' in self.message

View File

@ -297,12 +297,12 @@ class TemporaryChownTestCase(test.TestCase):
mock_stat.return_value.st_uid = 5678
test_filename = 'a_file'
with utils.temporary_chown(test_filename):
mock_exec.assert_called_once_with('chown', 1234, test_filename,
mock_exec.assert_called_once_with('chown', '1234', test_filename,
run_as_root=True)
mock_getuid.assert_called_once_with()
mock_stat.assert_called_once_with(test_filename)
calls = [mock.call('chown', 1234, test_filename, run_as_root=True),
mock.call('chown', 5678, test_filename, run_as_root=True)]
calls = [mock.call('chown', '1234', test_filename, run_as_root=True),
mock.call('chown', '5678', test_filename, run_as_root=True)]
mock_exec.assert_has_calls(calls)
@mock.patch('os.stat')
@ -312,12 +312,12 @@ class TemporaryChownTestCase(test.TestCase):
mock_stat.return_value.st_uid = 5678
test_filename = 'a_file'
with utils.temporary_chown(test_filename, owner_uid=9101):
mock_exec.assert_called_once_with('chown', 9101, test_filename,
mock_exec.assert_called_once_with('chown', '9101', test_filename,
run_as_root=True)
self.assertFalse(mock_getuid.called)
mock_stat.assert_called_once_with(test_filename)
calls = [mock.call('chown', 9101, test_filename, run_as_root=True),
mock.call('chown', 5678, test_filename, run_as_root=True)]
calls = [mock.call('chown', '9101', test_filename, run_as_root=True),
mock.call('chown', '5678', test_filename, run_as_root=True)]
mock_exec.assert_has_calls(calls)
@mock.patch('os.stat')

View File

@ -22,6 +22,7 @@ import contextlib
import datetime
import functools
import inspect
import logging as py_logging
import math
import multiprocessing
import operator
@ -32,6 +33,9 @@ import shutil
import stat
import sys
import tempfile
import typing
from typing import Callable, Dict, Iterable, Iterator, List # noqa: H301
from typing import Optional, Tuple, Type, Union # noqa: H301
import eventlet
from eventlet import tpool
@ -59,7 +63,7 @@ INFINITE_UNKNOWN_VALUES = ('infinite', 'unknown')
synchronized = lockutils.synchronized_with_prefix('cinder-')
def as_int(obj, quiet=True):
def as_int(obj: Union[int, float, str], quiet: bool = True) -> int:
# Try "2" -> 2
try:
return int(obj)
@ -73,10 +77,12 @@ def as_int(obj, quiet=True):
# Eck, not sure what this is then.
if not quiet:
raise TypeError(_("Can not translate %s to integer.") % (obj))
obj = typing.cast(int, obj)
return obj
def check_exclusive_options(**kwargs):
def check_exclusive_options(**kwargs: dict) -> None:
"""Checks that only one of the provided options is actually not-none.
Iterates over all the kwargs passed in and checks that only one of said
@ -99,24 +105,24 @@ def check_exclusive_options(**kwargs):
#
# Ex: 'the_key' -> 'the key'
if pretty_keys:
names = [k.replace('_', ' ') for k in kwargs]
tnames = [k.replace('_', ' ') for k in kwargs]
else:
names = kwargs.keys()
names = ", ".join(sorted(names))
tnames = list(kwargs.keys())
names = ", ".join(sorted(tnames))
msg = (_("May specify only one of %s") % (names))
raise exception.InvalidInput(reason=msg)
def execute(*cmd, **kwargs):
def execute(*cmd: str, **kwargs) -> Tuple[str, str]:
"""Convenience wrapper around oslo's execute() method."""
if 'run_as_root' in kwargs and 'root_helper' not in kwargs:
kwargs['root_helper'] = get_root_helper()
return processutils.execute(*cmd, **kwargs)
def check_ssh_injection(cmd_list):
ssh_injection_pattern = ['`', '$', '|', '||', ';', '&', '&&', '>', '>>',
'<']
def check_ssh_injection(cmd_list: List[str]) -> None:
ssh_injection_pattern: Tuple[str, ...] = ('`', '$', '|', '||', ';', '&',
'&&', '>', '>>', '<')
# Check whether injection attacks exist
for arg in cmd_list:
@ -149,7 +155,8 @@ def check_ssh_injection(cmd_list):
raise exception.SSHInjectionThreat(command=cmd_list)
def check_metadata_properties(metadata=None):
def check_metadata_properties(
metadata: Optional[Dict[str, str]]) -> None:
"""Checks that the volume metadata properties are valid."""
if not metadata:
@ -175,7 +182,9 @@ def check_metadata_properties(metadata=None):
raise exception.InvalidVolumeMetadataSize(reason=msg)
def last_completed_audit_period(unit=None):
def last_completed_audit_period(unit: str = None) -> \
Tuple[Union[datetime.datetime, datetime.timedelta],
Union[datetime.datetime, datetime.timedelta]]:
"""This method gives you the most recently *completed* audit period.
arguments:
@ -196,11 +205,15 @@ def last_completed_audit_period(unit=None):
if not unit:
unit = CONF.volume_usage_audit_period
offset = 0
unit = typing.cast(str, unit)
offset: Union[str, int] = 0
if '@' in unit:
unit, offset = unit.split("@", 1)
offset = int(offset)
offset = typing.cast(int, offset)
rightnow = timeutils.utcnow()
if unit not in ('month', 'day', 'year', 'hour'):
raise ValueError('Time period must be hour, day, month or year')
@ -262,7 +275,7 @@ def last_completed_audit_period(unit=None):
return (begin, end)
def monkey_patch():
def monkey_patch() -> None:
"""Patches decorators for all functions in a specified module.
If the CONF.monkey_patch set as True,
@ -309,7 +322,7 @@ def monkey_patch():
decorator("%s.%s" % (module, key), func))
def make_dev_path(dev, partition=None, base='/dev'):
def make_dev_path(dev: str, partition: str = None, base: str = '/dev') -> str:
"""Return a path to a particular device.
>>> make_dev_path('xvdc')
@ -324,7 +337,7 @@ def make_dev_path(dev, partition=None, base='/dev'):
return path
def robust_file_write(directory, filename, data):
def robust_file_write(directory: str, filename: str, data: str) -> None:
"""Robust file write.
Use "write to temp file and rename" model for writing the
@ -360,15 +373,16 @@ def robust_file_write(directory, filename, data):
with excutils.save_and_reraise_exception():
LOG.error("Failed to write persistence file: %(path)s.",
{'path': os.path.join(directory, filename)})
if tempname is not None:
if os.path.isfile(tempname):
os.unlink(tempname)
finally:
if dirfd:
if dirfd is not None:
os.close(dirfd)
@contextlib.contextmanager
def temporary_chown(path, owner_uid=None):
def temporary_chown(path: str, owner_uid: int = None) -> Iterator[None]:
"""Temporarily chown a path.
:params owner_uid: UID of temporary owner (defaults to current user)
@ -386,16 +400,16 @@ def temporary_chown(path, owner_uid=None):
orig_uid = os.stat(path).st_uid
if orig_uid != owner_uid:
execute('chown', owner_uid, path, run_as_root=True)
execute('chown', str(owner_uid), path, run_as_root=True)
try:
yield
finally:
if orig_uid != owner_uid:
execute('chown', orig_uid, path, run_as_root=True)
execute('chown', str(orig_uid), path, run_as_root=True)
@contextlib.contextmanager
def tempdir(**kwargs):
def tempdir(**kwargs) -> Iterator[str]:
tmpdir = tempfile.mkdtemp(**kwargs)
try:
yield tmpdir
@ -406,11 +420,11 @@ def tempdir(**kwargs):
LOG.debug('Could not remove tmpdir: %s', str(e))
def get_root_helper():
def get_root_helper() -> str:
return 'sudo cinder-rootwrap %s' % CONF.rootwrap_config
def require_driver_initialized(driver):
def require_driver_initialized(driver) -> None:
"""Verifies if `driver` is initialized
If the driver is not initialized, an exception will be raised.
@ -427,7 +441,7 @@ def require_driver_initialized(driver):
log_unsupported_driver_warning(driver)
def log_unsupported_driver_warning(driver):
def log_unsupported_driver_warning(driver) -> None:
"""Annoy the log about unsupported drivers."""
if not driver.supported:
# Check to see if the driver is flagged as supported.
@ -440,22 +454,24 @@ def log_unsupported_driver_warning(driver):
'id': driver.__class__.__name__})
def get_file_mode(path):
def get_file_mode(path: str) -> int:
"""This primarily exists to make unit testing easier."""
return stat.S_IMODE(os.stat(path).st_mode)
def get_file_gid(path):
def get_file_gid(path: str) -> int:
"""This primarily exists to make unit testing easier."""
return os.stat(path).st_gid
def get_file_size(path):
def get_file_size(path: str) -> int:
"""Returns the file size."""
return os.stat(path).st_size
def _get_disk_of_partition(devpath, st=None):
def _get_disk_of_partition(
devpath: str,
st: os.stat_result = None) -> Tuple[str, os.stat_result]:
"""Gets a disk device path and status from partition path.
Returns a disk device path from a partition device path, and stat for
@ -478,7 +494,9 @@ def _get_disk_of_partition(devpath, st=None):
return (devpath, st)
def get_bool_param(param_string, params, default=False):
def get_bool_param(param_string: str,
params: dict,
default: bool = False) -> bool:
param = params.get(param_string, default)
if not strutils.is_valid_boolstr(param):
msg = _("Value '%(param)s' for '%(param_string)s' is not "
@ -488,7 +506,8 @@ def get_bool_param(param_string, params, default=False):
return strutils.bool_from_string(param, strict=True)
def get_blkdev_major_minor(path, lookup_for_file=True):
def get_blkdev_major_minor(path: str,
lookup_for_file: bool = True) -> Optional[str]:
"""Get 'major:minor' number of block device.
Get the device's 'major:minor' number of a block device to control
@ -516,8 +535,9 @@ def get_blkdev_major_minor(path, lookup_for_file=True):
raise exception.CinderException(msg)
def check_string_length(value, name, min_length=0, max_length=None,
allow_all_spaces=True):
def check_string_length(value: str, name: str, min_length: int = 0,
max_length: int = None,
allow_all_spaces: bool = True) -> None:
"""Check the length of specified string.
:param value: the value of the string
@ -537,7 +557,7 @@ def check_string_length(value, name, min_length=0, max_length=None,
raise exception.InvalidInput(reason=msg)
def is_blk_device(dev):
def is_blk_device(dev: str) -> bool:
try:
if stat.S_ISBLK(os.stat(dev).st_mode):
return True
@ -548,30 +568,30 @@ def is_blk_device(dev):
class ComparableMixin(object):
def _compare(self, other, method):
def _compare(self, other: object, method: Callable):
try:
return method(self._cmpkey(), other._cmpkey())
return method(self._cmpkey(), other._cmpkey()) # type: ignore
except (AttributeError, TypeError):
# _cmpkey not implemented, or return different type,
# so I can't compare with "other".
return NotImplemented
def __lt__(self, other):
def __lt__(self, other: object):
return self._compare(other, lambda s, o: s < o)
def __le__(self, other):
def __le__(self, other: object):
return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other):
def __eq__(self, other: object):
return self._compare(other, lambda s, o: s == o)
def __ge__(self, other):
def __ge__(self, other: object):
return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other):
def __gt__(self, other: object):
return self._compare(other, lambda s, o: s > o)
def __ne__(self, other):
def __ne__(self, other: object):
return self._compare(other, lambda s, o: s != o)
@ -586,8 +606,12 @@ class retry_if_exit_code(tenacity.retry_if_exception):
exc.exit_code in self.codes)
def retry(retry_param, interval=1, retries=3, backoff_rate=2,
wait_random=False, retry=tenacity.retry_if_exception_type):
def retry(retry_param: Optional[Type[Exception]],
interval: int = 1,
retries: int = 3,
backoff_rate: int = 2,
wait_random: bool = False,
retry=tenacity.retry_if_exception_type) -> Callable:
if retries < 1:
raise ValueError('Retries must be greater than or '
@ -599,7 +623,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2,
wait = tenacity.wait_exponential(
multiplier=interval, min=0, exp_base=backoff_rate)
def _decorator(f):
def _decorator(f: Callable) -> Callable:
@functools.wraps(f)
def _wrapper(*args, **kwargs):
@ -618,7 +642,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2,
return _decorator
def convert_str(text):
def convert_str(text: Union[str, bytes]) -> str:
"""Convert to native string.
Convert bytes and Unicode strings to native strings:
@ -633,7 +657,8 @@ def convert_str(text):
return text
def build_or_str(elements, str_format=None):
def build_or_str(elements: Union[None, str, Iterable[str]],
str_format: str = None) -> str:
"""Builds a string of elements joined by 'or'.
Will join strings with the 'or' word and if a str_format is provided it
@ -651,18 +676,21 @@ def build_or_str(elements, str_format=None):
if not isinstance(elements, str):
elements = _(' or ').join(elements)
elements = typing.cast(str, elements)
if str_format:
return str_format % elements
return elements
def calculate_virtual_free_capacity(total_capacity,
free_capacity,
provisioned_capacity,
thin_provisioning_support,
max_over_subscription_ratio,
reserved_percentage,
thin):
def calculate_virtual_free_capacity(total_capacity: float,
free_capacity: float,
provisioned_capacity: float,
thin_provisioning_support: bool,
max_over_subscription_ratio: float,
reserved_percentage: float,
thin: bool) -> float:
"""Calculate the virtual free capacity based on thin provisioning support.
:param total_capacity: total_capacity_gb of a host_state or pool.
@ -693,8 +721,9 @@ def calculate_virtual_free_capacity(total_capacity,
return free
def calculate_max_over_subscription_ratio(capability,
global_max_over_subscription_ratio):
def calculate_max_over_subscription_ratio(
capability: dict,
global_max_over_subscription_ratio: float) -> float:
# provisioned_capacity_gb is the apparent total capacity of
# all the volumes created on a backend, which is greater than
# or equal to allocated_capacity_gb, which is the apparent
@ -752,7 +781,7 @@ def calculate_max_over_subscription_ratio(capability,
return max_over_subscription_ratio
def validate_dictionary_string_length(specs):
def validate_dictionary_string_length(specs: dict) -> None:
"""Check the length of each key and value of dictionary."""
if not isinstance(specs, dict):
msg = _('specs must be a dictionary.')
@ -768,7 +797,8 @@ def validate_dictionary_string_length(specs):
min_length=0, max_length=255)
def service_expired_time(with_timezone=False):
def service_expired_time(
with_timezone: Optional[bool] = False) -> datetime.datetime:
return (timeutils.utcnow(with_timezone=with_timezone) -
datetime.timedelta(seconds=CONF.service_down_time))
@ -794,7 +824,7 @@ def notifications_enabled(conf):
return notifications_driver and notifications_driver != {'noop'}
def if_notifications_enabled(f):
def if_notifications_enabled(f: Callable) -> Callable:
"""Calls decorated method only if notifications are enabled."""
@functools.wraps(f)
def wrapped(*args, **kwargs):
@ -807,7 +837,7 @@ def if_notifications_enabled(f):
LOG_LEVELS = ('INFO', 'WARNING', 'ERROR', 'DEBUG')
def get_log_method(level_string):
def get_log_method(level_string: str) -> int:
level_string = level_string or ''
upper_level_string = level_string.upper()
if upper_level_string not in LOG_LEVELS:
@ -816,7 +846,7 @@ def get_log_method(level_string):
return getattr(logging, upper_level_string)
def set_log_levels(prefix, level_string):
def set_log_levels(prefix: str, level_string: str) -> None:
level = get_log_method(level_string)
prefix = prefix or ''
@ -825,18 +855,18 @@ def set_log_levels(prefix, level_string):
v.logger.setLevel(level)
def get_log_levels(prefix):
def get_log_levels(prefix: str) -> dict:
prefix = prefix or ''
return {k: logging.logging.getLevelName(v.logger.getEffectiveLevel())
return {k: py_logging.getLevelName(v.logger.getEffectiveLevel())
for k, v in logging.get_loggers().items()
if k and k.startswith(prefix)}
def paths_normcase_equal(path_a, path_b):
def paths_normcase_equal(path_a: str, path_b: str) -> bool:
return os.path.normcase(path_a) == os.path.normcase(path_b)
def create_ordereddict(adict):
def create_ordereddict(adict: dict) -> OrderedDict:
"""Given a dict, return a sorted OrderedDict."""
return OrderedDict(sorted(adict.items(),
key=operator.itemgetter(0)))
@ -859,7 +889,9 @@ class Semaphore(object):
return self.semaphore.__exit__(*args)
def semaphore_factory(limit, concurrent_processes):
def semaphore_factory(limit: int,
concurrent_processes: int) -> Union[eventlet.Semaphore,
Semaphore]:
"""Get a semaphore to limit concurrent operations.
The semaphore depends on the limit we want to set and the concurrent
@ -876,7 +908,7 @@ def semaphore_factory(limit, concurrent_processes):
return contextlib.suppress()
def limit_operations(func):
def limit_operations(func: Callable) -> Callable:
"""Decorator to limit the number of concurrent operations.
This method decorator expects to have a _semaphore attribute holding an

View File

@ -30,6 +30,9 @@ import socket
import tempfile
import time
import types
import typing
from typing import Any, BinaryIO, Callable, Dict, IO # noqa: H301
from typing import List, Optional, Tuple, Union # noqa: H301
import uuid
from castellan.common.credentials import keystone_password
@ -69,7 +72,7 @@ CONF = cfg.CONF
LOG = logging.getLogger(__name__)
GB = units.Gi
GB: int = units.Gi
# These attributes we will attempt to save for the volume if they exist
# in the source image metadata.
IMAGE_ATTRIBUTES = (
@ -85,11 +88,13 @@ TRACE_API = False
TRACE_METHOD = False
def null_safe_str(s):
def null_safe_str(s: Optional[str]) -> str:
return str(s) if s else ''
def _usage_from_volume(context, volume_ref, **kw):
def _usage_from_volume(context: context.RequestContext,
volume_ref: 'objects.Volume',
**kw) -> dict:
now = timeutils.utcnow()
launched_at = volume_ref['launched_at'] or now
created_at = volume_ref['created_at'] or now
@ -131,7 +136,7 @@ def _usage_from_volume(context, volume_ref, **kw):
return usage_info
def _usage_from_backup(backup, **kw):
def _usage_from_backup(backup: 'objects.Backup', **kw) -> dict:
num_dependent_backups = backup.num_dependent_backups
usage_info = dict(tenant_id=backup.project_id,
user_id=backup.user_id,
@ -156,8 +161,11 @@ def _usage_from_backup(backup, **kw):
@utils.if_notifications_enabled
def notify_about_volume_usage(context, volume, event_suffix,
extra_usage_info=None, host=None):
def notify_about_volume_usage(context: context.RequestContext,
volume: 'objects.Volume',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -171,9 +179,11 @@ def notify_about_volume_usage(context, volume, event_suffix,
@utils.if_notifications_enabled
def notify_about_backup_usage(context, backup, event_suffix,
extra_usage_info=None,
host=None):
def notify_about_backup_usage(context: context.RequestContext,
backup: 'objects.Backup',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -186,7 +196,9 @@ def notify_about_backup_usage(context, backup, event_suffix,
usage_info)
def _usage_from_snapshot(snapshot, context, **extra_usage_info):
def _usage_from_snapshot(snapshot: 'objects.Snapshot',
context: context.RequestContext,
**extra_usage_info) -> dict:
# (niedbalski) a snapshot might be related to a deleted
# volume, if that's the case, the volume information is still
# required for filling the usage_info, so we enforce to read
@ -212,8 +224,11 @@ def _usage_from_snapshot(snapshot, context, **extra_usage_info):
@utils.if_notifications_enabled
def notify_about_snapshot_usage(context, snapshot, event_suffix,
extra_usage_info=None, host=None):
def notify_about_snapshot_usage(context: context.RequestContext,
snapshot: 'objects.Snapshot',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -227,7 +242,8 @@ def notify_about_snapshot_usage(context, snapshot, event_suffix,
usage_info)
def _usage_from_capacity(capacity, **extra_usage_info):
def _usage_from_capacity(capacity: Dict[str, Any],
**extra_usage_info) -> Dict[str, Any]:
capacity_info = {
'name_to_id': capacity['name_to_id'],
@ -244,8 +260,11 @@ def _usage_from_capacity(capacity, **extra_usage_info):
@utils.if_notifications_enabled
def notify_about_capacity_usage(context, capacity, suffix,
extra_usage_info=None, host=None):
def notify_about_capacity_usage(context: context.RequestContext,
capacity: dict,
suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -260,8 +279,11 @@ def notify_about_capacity_usage(context, capacity, suffix,
@utils.if_notifications_enabled
def notify_about_replication_usage(context, volume, suffix,
extra_usage_info=None, host=None):
def notify_about_replication_usage(context: context.RequestContext,
volume: 'objects.Volume',
suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -277,8 +299,11 @@ def notify_about_replication_usage(context, volume, suffix,
@utils.if_notifications_enabled
def notify_about_replication_error(context, volume, suffix,
extra_error_info=None, host=None):
def notify_about_replication_error(context: context.RequestContext,
volume: 'objects.Volume',
suffix: str,
extra_error_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -293,7 +318,7 @@ def notify_about_replication_error(context, volume, suffix,
usage_info)
def _usage_from_consistencygroup(group_ref, **kw):
def _usage_from_consistencygroup(group_ref: 'objects.Group', **kw) -> dict:
usage_info = dict(tenant_id=group_ref.project_id,
user_id=group_ref.user_id,
availability_zone=group_ref.availability_zone,
@ -307,8 +332,11 @@ def _usage_from_consistencygroup(group_ref, **kw):
@utils.if_notifications_enabled
def notify_about_consistencygroup_usage(context, group, event_suffix,
extra_usage_info=None, host=None):
def notify_about_consistencygroup_usage(context: context.RequestContext,
group: 'objects.Group',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -324,7 +352,7 @@ def notify_about_consistencygroup_usage(context, group, event_suffix,
usage_info)
def _usage_from_group(group_ref, **kw):
def _usage_from_group(group_ref: 'objects.Group', **kw) -> dict:
usage_info = dict(tenant_id=group_ref.project_id,
user_id=group_ref.user_id,
availability_zone=group_ref.availability_zone,
@ -339,8 +367,11 @@ def _usage_from_group(group_ref, **kw):
@utils.if_notifications_enabled
def notify_about_group_usage(context, group, event_suffix,
extra_usage_info=None, host=None):
def notify_about_group_usage(context: context.RequestContext,
group: 'objects.Group',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -356,7 +387,7 @@ def notify_about_group_usage(context, group, event_suffix,
usage_info)
def _usage_from_cgsnapshot(cgsnapshot, **kw):
def _usage_from_cgsnapshot(cgsnapshot: 'objects.CGSnapshot', **kw) -> dict:
usage_info = dict(
tenant_id=cgsnapshot.project_id,
user_id=cgsnapshot.user_id,
@ -370,7 +401,8 @@ def _usage_from_cgsnapshot(cgsnapshot, **kw):
return usage_info
def _usage_from_group_snapshot(group_snapshot, **kw):
def _usage_from_group_snapshot(group_snapshot: 'objects.GroupSnapshot',
**kw) -> dict:
usage_info = dict(
tenant_id=group_snapshot.project_id,
user_id=group_snapshot.user_id,
@ -386,8 +418,11 @@ def _usage_from_group_snapshot(group_snapshot, **kw):
@utils.if_notifications_enabled
def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix,
extra_usage_info=None, host=None):
def notify_about_cgsnapshot_usage(context: context.RequestContext,
cgsnapshot: 'objects.CGSnapshot',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -404,8 +439,11 @@ def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix,
@utils.if_notifications_enabled
def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix,
extra_usage_info=None, host=None):
def notify_about_group_snapshot_usage(context: context.RequestContext,
group_snapshot: 'objects.GroupSnapshot',
event_suffix: str,
extra_usage_info=None,
host: str = None) -> None:
if not host:
host = CONF.host
@ -421,13 +459,14 @@ def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix,
usage_info)
def _check_blocksize(blocksize):
def _check_blocksize(blocksize: Union[str, int]) -> Union[str, int]:
# Check if volume_dd_blocksize is valid
try:
# Rule out zero-sized/negative/float dd blocksize which
# cannot be caught by strutils
if blocksize.startswith(('-', '0')) or '.' in blocksize:
if (blocksize.startswith(('-', '0')) or # type: ignore
'.' in blocksize): # type: ignore
raise ValueError
strutils.string_to_bytes('%sB' % blocksize)
except ValueError:
@ -442,7 +481,8 @@ def _check_blocksize(blocksize):
return blocksize
def check_for_odirect_support(src, dest, flag='oflag=direct'):
def check_for_odirect_support(src: str, dest: str,
flag: str = 'oflag=direct') -> bool:
# Check whether O_DIRECT is supported
try:
@ -459,9 +499,12 @@ def check_for_odirect_support(src, dest, flag='oflag=direct'):
return False
def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize,
sync=False, execute=utils.execute, ionice=None,
sparse=False):
def _copy_volume_with_path(prefix, srcstr: str, deststr: str,
size_in_m: int, blocksize: Union[str, int],
sync: bool = False,
execute: Callable = utils.execute,
ionice=None,
sparse: bool = False) -> None:
cmd = prefix[:]
if ionice:
@ -514,16 +557,18 @@ def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize,
{'size_in_m': size_in_m, 'mbps': mbps})
def _open_volume_with_path(path, mode):
def _open_volume_with_path(path: str, mode: str) -> IO[Any]:
try:
with utils.temporary_chown(path):
handle = open(path, mode)
return handle
except Exception:
LOG.error("Failed to open volume from %(path)s.", {'path': path})
raise
def _transfer_data(src, dest, length, chunk_size):
def _transfer_data(src: IO, dest: IO,
length: int, chunk_size: int) -> None:
"""Transfer data between files (Python IO objects)."""
chunks = int(math.ceil(length / chunk_size))
@ -554,15 +599,21 @@ def _transfer_data(src, dest, length, chunk_size):
tpool.execute(dest.flush)
def _copy_volume_with_file(src, dest, size_in_m):
def _copy_volume_with_file(src: Union[str, IO],
dest: Union[str, IO],
size_in_m: int) -> None:
src_handle = src
if isinstance(src, str):
src_handle = _open_volume_with_path(src, 'rb')
src_handle = typing.cast(IO, src_handle)
dest_handle = dest
if isinstance(dest, str):
dest_handle = _open_volume_with_path(dest, 'wb')
dest_handle = typing.cast(IO, dest_handle)
if not src_handle:
raise exception.DeviceUnavailable(
_("Failed to copy volume, source device unavailable."))
@ -588,9 +639,12 @@ def _copy_volume_with_file(src, dest, size_in_m):
{'size_in_m': size_in_m, 'mbps': mbps})
def copy_volume(src, dest, size_in_m, blocksize, sync=False,
def copy_volume(src: Union[str, BinaryIO],
dest: Union[str, BinaryIO],
size_in_m: int,
blocksize: Union[str, int], sync=False,
execute=utils.execute, ionice=None, throttle=None,
sparse=False):
sparse=False) -> None:
"""Copy data from the source volume to the destination volume.
The parameters 'src' and 'dest' are both typically of type str, which
@ -617,9 +671,12 @@ def copy_volume(src, dest, size_in_m, blocksize, sync=False,
_copy_volume_with_file(src, dest, size_in_m)
def clear_volume(volume_size, volume_path, volume_clear=None,
volume_clear_size=None, volume_clear_ionice=None,
throttle=None):
def clear_volume(volume_size: int,
volume_path: str,
volume_clear: str = None,
volume_clear_size: int = None,
volume_clear_ionice: str = None,
throttle=None) -> None:
"""Unprovision old volumes to prevent data leaking between users."""
if volume_clear is None:
volume_clear = CONF.volume_clear
@ -649,24 +706,25 @@ def clear_volume(volume_size, volume_path, volume_clear=None,
value=volume_clear)
def supports_thin_provisioning():
def supports_thin_provisioning() -> bool:
return brick_lvm.LVM.supports_thin_provisioning(
utils.get_root_helper())
def get_all_physical_volumes(vg_name=None):
def get_all_physical_volumes(vg_name=None) -> list:
return brick_lvm.LVM.get_all_physical_volumes(
utils.get_root_helper(),
vg_name)
def get_all_volume_groups(vg_name=None):
def get_all_volume_groups(vg_name=None) -> list:
return brick_lvm.LVM.get_all_volume_groups(
utils.get_root_helper(),
vg_name)
def extract_availability_zones_from_volume_type(volume_type):
def extract_availability_zones_from_volume_type(volume_type) \
-> Optional[list]:
if not volume_type:
return None
extra_specs = volume_type.get('extra_specs', {})
@ -683,7 +741,9 @@ DEFAULT_PASSWORD_SYMBOLS = ('23456789', # Removed: 0,1
'abcdefghijkmnopqrstuvwxyz') # Removed: l
def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
def generate_password(
length: int = 16,
symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str:
"""Generate a random password from the supplied symbol groups.
At least one symbol from each group will be included. Unpredictable
@ -720,7 +780,9 @@ def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
return ''.join(password)
def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
def generate_username(
length: int = 20,
symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str:
# Use the same implementation as the password generation.
return generate_password(length, symbolgroups)
@ -728,7 +790,9 @@ def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
DEFAULT_POOL_NAME = '_pool0'
def extract_host(host, level='backend', default_pool_name=False):
def extract_host(host: Optional[str],
level: str = 'backend',
default_pool_name: bool = False) -> Optional[str]:
"""Extract Host, Backend or Pool information from host string.
:param host: String for host, which could include host@backend#pool info
@ -778,8 +842,11 @@ def extract_host(host, level='backend', default_pool_name=False):
else:
return None
return None # not hit
def append_host(host, pool):
def append_host(host: Optional[str],
pool: Optional[str]) -> Optional[str]:
"""Encode pool into host info."""
if not host or not pool:
return host
@ -788,7 +855,7 @@ def append_host(host, pool):
return new_host
def matching_backend_name(src_volume_type, volume_type):
def matching_backend_name(src_volume_type, volume_type) -> bool:
if src_volume_type.get('volume_backend_name') and \
volume_type.get('volume_backend_name'):
return src_volume_type.get('volume_backend_name') == \
@ -797,14 +864,14 @@ def matching_backend_name(src_volume_type, volume_type):
return False
def hosts_are_equivalent(host_1, host_2):
def hosts_are_equivalent(host_1: str, host_2: str) -> bool:
# In case host_1 or host_2 are None
if not (host_1 and host_2):
return host_1 == host_2
return extract_host(host_1) == extract_host(host_2)
def read_proc_mounts():
def read_proc_mounts() -> List[str]:
"""Read the /proc/mounts file.
It's a dummy function but it eases the writing of unit tests as mocking
@ -814,19 +881,20 @@ def read_proc_mounts():
return mounts.readlines()
def extract_id_from_volume_name(vol_name):
regex = re.compile(
def extract_id_from_volume_name(vol_name: str) -> Optional[str]:
regex: typing.Pattern = re.compile(
CONF.volume_name_template.replace('%s', r'(?P<uuid>.+)'))
match = regex.match(vol_name)
return match.group('uuid') if match else None
def check_already_managed_volume(vol_id):
def check_already_managed_volume(vol_id: Optional[str]):
"""Check cinder db for already managed volume.
:param vol_id: volume id parameter
:returns: bool -- return True, if db entry with specified
volume id exists, otherwise return False
:raises: ValueError if vol_id is not a valid uuid string
"""
try:
return (vol_id and isinstance(vol_id, str) and
@ -836,7 +904,7 @@ def check_already_managed_volume(vol_id):
return False
def extract_id_from_snapshot_name(snap_name):
def extract_id_from_snapshot_name(snap_name: str) -> Optional[str]:
"""Return a snapshot's ID from its name on the backend."""
regex = re.compile(
CONF.snapshot_name_template.replace('%s', r'(?P<uuid>.+)'))
@ -844,8 +912,12 @@ def extract_id_from_snapshot_name(snap_name):
return match.group('uuid') if match else None
def paginate_entries_list(entries, marker, limit, offset, sort_keys,
sort_dirs):
def paginate_entries_list(entries: List[Dict],
marker: Optional[Union[dict, str]],
limit: int,
offset: Optional[int],
sort_keys: List[str],
sort_dirs: List[str]) -> list:
"""Paginate a list of entries.
:param entries: list of dictionaries
@ -859,7 +931,8 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys,
comparers = [(operator.itemgetter(key.strip()), multiplier)
for (key, multiplier) in zip(sort_keys, sort_dirs)]
def comparer(left, right):
def comparer(left, right) -> int:
fn: Callable
for fn, d in comparers:
left_val = fn(left)
right_val = fn(right)
@ -900,7 +973,7 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys,
return sorted_entries[start_index + offset:range_end + offset]
def convert_config_string_to_dict(config_string):
def convert_config_string_to_dict(config_string: str) -> dict:
"""Convert config file replication string to a dict.
The only supported form is as follows:
@ -924,12 +997,16 @@ def convert_config_string_to_dict(config_string):
return resultant_dict
def create_encryption_key(context, key_manager, volume_type_id):
def create_encryption_key(context: context.RequestContext,
key_manager,
volume_type_id: str) -> Optional[str]:
encryption_key_id = None
if volume_types.is_encrypted(context, volume_type_id):
volume_type_encryption = (
volume_type_encryption: db.sqlalchemy.models.Encryption = (
volume_types.get_volume_type_encryption(context,
volume_type_id))
if volume_type_encryption is None:
raise exception.Invalid(message="Volume type error")
cipher = volume_type_encryption.cipher
length = volume_type_encryption.key_size
algorithm = cipher.split('-')[0] if cipher else None
@ -945,10 +1022,13 @@ def create_encryption_key(context, key_manager, volume_type_id):
LOG.exception("Key manager error")
raise exception.Invalid(message="Key manager error")
typing.cast(str, encryption_key_id)
return encryption_key_id
def delete_encryption_key(context, key_manager, encryption_key_id):
def delete_encryption_key(context: context.RequestContext,
key_manager,
encryption_key_id: str) -> None:
try:
key_manager.delete(context, encryption_key_id)
except castellan_exception.ManagedObjectNotFoundError:
@ -972,7 +1052,9 @@ def delete_encryption_key(context, key_manager, encryption_key_id):
pass
def clone_encryption_key(context, key_manager, encryption_key_id):
def clone_encryption_key(context: context.RequestContext,
key_manager,
encryption_key_id: str) -> str:
clone_key_id = None
if encryption_key_id is not None:
clone_key_id = key_manager.store(
@ -981,19 +1063,19 @@ def clone_encryption_key(context, key_manager, encryption_key_id):
return clone_key_id
def is_boolean_str(str):
def is_boolean_str(str: Optional[str]) -> bool:
spec = (str or '').split()
return (len(spec) == 2 and
spec[0] == '<is>' and strutils.bool_from_string(spec[1]))
def is_replicated_spec(extra_specs):
return (extra_specs and
def is_replicated_spec(extra_specs: dict) -> bool:
return (bool(extra_specs) and
is_boolean_str(extra_specs.get('replication_enabled')))
def is_multiattach_spec(extra_specs):
return (extra_specs and
def is_multiattach_spec(extra_specs: dict) -> bool:
return (bool(extra_specs) and
is_boolean_str(extra_specs.get('multiattach')))
@ -1003,7 +1085,7 @@ def group_get_by_id(group_id):
return group
def is_group_a_cg_snapshot_type(group_or_snap):
def is_group_a_cg_snapshot_type(group_or_snap) -> bool:
LOG.debug("Checking if %s is a consistent snapshot group",
group_or_snap)
if group_or_snap["group_type_id"] is not None:
@ -1015,7 +1097,7 @@ def is_group_a_cg_snapshot_type(group_or_snap):
return False
def is_group_a_type(group, key):
def is_group_a_type(group: 'objects.Group', key: str) -> bool:
if group.group_type_id is not None:
spec = group_types.get_group_type_specs(
group.group_type_id, key=key
@ -1024,7 +1106,9 @@ def is_group_a_type(group, key):
return False
def get_max_over_subscription_ratio(str_value, supports_auto=False):
def get_max_over_subscription_ratio(
str_value: Union[str, float],
supports_auto: bool = False) -> Union[str, float]:
"""Get the max_over_subscription_ratio from a string
As some drivers need to do some calculations with the value and we are now
@ -1044,6 +1128,7 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False):
raise exception.VolumeDriverException(message=msg)
if str_value == 'auto':
str_value = typing.cast(str, str_value)
return str_value
mosr = float(str_value)
@ -1055,7 +1140,8 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False):
return mosr
def check_image_metadata(image_meta, vol_size):
def check_image_metadata(image_meta: Dict[str, Union[str, int]],
vol_size: int) -> None:
"""Validates the image metadata."""
# Check whether image is active
if image_meta['status'] != 'active':
@ -1074,6 +1160,7 @@ def check_image_metadata(image_meta, vol_size):
# Check image min_disk requirement is met for the particular volume
min_disk = image_meta.get('min_disk', 0)
min_disk = typing.cast(int, min_disk)
if vol_size < min_disk:
msg = _('Volume size %(volume_size)sGB cannot be smaller'
' than the image minDisk size %(min_disk)sGB.')
@ -1081,7 +1168,7 @@ def check_image_metadata(image_meta, vol_size):
raise exception.InvalidInput(reason=msg)
def enable_bootable_flag(volume):
def enable_bootable_flag(volume: 'objects.Volume') -> None:
try:
LOG.debug('Marking volume %s as bootable.', volume.id)
volume.bootable = True
@ -1092,7 +1179,8 @@ def enable_bootable_flag(volume):
raise exception.MetadataUpdateFailure(reason=ex)
def get_volume_image_metadata(image_id, image_meta):
def get_volume_image_metadata(image_id: str,
image_meta: Dict[str, Any]) -> dict:
# Save some base attributes into the volume metadata
base_metadata = {
@ -1114,6 +1202,7 @@ def get_volume_image_metadata(image_id, image_meta):
# Save all the image metadata properties into the volume metadata
property_metadata = {}
image_properties = image_meta.get('properties', {})
image_properties = typing.cast(dict, image_properties)
for (key, value) in image_properties.items():
if value is not None:
property_metadata[key] = value
@ -1123,8 +1212,12 @@ def get_volume_image_metadata(image_id, image_meta):
return volume_metadata
def copy_image_to_volume(driver, context, volume, image_meta, image_location,
image_service):
def copy_image_to_volume(driver,
context: context.RequestContext,
volume: 'objects.Volume',
image_meta: dict,
image_location: str,
image_service) -> None:
"""Downloads Glance image to the specified volume."""
image_id = image_meta['id']
LOG.debug("Attempting download of %(image_id)s (%(image_location)s)"
@ -1173,7 +1266,7 @@ def copy_image_to_volume(driver, context, volume, image_meta, image_location,
'image_location': image_location})
def image_conversion_dir():
def image_conversion_dir() -> str:
tmpdir = (CONF.image_conversion_dir or
tempfile.gettempdir())
@ -1184,7 +1277,9 @@ def image_conversion_dir():
return tmpdir
def check_encryption_provider(db, volume, context):
def check_encryption_provider(db,
volume: 'objects.Volume',
context: context.RequestContext) -> dict:
"""Check that this is a LUKS encryption provider.
:returns: encryption dict
@ -1212,14 +1307,14 @@ def check_encryption_provider(db, volume, context):
return encryption
def sanitize_host(host):
def sanitize_host(host: str) -> str:
"""Ensure IPv6 addresses are enclosed in [] for iSCSI portals."""
if netutils.is_valid_ipv6(host):
return '[%s]' % host
return host
def sanitize_hostname(hostname):
def sanitize_hostname(hostname) -> str:
"""Return a hostname which conforms to RFC-952 and RFC-1123 specs."""
hostname = hostname.encode('latin-1', 'ignore')
hostname = hostname.decode('latin-1')
@ -1232,7 +1327,7 @@ def sanitize_hostname(hostname):
return hostname
def resolve_hostname(hostname):
def resolve_hostname(hostname: str) -> str:
"""Resolves host name to IP address.
Resolves a host name (my.data.point.com) to an IP address (10.12.143.11).
@ -1248,7 +1343,9 @@ def resolve_hostname(hostname):
return ip
def update_backup_error(backup, err, status=fields.BackupStatus.ERROR):
def update_backup_error(backup,
err: str,
status=fields.BackupStatus.ERROR) -> None:
backup.status = status
backup.fail_reason = err
backup.save()
@ -1256,7 +1353,7 @@ def update_backup_error(backup, err, status=fields.BackupStatus.ERROR):
# TODO (whoami-rajat): Remove this method when oslo.vmware calls volume_utils
# wrapper of upload_volume instead of image_utils.upload_volume
def get_base_image_ref(volume):
def get_base_image_ref(volume: 'objects.Volume'):
# This method fetches the image_id from volume glance metadata and pass
# it to the driver calling it during upload volume to image operation
base_image_ref = None
@ -1265,9 +1362,12 @@ def get_base_image_ref(volume):
return base_image_ref
def upload_volume(context, image_service, image_meta, volume_path,
volume, volume_format='raw', run_as_root=True,
compress=True):
def upload_volume(context: context.RequestContext,
image_service, image_meta, volume_path,
volume: 'objects.Volume',
volume_format: str = 'raw',
run_as_root: bool = True,
compress: bool = True) -> None:
# retrieve store information from extra-specs
store_id = volume.volume_type.extra_specs.get('image_service:store_id')
@ -1305,7 +1405,8 @@ def get_backend_configuration(backend_name, backend_opts=None):
return config
def brick_get_connector_properties(multipath=False, enforce_multipath=False):
def brick_get_connector_properties(multipath: bool = False,
enforce_multipath: bool = False):
"""Wrapper to automatically set root_helper in brick calls.
:param multipath: A boolean indicating whether the connector can
@ -1323,9 +1424,10 @@ def brick_get_connector_properties(multipath=False, enforce_multipath=False):
enforce_multipath)
def brick_get_connector(protocol, driver=None,
use_multipath=False,
device_scan_attempts=3,
def brick_get_connector(protocol: str,
driver=None,
use_multipath: bool = False,
device_scan_attempts: int = 3,
*args, **kwargs):
"""Wrapper to get a brick connector object.
@ -1342,7 +1444,7 @@ def brick_get_connector(protocol, driver=None,
*args, **kwargs)
def brick_get_encryptor(connection_info, *args, **kwargs):
def brick_get_encryptor(connection_info: dict, *args, **kwargs):
"""Wrapper to get a brick encryptor object."""
root_helper = utils.get_root_helper()
@ -1353,7 +1455,9 @@ def brick_get_encryptor(connection_info, *args, **kwargs):
*args, **kwargs)
def brick_attach_volume_encryptor(context, attach_info, encryption):
def brick_attach_volume_encryptor(context: context.RequestContext,
attach_info: dict,
encryption: dict) -> None:
"""Attach encryption layer."""
connection_info = attach_info['conn']
connection_info['data']['device_path'] = attach_info['device']['path']
@ -1362,7 +1466,7 @@ def brick_attach_volume_encryptor(context, attach_info, encryption):
encryptor.attach_volume(context, **encryption)
def brick_detach_volume_encryptor(attach_info, encryption):
def brick_detach_volume_encryptor(attach_info: dict, encryption: dict) -> None:
"""Detach encryption layer."""
connection_info = attach_info['conn']
connection_info['data']['device_path'] = attach_info['device']['path']

View File

@ -1,6 +1,9 @@
cinder/context.py
cinder/i18n.py
cinder/exception.py
cinder/manager.py
cinder/utils.py
cinder/volume/__init__.py
cinder/volume/manager.py
cinder/volume/volume_types.py
cinder/volume/volume_utils.py