mypy: annotate volume_utils / utils / exc

Change-Id: I886600b1712f4c9415e59cea7166289c0870e58c
This commit is contained in:
Eric Harney 2020-06-03 14:05:39 -04:00
parent 7441694cd4
commit 8953340820
5 changed files with 318 additions and 176 deletions

View File

@ -22,6 +22,8 @@ SHOULD include dedicated exception logging.
""" """
from typing import Union
from oslo_log import log as logging from oslo_log import log as logging
from oslo_versionedobjects import exception as obj_exc from oslo_versionedobjects import exception as obj_exc
import webob.exc import webob.exc
@ -35,7 +37,8 @@ LOG = logging.getLogger(__name__)
class ConvertedException(webob.exc.WSGIHTTPException): class ConvertedException(webob.exc.WSGIHTTPException):
def __init__(self, code=500, title="", explanation=""): def __init__(self, code: int = 500, title: str = "",
explanation: str = ""):
self.code = code self.code = code
# There is a strict rule about constructing status line for HTTP: # There is a strict rule about constructing status line for HTTP:
# '...Status-Line, consisting of the protocol version followed by a # '...Status-Line, consisting of the protocol version followed by a
@ -66,10 +69,10 @@ class CinderException(Exception):
""" """
message = _("An unknown exception occurred.") message = _("An unknown exception occurred.")
code = 500 code = 500
headers = {} headers: dict = {}
safe = False safe = False
def __init__(self, message=None, **kwargs): def __init__(self, message: Union[str, tuple] = None, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
self.kwargs['message'] = message self.kwargs['message'] = message
@ -112,7 +115,7 @@ class CinderException(Exception):
# with duplicate keyword exception. # with duplicate keyword exception.
self.kwargs.pop('message', None) self.kwargs.pop('message', None)
def _log_exception(self): def _log_exception(self) -> None:
# kwargs doesn't match a variable in the message # kwargs doesn't match a variable in the message
# log the issue and the kwargs # log the issue and the kwargs
LOG.exception('Exception in string format operation:') LOG.exception('Exception in string format operation:')
@ -120,7 +123,7 @@ class CinderException(Exception):
LOG.error("%(name)s: %(value)s", LOG.error("%(name)s: %(value)s",
{'name': name, 'value': value}) {'name': name, 'value': value})
def _should_format(self): def _should_format(self) -> bool:
return self.kwargs['message'] is None or '%(message)' in self.message return self.kwargs['message'] is None or '%(message)' in self.message

View File

@ -297,12 +297,12 @@ class TemporaryChownTestCase(test.TestCase):
mock_stat.return_value.st_uid = 5678 mock_stat.return_value.st_uid = 5678
test_filename = 'a_file' test_filename = 'a_file'
with utils.temporary_chown(test_filename): with utils.temporary_chown(test_filename):
mock_exec.assert_called_once_with('chown', 1234, test_filename, mock_exec.assert_called_once_with('chown', '1234', test_filename,
run_as_root=True) run_as_root=True)
mock_getuid.assert_called_once_with() mock_getuid.assert_called_once_with()
mock_stat.assert_called_once_with(test_filename) mock_stat.assert_called_once_with(test_filename)
calls = [mock.call('chown', 1234, test_filename, run_as_root=True), calls = [mock.call('chown', '1234', test_filename, run_as_root=True),
mock.call('chown', 5678, test_filename, run_as_root=True)] mock.call('chown', '5678', test_filename, run_as_root=True)]
mock_exec.assert_has_calls(calls) mock_exec.assert_has_calls(calls)
@mock.patch('os.stat') @mock.patch('os.stat')
@ -312,12 +312,12 @@ class TemporaryChownTestCase(test.TestCase):
mock_stat.return_value.st_uid = 5678 mock_stat.return_value.st_uid = 5678
test_filename = 'a_file' test_filename = 'a_file'
with utils.temporary_chown(test_filename, owner_uid=9101): with utils.temporary_chown(test_filename, owner_uid=9101):
mock_exec.assert_called_once_with('chown', 9101, test_filename, mock_exec.assert_called_once_with('chown', '9101', test_filename,
run_as_root=True) run_as_root=True)
self.assertFalse(mock_getuid.called) self.assertFalse(mock_getuid.called)
mock_stat.assert_called_once_with(test_filename) mock_stat.assert_called_once_with(test_filename)
calls = [mock.call('chown', 9101, test_filename, run_as_root=True), calls = [mock.call('chown', '9101', test_filename, run_as_root=True),
mock.call('chown', 5678, test_filename, run_as_root=True)] mock.call('chown', '5678', test_filename, run_as_root=True)]
mock_exec.assert_has_calls(calls) mock_exec.assert_has_calls(calls)
@mock.patch('os.stat') @mock.patch('os.stat')

View File

@ -22,6 +22,7 @@ import contextlib
import datetime import datetime
import functools import functools
import inspect import inspect
import logging as py_logging
import math import math
import multiprocessing import multiprocessing
import operator import operator
@ -32,6 +33,9 @@ import shutil
import stat import stat
import sys import sys
import tempfile import tempfile
import typing
from typing import Callable, Dict, Iterable, Iterator, List # noqa: H301
from typing import Optional, Tuple, Type, Union # noqa: H301
import eventlet import eventlet
from eventlet import tpool from eventlet import tpool
@ -59,7 +63,7 @@ INFINITE_UNKNOWN_VALUES = ('infinite', 'unknown')
synchronized = lockutils.synchronized_with_prefix('cinder-') synchronized = lockutils.synchronized_with_prefix('cinder-')
def as_int(obj, quiet=True): def as_int(obj: Union[int, float, str], quiet: bool = True) -> int:
# Try "2" -> 2 # Try "2" -> 2
try: try:
return int(obj) return int(obj)
@ -73,10 +77,12 @@ def as_int(obj, quiet=True):
# Eck, not sure what this is then. # Eck, not sure what this is then.
if not quiet: if not quiet:
raise TypeError(_("Can not translate %s to integer.") % (obj)) raise TypeError(_("Can not translate %s to integer.") % (obj))
obj = typing.cast(int, obj)
return obj return obj
def check_exclusive_options(**kwargs): def check_exclusive_options(**kwargs: dict) -> None:
"""Checks that only one of the provided options is actually not-none. """Checks that only one of the provided options is actually not-none.
Iterates over all the kwargs passed in and checks that only one of said Iterates over all the kwargs passed in and checks that only one of said
@ -99,24 +105,24 @@ def check_exclusive_options(**kwargs):
# #
# Ex: 'the_key' -> 'the key' # Ex: 'the_key' -> 'the key'
if pretty_keys: if pretty_keys:
names = [k.replace('_', ' ') for k in kwargs] tnames = [k.replace('_', ' ') for k in kwargs]
else: else:
names = kwargs.keys() tnames = list(kwargs.keys())
names = ", ".join(sorted(names)) names = ", ".join(sorted(tnames))
msg = (_("May specify only one of %s") % (names)) msg = (_("May specify only one of %s") % (names))
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
def execute(*cmd, **kwargs): def execute(*cmd: str, **kwargs) -> Tuple[str, str]:
"""Convenience wrapper around oslo's execute() method.""" """Convenience wrapper around oslo's execute() method."""
if 'run_as_root' in kwargs and 'root_helper' not in kwargs: if 'run_as_root' in kwargs and 'root_helper' not in kwargs:
kwargs['root_helper'] = get_root_helper() kwargs['root_helper'] = get_root_helper()
return processutils.execute(*cmd, **kwargs) return processutils.execute(*cmd, **kwargs)
def check_ssh_injection(cmd_list): def check_ssh_injection(cmd_list: List[str]) -> None:
ssh_injection_pattern = ['`', '$', '|', '||', ';', '&', '&&', '>', '>>', ssh_injection_pattern: Tuple[str, ...] = ('`', '$', '|', '||', ';', '&',
'<'] '&&', '>', '>>', '<')
# Check whether injection attacks exist # Check whether injection attacks exist
for arg in cmd_list: for arg in cmd_list:
@ -149,7 +155,8 @@ def check_ssh_injection(cmd_list):
raise exception.SSHInjectionThreat(command=cmd_list) raise exception.SSHInjectionThreat(command=cmd_list)
def check_metadata_properties(metadata=None): def check_metadata_properties(
metadata: Optional[Dict[str, str]]) -> None:
"""Checks that the volume metadata properties are valid.""" """Checks that the volume metadata properties are valid."""
if not metadata: if not metadata:
@ -175,7 +182,9 @@ def check_metadata_properties(metadata=None):
raise exception.InvalidVolumeMetadataSize(reason=msg) raise exception.InvalidVolumeMetadataSize(reason=msg)
def last_completed_audit_period(unit=None): def last_completed_audit_period(unit: str = None) -> \
Tuple[Union[datetime.datetime, datetime.timedelta],
Union[datetime.datetime, datetime.timedelta]]:
"""This method gives you the most recently *completed* audit period. """This method gives you the most recently *completed* audit period.
arguments: arguments:
@ -196,11 +205,15 @@ def last_completed_audit_period(unit=None):
if not unit: if not unit:
unit = CONF.volume_usage_audit_period unit = CONF.volume_usage_audit_period
offset = 0 unit = typing.cast(str, unit)
offset: Union[str, int] = 0
if '@' in unit: if '@' in unit:
unit, offset = unit.split("@", 1) unit, offset = unit.split("@", 1)
offset = int(offset) offset = int(offset)
offset = typing.cast(int, offset)
rightnow = timeutils.utcnow() rightnow = timeutils.utcnow()
if unit not in ('month', 'day', 'year', 'hour'): if unit not in ('month', 'day', 'year', 'hour'):
raise ValueError('Time period must be hour, day, month or year') raise ValueError('Time period must be hour, day, month or year')
@ -262,7 +275,7 @@ def last_completed_audit_period(unit=None):
return (begin, end) return (begin, end)
def monkey_patch(): def monkey_patch() -> None:
"""Patches decorators for all functions in a specified module. """Patches decorators for all functions in a specified module.
If the CONF.monkey_patch set as True, If the CONF.monkey_patch set as True,
@ -309,7 +322,7 @@ def monkey_patch():
decorator("%s.%s" % (module, key), func)) decorator("%s.%s" % (module, key), func))
def make_dev_path(dev, partition=None, base='/dev'): def make_dev_path(dev: str, partition: str = None, base: str = '/dev') -> str:
"""Return a path to a particular device. """Return a path to a particular device.
>>> make_dev_path('xvdc') >>> make_dev_path('xvdc')
@ -324,7 +337,7 @@ def make_dev_path(dev, partition=None, base='/dev'):
return path return path
def robust_file_write(directory, filename, data): def robust_file_write(directory: str, filename: str, data: str) -> None:
"""Robust file write. """Robust file write.
Use "write to temp file and rename" model for writing the Use "write to temp file and rename" model for writing the
@ -360,15 +373,16 @@ def robust_file_write(directory, filename, data):
with excutils.save_and_reraise_exception(): with excutils.save_and_reraise_exception():
LOG.error("Failed to write persistence file: %(path)s.", LOG.error("Failed to write persistence file: %(path)s.",
{'path': os.path.join(directory, filename)}) {'path': os.path.join(directory, filename)})
if os.path.isfile(tempname): if tempname is not None:
os.unlink(tempname) if os.path.isfile(tempname):
os.unlink(tempname)
finally: finally:
if dirfd: if dirfd is not None:
os.close(dirfd) os.close(dirfd)
@contextlib.contextmanager @contextlib.contextmanager
def temporary_chown(path, owner_uid=None): def temporary_chown(path: str, owner_uid: int = None) -> Iterator[None]:
"""Temporarily chown a path. """Temporarily chown a path.
:params owner_uid: UID of temporary owner (defaults to current user) :params owner_uid: UID of temporary owner (defaults to current user)
@ -386,16 +400,16 @@ def temporary_chown(path, owner_uid=None):
orig_uid = os.stat(path).st_uid orig_uid = os.stat(path).st_uid
if orig_uid != owner_uid: if orig_uid != owner_uid:
execute('chown', owner_uid, path, run_as_root=True) execute('chown', str(owner_uid), path, run_as_root=True)
try: try:
yield yield
finally: finally:
if orig_uid != owner_uid: if orig_uid != owner_uid:
execute('chown', orig_uid, path, run_as_root=True) execute('chown', str(orig_uid), path, run_as_root=True)
@contextlib.contextmanager @contextlib.contextmanager
def tempdir(**kwargs): def tempdir(**kwargs) -> Iterator[str]:
tmpdir = tempfile.mkdtemp(**kwargs) tmpdir = tempfile.mkdtemp(**kwargs)
try: try:
yield tmpdir yield tmpdir
@ -406,11 +420,11 @@ def tempdir(**kwargs):
LOG.debug('Could not remove tmpdir: %s', str(e)) LOG.debug('Could not remove tmpdir: %s', str(e))
def get_root_helper(): def get_root_helper() -> str:
return 'sudo cinder-rootwrap %s' % CONF.rootwrap_config return 'sudo cinder-rootwrap %s' % CONF.rootwrap_config
def require_driver_initialized(driver): def require_driver_initialized(driver) -> None:
"""Verifies if `driver` is initialized """Verifies if `driver` is initialized
If the driver is not initialized, an exception will be raised. If the driver is not initialized, an exception will be raised.
@ -427,7 +441,7 @@ def require_driver_initialized(driver):
log_unsupported_driver_warning(driver) log_unsupported_driver_warning(driver)
def log_unsupported_driver_warning(driver): def log_unsupported_driver_warning(driver) -> None:
"""Annoy the log about unsupported drivers.""" """Annoy the log about unsupported drivers."""
if not driver.supported: if not driver.supported:
# Check to see if the driver is flagged as supported. # Check to see if the driver is flagged as supported.
@ -440,22 +454,24 @@ def log_unsupported_driver_warning(driver):
'id': driver.__class__.__name__}) 'id': driver.__class__.__name__})
def get_file_mode(path): def get_file_mode(path: str) -> int:
"""This primarily exists to make unit testing easier.""" """This primarily exists to make unit testing easier."""
return stat.S_IMODE(os.stat(path).st_mode) return stat.S_IMODE(os.stat(path).st_mode)
def get_file_gid(path): def get_file_gid(path: str) -> int:
"""This primarily exists to make unit testing easier.""" """This primarily exists to make unit testing easier."""
return os.stat(path).st_gid return os.stat(path).st_gid
def get_file_size(path): def get_file_size(path: str) -> int:
"""Returns the file size.""" """Returns the file size."""
return os.stat(path).st_size return os.stat(path).st_size
def _get_disk_of_partition(devpath, st=None): def _get_disk_of_partition(
devpath: str,
st: os.stat_result = None) -> Tuple[str, os.stat_result]:
"""Gets a disk device path and status from partition path. """Gets a disk device path and status from partition path.
Returns a disk device path from a partition device path, and stat for Returns a disk device path from a partition device path, and stat for
@ -478,7 +494,9 @@ def _get_disk_of_partition(devpath, st=None):
return (devpath, st) return (devpath, st)
def get_bool_param(param_string, params, default=False): def get_bool_param(param_string: str,
params: dict,
default: bool = False) -> bool:
param = params.get(param_string, default) param = params.get(param_string, default)
if not strutils.is_valid_boolstr(param): if not strutils.is_valid_boolstr(param):
msg = _("Value '%(param)s' for '%(param_string)s' is not " msg = _("Value '%(param)s' for '%(param_string)s' is not "
@ -488,7 +506,8 @@ def get_bool_param(param_string, params, default=False):
return strutils.bool_from_string(param, strict=True) return strutils.bool_from_string(param, strict=True)
def get_blkdev_major_minor(path, lookup_for_file=True): def get_blkdev_major_minor(path: str,
lookup_for_file: bool = True) -> Optional[str]:
"""Get 'major:minor' number of block device. """Get 'major:minor' number of block device.
Get the device's 'major:minor' number of a block device to control Get the device's 'major:minor' number of a block device to control
@ -516,8 +535,9 @@ def get_blkdev_major_minor(path, lookup_for_file=True):
raise exception.CinderException(msg) raise exception.CinderException(msg)
def check_string_length(value, name, min_length=0, max_length=None, def check_string_length(value: str, name: str, min_length: int = 0,
allow_all_spaces=True): max_length: int = None,
allow_all_spaces: bool = True) -> None:
"""Check the length of specified string. """Check the length of specified string.
:param value: the value of the string :param value: the value of the string
@ -537,7 +557,7 @@ def check_string_length(value, name, min_length=0, max_length=None,
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
def is_blk_device(dev): def is_blk_device(dev: str) -> bool:
try: try:
if stat.S_ISBLK(os.stat(dev).st_mode): if stat.S_ISBLK(os.stat(dev).st_mode):
return True return True
@ -548,30 +568,30 @@ def is_blk_device(dev):
class ComparableMixin(object): class ComparableMixin(object):
def _compare(self, other, method): def _compare(self, other: object, method: Callable):
try: try:
return method(self._cmpkey(), other._cmpkey()) return method(self._cmpkey(), other._cmpkey()) # type: ignore
except (AttributeError, TypeError): except (AttributeError, TypeError):
# _cmpkey not implemented, or return different type, # _cmpkey not implemented, or return different type,
# so I can't compare with "other". # so I can't compare with "other".
return NotImplemented return NotImplemented
def __lt__(self, other): def __lt__(self, other: object):
return self._compare(other, lambda s, o: s < o) return self._compare(other, lambda s, o: s < o)
def __le__(self, other): def __le__(self, other: object):
return self._compare(other, lambda s, o: s <= o) return self._compare(other, lambda s, o: s <= o)
def __eq__(self, other): def __eq__(self, other: object):
return self._compare(other, lambda s, o: s == o) return self._compare(other, lambda s, o: s == o)
def __ge__(self, other): def __ge__(self, other: object):
return self._compare(other, lambda s, o: s >= o) return self._compare(other, lambda s, o: s >= o)
def __gt__(self, other): def __gt__(self, other: object):
return self._compare(other, lambda s, o: s > o) return self._compare(other, lambda s, o: s > o)
def __ne__(self, other): def __ne__(self, other: object):
return self._compare(other, lambda s, o: s != o) return self._compare(other, lambda s, o: s != o)
@ -586,8 +606,12 @@ class retry_if_exit_code(tenacity.retry_if_exception):
exc.exit_code in self.codes) exc.exit_code in self.codes)
def retry(retry_param, interval=1, retries=3, backoff_rate=2, def retry(retry_param: Optional[Type[Exception]],
wait_random=False, retry=tenacity.retry_if_exception_type): interval: int = 1,
retries: int = 3,
backoff_rate: int = 2,
wait_random: bool = False,
retry=tenacity.retry_if_exception_type) -> Callable:
if retries < 1: if retries < 1:
raise ValueError('Retries must be greater than or ' raise ValueError('Retries must be greater than or '
@ -599,7 +623,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2,
wait = tenacity.wait_exponential( wait = tenacity.wait_exponential(
multiplier=interval, min=0, exp_base=backoff_rate) multiplier=interval, min=0, exp_base=backoff_rate)
def _decorator(f): def _decorator(f: Callable) -> Callable:
@functools.wraps(f) @functools.wraps(f)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
@ -618,7 +642,7 @@ def retry(retry_param, interval=1, retries=3, backoff_rate=2,
return _decorator return _decorator
def convert_str(text): def convert_str(text: Union[str, bytes]) -> str:
"""Convert to native string. """Convert to native string.
Convert bytes and Unicode strings to native strings: Convert bytes and Unicode strings to native strings:
@ -633,7 +657,8 @@ def convert_str(text):
return text return text
def build_or_str(elements, str_format=None): def build_or_str(elements: Union[None, str, Iterable[str]],
str_format: str = None) -> str:
"""Builds a string of elements joined by 'or'. """Builds a string of elements joined by 'or'.
Will join strings with the 'or' word and if a str_format is provided it Will join strings with the 'or' word and if a str_format is provided it
@ -651,18 +676,21 @@ def build_or_str(elements, str_format=None):
if not isinstance(elements, str): if not isinstance(elements, str):
elements = _(' or ').join(elements) elements = _(' or ').join(elements)
elements = typing.cast(str, elements)
if str_format: if str_format:
return str_format % elements return str_format % elements
return elements return elements
def calculate_virtual_free_capacity(total_capacity, def calculate_virtual_free_capacity(total_capacity: float,
free_capacity, free_capacity: float,
provisioned_capacity, provisioned_capacity: float,
thin_provisioning_support, thin_provisioning_support: bool,
max_over_subscription_ratio, max_over_subscription_ratio: float,
reserved_percentage, reserved_percentage: float,
thin): thin: bool) -> float:
"""Calculate the virtual free capacity based on thin provisioning support. """Calculate the virtual free capacity based on thin provisioning support.
:param total_capacity: total_capacity_gb of a host_state or pool. :param total_capacity: total_capacity_gb of a host_state or pool.
@ -693,8 +721,9 @@ def calculate_virtual_free_capacity(total_capacity,
return free return free
def calculate_max_over_subscription_ratio(capability, def calculate_max_over_subscription_ratio(
global_max_over_subscription_ratio): capability: dict,
global_max_over_subscription_ratio: float) -> float:
# provisioned_capacity_gb is the apparent total capacity of # provisioned_capacity_gb is the apparent total capacity of
# all the volumes created on a backend, which is greater than # all the volumes created on a backend, which is greater than
# or equal to allocated_capacity_gb, which is the apparent # or equal to allocated_capacity_gb, which is the apparent
@ -752,7 +781,7 @@ def calculate_max_over_subscription_ratio(capability,
return max_over_subscription_ratio return max_over_subscription_ratio
def validate_dictionary_string_length(specs): def validate_dictionary_string_length(specs: dict) -> None:
"""Check the length of each key and value of dictionary.""" """Check the length of each key and value of dictionary."""
if not isinstance(specs, dict): if not isinstance(specs, dict):
msg = _('specs must be a dictionary.') msg = _('specs must be a dictionary.')
@ -768,7 +797,8 @@ def validate_dictionary_string_length(specs):
min_length=0, max_length=255) min_length=0, max_length=255)
def service_expired_time(with_timezone=False): def service_expired_time(
with_timezone: Optional[bool] = False) -> datetime.datetime:
return (timeutils.utcnow(with_timezone=with_timezone) - return (timeutils.utcnow(with_timezone=with_timezone) -
datetime.timedelta(seconds=CONF.service_down_time)) datetime.timedelta(seconds=CONF.service_down_time))
@ -794,7 +824,7 @@ def notifications_enabled(conf):
return notifications_driver and notifications_driver != {'noop'} return notifications_driver and notifications_driver != {'noop'}
def if_notifications_enabled(f): def if_notifications_enabled(f: Callable) -> Callable:
"""Calls decorated method only if notifications are enabled.""" """Calls decorated method only if notifications are enabled."""
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
@ -807,7 +837,7 @@ def if_notifications_enabled(f):
LOG_LEVELS = ('INFO', 'WARNING', 'ERROR', 'DEBUG') LOG_LEVELS = ('INFO', 'WARNING', 'ERROR', 'DEBUG')
def get_log_method(level_string): def get_log_method(level_string: str) -> int:
level_string = level_string or '' level_string = level_string or ''
upper_level_string = level_string.upper() upper_level_string = level_string.upper()
if upper_level_string not in LOG_LEVELS: if upper_level_string not in LOG_LEVELS:
@ -816,7 +846,7 @@ def get_log_method(level_string):
return getattr(logging, upper_level_string) return getattr(logging, upper_level_string)
def set_log_levels(prefix, level_string): def set_log_levels(prefix: str, level_string: str) -> None:
level = get_log_method(level_string) level = get_log_method(level_string)
prefix = prefix or '' prefix = prefix or ''
@ -825,18 +855,18 @@ def set_log_levels(prefix, level_string):
v.logger.setLevel(level) v.logger.setLevel(level)
def get_log_levels(prefix): def get_log_levels(prefix: str) -> dict:
prefix = prefix or '' prefix = prefix or ''
return {k: logging.logging.getLevelName(v.logger.getEffectiveLevel()) return {k: py_logging.getLevelName(v.logger.getEffectiveLevel())
for k, v in logging.get_loggers().items() for k, v in logging.get_loggers().items()
if k and k.startswith(prefix)} if k and k.startswith(prefix)}
def paths_normcase_equal(path_a, path_b): def paths_normcase_equal(path_a: str, path_b: str) -> bool:
return os.path.normcase(path_a) == os.path.normcase(path_b) return os.path.normcase(path_a) == os.path.normcase(path_b)
def create_ordereddict(adict): def create_ordereddict(adict: dict) -> OrderedDict:
"""Given a dict, return a sorted OrderedDict.""" """Given a dict, return a sorted OrderedDict."""
return OrderedDict(sorted(adict.items(), return OrderedDict(sorted(adict.items(),
key=operator.itemgetter(0))) key=operator.itemgetter(0)))
@ -859,7 +889,9 @@ class Semaphore(object):
return self.semaphore.__exit__(*args) return self.semaphore.__exit__(*args)
def semaphore_factory(limit, concurrent_processes): def semaphore_factory(limit: int,
concurrent_processes: int) -> Union[eventlet.Semaphore,
Semaphore]:
"""Get a semaphore to limit concurrent operations. """Get a semaphore to limit concurrent operations.
The semaphore depends on the limit we want to set and the concurrent The semaphore depends on the limit we want to set and the concurrent
@ -876,7 +908,7 @@ def semaphore_factory(limit, concurrent_processes):
return contextlib.suppress() return contextlib.suppress()
def limit_operations(func): def limit_operations(func: Callable) -> Callable:
"""Decorator to limit the number of concurrent operations. """Decorator to limit the number of concurrent operations.
This method decorator expects to have a _semaphore attribute holding an This method decorator expects to have a _semaphore attribute holding an

View File

@ -30,6 +30,9 @@ import socket
import tempfile import tempfile
import time import time
import types import types
import typing
from typing import Any, BinaryIO, Callable, Dict, IO # noqa: H301
from typing import List, Optional, Tuple, Union # noqa: H301
import uuid import uuid
from castellan.common.credentials import keystone_password from castellan.common.credentials import keystone_password
@ -69,7 +72,7 @@ CONF = cfg.CONF
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
GB = units.Gi GB: int = units.Gi
# These attributes we will attempt to save for the volume if they exist # These attributes we will attempt to save for the volume if they exist
# in the source image metadata. # in the source image metadata.
IMAGE_ATTRIBUTES = ( IMAGE_ATTRIBUTES = (
@ -85,11 +88,13 @@ TRACE_API = False
TRACE_METHOD = False TRACE_METHOD = False
def null_safe_str(s): def null_safe_str(s: Optional[str]) -> str:
return str(s) if s else '' return str(s) if s else ''
def _usage_from_volume(context, volume_ref, **kw): def _usage_from_volume(context: context.RequestContext,
volume_ref: 'objects.Volume',
**kw) -> dict:
now = timeutils.utcnow() now = timeutils.utcnow()
launched_at = volume_ref['launched_at'] or now launched_at = volume_ref['launched_at'] or now
created_at = volume_ref['created_at'] or now created_at = volume_ref['created_at'] or now
@ -131,7 +136,7 @@ def _usage_from_volume(context, volume_ref, **kw):
return usage_info return usage_info
def _usage_from_backup(backup, **kw): def _usage_from_backup(backup: 'objects.Backup', **kw) -> dict:
num_dependent_backups = backup.num_dependent_backups num_dependent_backups = backup.num_dependent_backups
usage_info = dict(tenant_id=backup.project_id, usage_info = dict(tenant_id=backup.project_id,
user_id=backup.user_id, user_id=backup.user_id,
@ -156,8 +161,11 @@ def _usage_from_backup(backup, **kw):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_volume_usage(context, volume, event_suffix, def notify_about_volume_usage(context: context.RequestContext,
extra_usage_info=None, host=None): volume: 'objects.Volume',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -171,9 +179,11 @@ def notify_about_volume_usage(context, volume, event_suffix,
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_backup_usage(context, backup, event_suffix, def notify_about_backup_usage(context: context.RequestContext,
extra_usage_info=None, backup: 'objects.Backup',
host=None): event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -186,7 +196,9 @@ def notify_about_backup_usage(context, backup, event_suffix,
usage_info) usage_info)
def _usage_from_snapshot(snapshot, context, **extra_usage_info): def _usage_from_snapshot(snapshot: 'objects.Snapshot',
context: context.RequestContext,
**extra_usage_info) -> dict:
# (niedbalski) a snapshot might be related to a deleted # (niedbalski) a snapshot might be related to a deleted
# volume, if that's the case, the volume information is still # volume, if that's the case, the volume information is still
# required for filling the usage_info, so we enforce to read # required for filling the usage_info, so we enforce to read
@ -212,8 +224,11 @@ def _usage_from_snapshot(snapshot, context, **extra_usage_info):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_snapshot_usage(context, snapshot, event_suffix, def notify_about_snapshot_usage(context: context.RequestContext,
extra_usage_info=None, host=None): snapshot: 'objects.Snapshot',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -227,7 +242,8 @@ def notify_about_snapshot_usage(context, snapshot, event_suffix,
usage_info) usage_info)
def _usage_from_capacity(capacity, **extra_usage_info): def _usage_from_capacity(capacity: Dict[str, Any],
**extra_usage_info) -> Dict[str, Any]:
capacity_info = { capacity_info = {
'name_to_id': capacity['name_to_id'], 'name_to_id': capacity['name_to_id'],
@ -244,8 +260,11 @@ def _usage_from_capacity(capacity, **extra_usage_info):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_capacity_usage(context, capacity, suffix, def notify_about_capacity_usage(context: context.RequestContext,
extra_usage_info=None, host=None): capacity: dict,
suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -260,8 +279,11 @@ def notify_about_capacity_usage(context, capacity, suffix,
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_replication_usage(context, volume, suffix, def notify_about_replication_usage(context: context.RequestContext,
extra_usage_info=None, host=None): volume: 'objects.Volume',
suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -277,8 +299,11 @@ def notify_about_replication_usage(context, volume, suffix,
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_replication_error(context, volume, suffix, def notify_about_replication_error(context: context.RequestContext,
extra_error_info=None, host=None): volume: 'objects.Volume',
suffix: str,
extra_error_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -293,7 +318,7 @@ def notify_about_replication_error(context, volume, suffix,
usage_info) usage_info)
def _usage_from_consistencygroup(group_ref, **kw): def _usage_from_consistencygroup(group_ref: 'objects.Group', **kw) -> dict:
usage_info = dict(tenant_id=group_ref.project_id, usage_info = dict(tenant_id=group_ref.project_id,
user_id=group_ref.user_id, user_id=group_ref.user_id,
availability_zone=group_ref.availability_zone, availability_zone=group_ref.availability_zone,
@ -307,8 +332,11 @@ def _usage_from_consistencygroup(group_ref, **kw):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_consistencygroup_usage(context, group, event_suffix, def notify_about_consistencygroup_usage(context: context.RequestContext,
extra_usage_info=None, host=None): group: 'objects.Group',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -324,7 +352,7 @@ def notify_about_consistencygroup_usage(context, group, event_suffix,
usage_info) usage_info)
def _usage_from_group(group_ref, **kw): def _usage_from_group(group_ref: 'objects.Group', **kw) -> dict:
usage_info = dict(tenant_id=group_ref.project_id, usage_info = dict(tenant_id=group_ref.project_id,
user_id=group_ref.user_id, user_id=group_ref.user_id,
availability_zone=group_ref.availability_zone, availability_zone=group_ref.availability_zone,
@ -339,8 +367,11 @@ def _usage_from_group(group_ref, **kw):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_group_usage(context, group, event_suffix, def notify_about_group_usage(context: context.RequestContext,
extra_usage_info=None, host=None): group: 'objects.Group',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -356,7 +387,7 @@ def notify_about_group_usage(context, group, event_suffix,
usage_info) usage_info)
def _usage_from_cgsnapshot(cgsnapshot, **kw): def _usage_from_cgsnapshot(cgsnapshot: 'objects.CGSnapshot', **kw) -> dict:
usage_info = dict( usage_info = dict(
tenant_id=cgsnapshot.project_id, tenant_id=cgsnapshot.project_id,
user_id=cgsnapshot.user_id, user_id=cgsnapshot.user_id,
@ -370,7 +401,8 @@ def _usage_from_cgsnapshot(cgsnapshot, **kw):
return usage_info return usage_info
def _usage_from_group_snapshot(group_snapshot, **kw): def _usage_from_group_snapshot(group_snapshot: 'objects.GroupSnapshot',
**kw) -> dict:
usage_info = dict( usage_info = dict(
tenant_id=group_snapshot.project_id, tenant_id=group_snapshot.project_id,
user_id=group_snapshot.user_id, user_id=group_snapshot.user_id,
@ -386,8 +418,11 @@ def _usage_from_group_snapshot(group_snapshot, **kw):
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix, def notify_about_cgsnapshot_usage(context: context.RequestContext,
extra_usage_info=None, host=None): cgsnapshot: 'objects.CGSnapshot',
event_suffix: str,
extra_usage_info: dict = None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -404,8 +439,11 @@ def notify_about_cgsnapshot_usage(context, cgsnapshot, event_suffix,
@utils.if_notifications_enabled @utils.if_notifications_enabled
def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix, def notify_about_group_snapshot_usage(context: context.RequestContext,
extra_usage_info=None, host=None): group_snapshot: 'objects.GroupSnapshot',
event_suffix: str,
extra_usage_info=None,
host: str = None) -> None:
if not host: if not host:
host = CONF.host host = CONF.host
@ -421,13 +459,14 @@ def notify_about_group_snapshot_usage(context, group_snapshot, event_suffix,
usage_info) usage_info)
def _check_blocksize(blocksize): def _check_blocksize(blocksize: Union[str, int]) -> Union[str, int]:
# Check if volume_dd_blocksize is valid # Check if volume_dd_blocksize is valid
try: try:
# Rule out zero-sized/negative/float dd blocksize which # Rule out zero-sized/negative/float dd blocksize which
# cannot be caught by strutils # cannot be caught by strutils
if blocksize.startswith(('-', '0')) or '.' in blocksize: if (blocksize.startswith(('-', '0')) or # type: ignore
'.' in blocksize): # type: ignore
raise ValueError raise ValueError
strutils.string_to_bytes('%sB' % blocksize) strutils.string_to_bytes('%sB' % blocksize)
except ValueError: except ValueError:
@ -442,7 +481,8 @@ def _check_blocksize(blocksize):
return blocksize return blocksize
def check_for_odirect_support(src, dest, flag='oflag=direct'): def check_for_odirect_support(src: str, dest: str,
flag: str = 'oflag=direct') -> bool:
# Check whether O_DIRECT is supported # Check whether O_DIRECT is supported
try: try:
@ -459,9 +499,12 @@ def check_for_odirect_support(src, dest, flag='oflag=direct'):
return False return False
def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize, def _copy_volume_with_path(prefix, srcstr: str, deststr: str,
sync=False, execute=utils.execute, ionice=None, size_in_m: int, blocksize: Union[str, int],
sparse=False): sync: bool = False,
execute: Callable = utils.execute,
ionice=None,
sparse: bool = False) -> None:
cmd = prefix[:] cmd = prefix[:]
if ionice: if ionice:
@ -514,16 +557,18 @@ def _copy_volume_with_path(prefix, srcstr, deststr, size_in_m, blocksize,
{'size_in_m': size_in_m, 'mbps': mbps}) {'size_in_m': size_in_m, 'mbps': mbps})
def _open_volume_with_path(path, mode): def _open_volume_with_path(path: str, mode: str) -> IO[Any]:
try: try:
with utils.temporary_chown(path): with utils.temporary_chown(path):
handle = open(path, mode) handle = open(path, mode)
return handle return handle
except Exception: except Exception:
LOG.error("Failed to open volume from %(path)s.", {'path': path}) LOG.error("Failed to open volume from %(path)s.", {'path': path})
raise
def _transfer_data(src, dest, length, chunk_size): def _transfer_data(src: IO, dest: IO,
length: int, chunk_size: int) -> None:
"""Transfer data between files (Python IO objects).""" """Transfer data between files (Python IO objects)."""
chunks = int(math.ceil(length / chunk_size)) chunks = int(math.ceil(length / chunk_size))
@ -554,15 +599,21 @@ def _transfer_data(src, dest, length, chunk_size):
tpool.execute(dest.flush) tpool.execute(dest.flush)
def _copy_volume_with_file(src, dest, size_in_m): def _copy_volume_with_file(src: Union[str, IO],
dest: Union[str, IO],
size_in_m: int) -> None:
src_handle = src src_handle = src
if isinstance(src, str): if isinstance(src, str):
src_handle = _open_volume_with_path(src, 'rb') src_handle = _open_volume_with_path(src, 'rb')
src_handle = typing.cast(IO, src_handle)
dest_handle = dest dest_handle = dest
if isinstance(dest, str): if isinstance(dest, str):
dest_handle = _open_volume_with_path(dest, 'wb') dest_handle = _open_volume_with_path(dest, 'wb')
dest_handle = typing.cast(IO, dest_handle)
if not src_handle: if not src_handle:
raise exception.DeviceUnavailable( raise exception.DeviceUnavailable(
_("Failed to copy volume, source device unavailable.")) _("Failed to copy volume, source device unavailable."))
@ -588,9 +639,12 @@ def _copy_volume_with_file(src, dest, size_in_m):
{'size_in_m': size_in_m, 'mbps': mbps}) {'size_in_m': size_in_m, 'mbps': mbps})
def copy_volume(src, dest, size_in_m, blocksize, sync=False, def copy_volume(src: Union[str, BinaryIO],
dest: Union[str, BinaryIO],
size_in_m: int,
blocksize: Union[str, int], sync=False,
execute=utils.execute, ionice=None, throttle=None, execute=utils.execute, ionice=None, throttle=None,
sparse=False): sparse=False) -> None:
"""Copy data from the source volume to the destination volume. """Copy data from the source volume to the destination volume.
The parameters 'src' and 'dest' are both typically of type str, which The parameters 'src' and 'dest' are both typically of type str, which
@ -617,9 +671,12 @@ def copy_volume(src, dest, size_in_m, blocksize, sync=False,
_copy_volume_with_file(src, dest, size_in_m) _copy_volume_with_file(src, dest, size_in_m)
def clear_volume(volume_size, volume_path, volume_clear=None, def clear_volume(volume_size: int,
volume_clear_size=None, volume_clear_ionice=None, volume_path: str,
throttle=None): volume_clear: str = None,
volume_clear_size: int = None,
volume_clear_ionice: str = None,
throttle=None) -> None:
"""Unprovision old volumes to prevent data leaking between users.""" """Unprovision old volumes to prevent data leaking between users."""
if volume_clear is None: if volume_clear is None:
volume_clear = CONF.volume_clear volume_clear = CONF.volume_clear
@ -649,24 +706,25 @@ def clear_volume(volume_size, volume_path, volume_clear=None,
value=volume_clear) value=volume_clear)
def supports_thin_provisioning(): def supports_thin_provisioning() -> bool:
return brick_lvm.LVM.supports_thin_provisioning( return brick_lvm.LVM.supports_thin_provisioning(
utils.get_root_helper()) utils.get_root_helper())
def get_all_physical_volumes(vg_name=None): def get_all_physical_volumes(vg_name=None) -> list:
return brick_lvm.LVM.get_all_physical_volumes( return brick_lvm.LVM.get_all_physical_volumes(
utils.get_root_helper(), utils.get_root_helper(),
vg_name) vg_name)
def get_all_volume_groups(vg_name=None): def get_all_volume_groups(vg_name=None) -> list:
return brick_lvm.LVM.get_all_volume_groups( return brick_lvm.LVM.get_all_volume_groups(
utils.get_root_helper(), utils.get_root_helper(),
vg_name) vg_name)
def extract_availability_zones_from_volume_type(volume_type): def extract_availability_zones_from_volume_type(volume_type) \
-> Optional[list]:
if not volume_type: if not volume_type:
return None return None
extra_specs = volume_type.get('extra_specs', {}) extra_specs = volume_type.get('extra_specs', {})
@ -683,7 +741,9 @@ DEFAULT_PASSWORD_SYMBOLS = ('23456789', # Removed: 0,1
'abcdefghijkmnopqrstuvwxyz') # Removed: l 'abcdefghijkmnopqrstuvwxyz') # Removed: l
def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): def generate_password(
length: int = 16,
symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str:
"""Generate a random password from the supplied symbol groups. """Generate a random password from the supplied symbol groups.
At least one symbol from each group will be included. Unpredictable At least one symbol from each group will be included. Unpredictable
@ -720,7 +780,9 @@ def generate_password(length=16, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
return ''.join(password) return ''.join(password)
def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS): def generate_username(
length: int = 20,
symbolgroups: Tuple[str, ...] = DEFAULT_PASSWORD_SYMBOLS) -> str:
# Use the same implementation as the password generation. # Use the same implementation as the password generation.
return generate_password(length, symbolgroups) return generate_password(length, symbolgroups)
@ -728,7 +790,9 @@ def generate_username(length=20, symbolgroups=DEFAULT_PASSWORD_SYMBOLS):
DEFAULT_POOL_NAME = '_pool0' DEFAULT_POOL_NAME = '_pool0'
def extract_host(host, level='backend', default_pool_name=False): def extract_host(host: Optional[str],
level: str = 'backend',
default_pool_name: bool = False) -> Optional[str]:
"""Extract Host, Backend or Pool information from host string. """Extract Host, Backend or Pool information from host string.
:param host: String for host, which could include host@backend#pool info :param host: String for host, which could include host@backend#pool info
@ -778,8 +842,11 @@ def extract_host(host, level='backend', default_pool_name=False):
else: else:
return None return None
return None # not hit
def append_host(host, pool):
def append_host(host: Optional[str],
pool: Optional[str]) -> Optional[str]:
"""Encode pool into host info.""" """Encode pool into host info."""
if not host or not pool: if not host or not pool:
return host return host
@ -788,7 +855,7 @@ def append_host(host, pool):
return new_host return new_host
def matching_backend_name(src_volume_type, volume_type): def matching_backend_name(src_volume_type, volume_type) -> bool:
if src_volume_type.get('volume_backend_name') and \ if src_volume_type.get('volume_backend_name') and \
volume_type.get('volume_backend_name'): volume_type.get('volume_backend_name'):
return src_volume_type.get('volume_backend_name') == \ return src_volume_type.get('volume_backend_name') == \
@ -797,14 +864,14 @@ def matching_backend_name(src_volume_type, volume_type):
return False return False
def hosts_are_equivalent(host_1, host_2): def hosts_are_equivalent(host_1: str, host_2: str) -> bool:
# In case host_1 or host_2 are None # In case host_1 or host_2 are None
if not (host_1 and host_2): if not (host_1 and host_2):
return host_1 == host_2 return host_1 == host_2
return extract_host(host_1) == extract_host(host_2) return extract_host(host_1) == extract_host(host_2)
def read_proc_mounts(): def read_proc_mounts() -> List[str]:
"""Read the /proc/mounts file. """Read the /proc/mounts file.
It's a dummy function but it eases the writing of unit tests as mocking It's a dummy function but it eases the writing of unit tests as mocking
@ -814,19 +881,20 @@ def read_proc_mounts():
return mounts.readlines() return mounts.readlines()
def extract_id_from_volume_name(vol_name): def extract_id_from_volume_name(vol_name: str) -> Optional[str]:
regex = re.compile( regex: typing.Pattern = re.compile(
CONF.volume_name_template.replace('%s', r'(?P<uuid>.+)')) CONF.volume_name_template.replace('%s', r'(?P<uuid>.+)'))
match = regex.match(vol_name) match = regex.match(vol_name)
return match.group('uuid') if match else None return match.group('uuid') if match else None
def check_already_managed_volume(vol_id): def check_already_managed_volume(vol_id: Optional[str]):
"""Check cinder db for already managed volume. """Check cinder db for already managed volume.
:param vol_id: volume id parameter :param vol_id: volume id parameter
:returns: bool -- return True, if db entry with specified :returns: bool -- return True, if db entry with specified
volume id exists, otherwise return False volume id exists, otherwise return False
:raises: ValueError if vol_id is not a valid uuid string
""" """
try: try:
return (vol_id and isinstance(vol_id, str) and return (vol_id and isinstance(vol_id, str) and
@ -836,7 +904,7 @@ def check_already_managed_volume(vol_id):
return False return False
def extract_id_from_snapshot_name(snap_name): def extract_id_from_snapshot_name(snap_name: str) -> Optional[str]:
"""Return a snapshot's ID from its name on the backend.""" """Return a snapshot's ID from its name on the backend."""
regex = re.compile( regex = re.compile(
CONF.snapshot_name_template.replace('%s', r'(?P<uuid>.+)')) CONF.snapshot_name_template.replace('%s', r'(?P<uuid>.+)'))
@ -844,8 +912,12 @@ def extract_id_from_snapshot_name(snap_name):
return match.group('uuid') if match else None return match.group('uuid') if match else None
def paginate_entries_list(entries, marker, limit, offset, sort_keys, def paginate_entries_list(entries: List[Dict],
sort_dirs): marker: Optional[Union[dict, str]],
limit: int,
offset: Optional[int],
sort_keys: List[str],
sort_dirs: List[str]) -> list:
"""Paginate a list of entries. """Paginate a list of entries.
:param entries: list of dictionaries :param entries: list of dictionaries
@ -859,7 +931,8 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys,
comparers = [(operator.itemgetter(key.strip()), multiplier) comparers = [(operator.itemgetter(key.strip()), multiplier)
for (key, multiplier) in zip(sort_keys, sort_dirs)] for (key, multiplier) in zip(sort_keys, sort_dirs)]
def comparer(left, right): def comparer(left, right) -> int:
fn: Callable
for fn, d in comparers: for fn, d in comparers:
left_val = fn(left) left_val = fn(left)
right_val = fn(right) right_val = fn(right)
@ -900,7 +973,7 @@ def paginate_entries_list(entries, marker, limit, offset, sort_keys,
return sorted_entries[start_index + offset:range_end + offset] return sorted_entries[start_index + offset:range_end + offset]
def convert_config_string_to_dict(config_string): def convert_config_string_to_dict(config_string: str) -> dict:
"""Convert config file replication string to a dict. """Convert config file replication string to a dict.
The only supported form is as follows: The only supported form is as follows:
@ -924,12 +997,16 @@ def convert_config_string_to_dict(config_string):
return resultant_dict return resultant_dict
def create_encryption_key(context, key_manager, volume_type_id): def create_encryption_key(context: context.RequestContext,
key_manager,
volume_type_id: str) -> Optional[str]:
encryption_key_id = None encryption_key_id = None
if volume_types.is_encrypted(context, volume_type_id): if volume_types.is_encrypted(context, volume_type_id):
volume_type_encryption = ( volume_type_encryption: db.sqlalchemy.models.Encryption = (
volume_types.get_volume_type_encryption(context, volume_types.get_volume_type_encryption(context,
volume_type_id)) volume_type_id))
if volume_type_encryption is None:
raise exception.Invalid(message="Volume type error")
cipher = volume_type_encryption.cipher cipher = volume_type_encryption.cipher
length = volume_type_encryption.key_size length = volume_type_encryption.key_size
algorithm = cipher.split('-')[0] if cipher else None algorithm = cipher.split('-')[0] if cipher else None
@ -945,10 +1022,13 @@ def create_encryption_key(context, key_manager, volume_type_id):
LOG.exception("Key manager error") LOG.exception("Key manager error")
raise exception.Invalid(message="Key manager error") raise exception.Invalid(message="Key manager error")
typing.cast(str, encryption_key_id)
return encryption_key_id return encryption_key_id
def delete_encryption_key(context, key_manager, encryption_key_id): def delete_encryption_key(context: context.RequestContext,
key_manager,
encryption_key_id: str) -> None:
try: try:
key_manager.delete(context, encryption_key_id) key_manager.delete(context, encryption_key_id)
except castellan_exception.ManagedObjectNotFoundError: except castellan_exception.ManagedObjectNotFoundError:
@ -972,7 +1052,9 @@ def delete_encryption_key(context, key_manager, encryption_key_id):
pass pass
def clone_encryption_key(context, key_manager, encryption_key_id): def clone_encryption_key(context: context.RequestContext,
key_manager,
encryption_key_id: str) -> str:
clone_key_id = None clone_key_id = None
if encryption_key_id is not None: if encryption_key_id is not None:
clone_key_id = key_manager.store( clone_key_id = key_manager.store(
@ -981,19 +1063,19 @@ def clone_encryption_key(context, key_manager, encryption_key_id):
return clone_key_id return clone_key_id
def is_boolean_str(str): def is_boolean_str(str: Optional[str]) -> bool:
spec = (str or '').split() spec = (str or '').split()
return (len(spec) == 2 and return (len(spec) == 2 and
spec[0] == '<is>' and strutils.bool_from_string(spec[1])) spec[0] == '<is>' and strutils.bool_from_string(spec[1]))
def is_replicated_spec(extra_specs): def is_replicated_spec(extra_specs: dict) -> bool:
return (extra_specs and return (bool(extra_specs) and
is_boolean_str(extra_specs.get('replication_enabled'))) is_boolean_str(extra_specs.get('replication_enabled')))
def is_multiattach_spec(extra_specs): def is_multiattach_spec(extra_specs: dict) -> bool:
return (extra_specs and return (bool(extra_specs) and
is_boolean_str(extra_specs.get('multiattach'))) is_boolean_str(extra_specs.get('multiattach')))
@ -1003,7 +1085,7 @@ def group_get_by_id(group_id):
return group return group
def is_group_a_cg_snapshot_type(group_or_snap): def is_group_a_cg_snapshot_type(group_or_snap) -> bool:
LOG.debug("Checking if %s is a consistent snapshot group", LOG.debug("Checking if %s is a consistent snapshot group",
group_or_snap) group_or_snap)
if group_or_snap["group_type_id"] is not None: if group_or_snap["group_type_id"] is not None:
@ -1015,7 +1097,7 @@ def is_group_a_cg_snapshot_type(group_or_snap):
return False return False
def is_group_a_type(group, key): def is_group_a_type(group: 'objects.Group', key: str) -> bool:
if group.group_type_id is not None: if group.group_type_id is not None:
spec = group_types.get_group_type_specs( spec = group_types.get_group_type_specs(
group.group_type_id, key=key group.group_type_id, key=key
@ -1024,7 +1106,9 @@ def is_group_a_type(group, key):
return False return False
def get_max_over_subscription_ratio(str_value, supports_auto=False): def get_max_over_subscription_ratio(
str_value: Union[str, float],
supports_auto: bool = False) -> Union[str, float]:
"""Get the max_over_subscription_ratio from a string """Get the max_over_subscription_ratio from a string
As some drivers need to do some calculations with the value and we are now As some drivers need to do some calculations with the value and we are now
@ -1044,6 +1128,7 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False):
raise exception.VolumeDriverException(message=msg) raise exception.VolumeDriverException(message=msg)
if str_value == 'auto': if str_value == 'auto':
str_value = typing.cast(str, str_value)
return str_value return str_value
mosr = float(str_value) mosr = float(str_value)
@ -1055,7 +1140,8 @@ def get_max_over_subscription_ratio(str_value, supports_auto=False):
return mosr return mosr
def check_image_metadata(image_meta, vol_size): def check_image_metadata(image_meta: Dict[str, Union[str, int]],
vol_size: int) -> None:
"""Validates the image metadata.""" """Validates the image metadata."""
# Check whether image is active # Check whether image is active
if image_meta['status'] != 'active': if image_meta['status'] != 'active':
@ -1074,6 +1160,7 @@ def check_image_metadata(image_meta, vol_size):
# Check image min_disk requirement is met for the particular volume # Check image min_disk requirement is met for the particular volume
min_disk = image_meta.get('min_disk', 0) min_disk = image_meta.get('min_disk', 0)
min_disk = typing.cast(int, min_disk)
if vol_size < min_disk: if vol_size < min_disk:
msg = _('Volume size %(volume_size)sGB cannot be smaller' msg = _('Volume size %(volume_size)sGB cannot be smaller'
' than the image minDisk size %(min_disk)sGB.') ' than the image minDisk size %(min_disk)sGB.')
@ -1081,7 +1168,7 @@ def check_image_metadata(image_meta, vol_size):
raise exception.InvalidInput(reason=msg) raise exception.InvalidInput(reason=msg)
def enable_bootable_flag(volume): def enable_bootable_flag(volume: 'objects.Volume') -> None:
try: try:
LOG.debug('Marking volume %s as bootable.', volume.id) LOG.debug('Marking volume %s as bootable.', volume.id)
volume.bootable = True volume.bootable = True
@ -1092,7 +1179,8 @@ def enable_bootable_flag(volume):
raise exception.MetadataUpdateFailure(reason=ex) raise exception.MetadataUpdateFailure(reason=ex)
def get_volume_image_metadata(image_id, image_meta): def get_volume_image_metadata(image_id: str,
image_meta: Dict[str, Any]) -> dict:
# Save some base attributes into the volume metadata # Save some base attributes into the volume metadata
base_metadata = { base_metadata = {
@ -1114,6 +1202,7 @@ def get_volume_image_metadata(image_id, image_meta):
# Save all the image metadata properties into the volume metadata # Save all the image metadata properties into the volume metadata
property_metadata = {} property_metadata = {}
image_properties = image_meta.get('properties', {}) image_properties = image_meta.get('properties', {})
image_properties = typing.cast(dict, image_properties)
for (key, value) in image_properties.items(): for (key, value) in image_properties.items():
if value is not None: if value is not None:
property_metadata[key] = value property_metadata[key] = value
@ -1123,8 +1212,12 @@ def get_volume_image_metadata(image_id, image_meta):
return volume_metadata return volume_metadata
def copy_image_to_volume(driver, context, volume, image_meta, image_location, def copy_image_to_volume(driver,
image_service): context: context.RequestContext,
volume: 'objects.Volume',
image_meta: dict,
image_location: str,
image_service) -> None:
"""Downloads Glance image to the specified volume.""" """Downloads Glance image to the specified volume."""
image_id = image_meta['id'] image_id = image_meta['id']
LOG.debug("Attempting download of %(image_id)s (%(image_location)s)" LOG.debug("Attempting download of %(image_id)s (%(image_location)s)"
@ -1173,7 +1266,7 @@ def copy_image_to_volume(driver, context, volume, image_meta, image_location,
'image_location': image_location}) 'image_location': image_location})
def image_conversion_dir(): def image_conversion_dir() -> str:
tmpdir = (CONF.image_conversion_dir or tmpdir = (CONF.image_conversion_dir or
tempfile.gettempdir()) tempfile.gettempdir())
@ -1184,7 +1277,9 @@ def image_conversion_dir():
return tmpdir return tmpdir
def check_encryption_provider(db, volume, context): def check_encryption_provider(db,
volume: 'objects.Volume',
context: context.RequestContext) -> dict:
"""Check that this is a LUKS encryption provider. """Check that this is a LUKS encryption provider.
:returns: encryption dict :returns: encryption dict
@ -1212,14 +1307,14 @@ def check_encryption_provider(db, volume, context):
return encryption return encryption
def sanitize_host(host): def sanitize_host(host: str) -> str:
"""Ensure IPv6 addresses are enclosed in [] for iSCSI portals.""" """Ensure IPv6 addresses are enclosed in [] for iSCSI portals."""
if netutils.is_valid_ipv6(host): if netutils.is_valid_ipv6(host):
return '[%s]' % host return '[%s]' % host
return host return host
def sanitize_hostname(hostname): def sanitize_hostname(hostname) -> str:
"""Return a hostname which conforms to RFC-952 and RFC-1123 specs.""" """Return a hostname which conforms to RFC-952 and RFC-1123 specs."""
hostname = hostname.encode('latin-1', 'ignore') hostname = hostname.encode('latin-1', 'ignore')
hostname = hostname.decode('latin-1') hostname = hostname.decode('latin-1')
@ -1232,7 +1327,7 @@ def sanitize_hostname(hostname):
return hostname return hostname
def resolve_hostname(hostname): def resolve_hostname(hostname: str) -> str:
"""Resolves host name to IP address. """Resolves host name to IP address.
Resolves a host name (my.data.point.com) to an IP address (10.12.143.11). Resolves a host name (my.data.point.com) to an IP address (10.12.143.11).
@ -1248,7 +1343,9 @@ def resolve_hostname(hostname):
return ip return ip
def update_backup_error(backup, err, status=fields.BackupStatus.ERROR): def update_backup_error(backup,
err: str,
status=fields.BackupStatus.ERROR) -> None:
backup.status = status backup.status = status
backup.fail_reason = err backup.fail_reason = err
backup.save() backup.save()
@ -1256,7 +1353,7 @@ def update_backup_error(backup, err, status=fields.BackupStatus.ERROR):
# TODO (whoami-rajat): Remove this method when oslo.vmware calls volume_utils # TODO (whoami-rajat): Remove this method when oslo.vmware calls volume_utils
# wrapper of upload_volume instead of image_utils.upload_volume # wrapper of upload_volume instead of image_utils.upload_volume
def get_base_image_ref(volume): def get_base_image_ref(volume: 'objects.Volume'):
# This method fetches the image_id from volume glance metadata and pass # This method fetches the image_id from volume glance metadata and pass
# it to the driver calling it during upload volume to image operation # it to the driver calling it during upload volume to image operation
base_image_ref = None base_image_ref = None
@ -1265,9 +1362,12 @@ def get_base_image_ref(volume):
return base_image_ref return base_image_ref
def upload_volume(context, image_service, image_meta, volume_path, def upload_volume(context: context.RequestContext,
volume, volume_format='raw', run_as_root=True, image_service, image_meta, volume_path,
compress=True): volume: 'objects.Volume',
volume_format: str = 'raw',
run_as_root: bool = True,
compress: bool = True) -> None:
# retrieve store information from extra-specs # retrieve store information from extra-specs
store_id = volume.volume_type.extra_specs.get('image_service:store_id') store_id = volume.volume_type.extra_specs.get('image_service:store_id')
@ -1305,7 +1405,8 @@ def get_backend_configuration(backend_name, backend_opts=None):
return config return config
def brick_get_connector_properties(multipath=False, enforce_multipath=False): def brick_get_connector_properties(multipath: bool = False,
enforce_multipath: bool = False):
"""Wrapper to automatically set root_helper in brick calls. """Wrapper to automatically set root_helper in brick calls.
:param multipath: A boolean indicating whether the connector can :param multipath: A boolean indicating whether the connector can
@ -1323,9 +1424,10 @@ def brick_get_connector_properties(multipath=False, enforce_multipath=False):
enforce_multipath) enforce_multipath)
def brick_get_connector(protocol, driver=None, def brick_get_connector(protocol: str,
use_multipath=False, driver=None,
device_scan_attempts=3, use_multipath: bool = False,
device_scan_attempts: int = 3,
*args, **kwargs): *args, **kwargs):
"""Wrapper to get a brick connector object. """Wrapper to get a brick connector object.
@ -1342,7 +1444,7 @@ def brick_get_connector(protocol, driver=None,
*args, **kwargs) *args, **kwargs)
def brick_get_encryptor(connection_info, *args, **kwargs): def brick_get_encryptor(connection_info: dict, *args, **kwargs):
"""Wrapper to get a brick encryptor object.""" """Wrapper to get a brick encryptor object."""
root_helper = utils.get_root_helper() root_helper = utils.get_root_helper()
@ -1353,7 +1455,9 @@ def brick_get_encryptor(connection_info, *args, **kwargs):
*args, **kwargs) *args, **kwargs)
def brick_attach_volume_encryptor(context, attach_info, encryption): def brick_attach_volume_encryptor(context: context.RequestContext,
attach_info: dict,
encryption: dict) -> None:
"""Attach encryption layer.""" """Attach encryption layer."""
connection_info = attach_info['conn'] connection_info = attach_info['conn']
connection_info['data']['device_path'] = attach_info['device']['path'] connection_info['data']['device_path'] = attach_info['device']['path']
@ -1362,7 +1466,7 @@ def brick_attach_volume_encryptor(context, attach_info, encryption):
encryptor.attach_volume(context, **encryption) encryptor.attach_volume(context, **encryption)
def brick_detach_volume_encryptor(attach_info, encryption): def brick_detach_volume_encryptor(attach_info: dict, encryption: dict) -> None:
"""Detach encryption layer.""" """Detach encryption layer."""
connection_info = attach_info['conn'] connection_info = attach_info['conn']
connection_info['data']['device_path'] = attach_info['device']['path'] connection_info['data']['device_path'] = attach_info['device']['path']

View File

@ -1,6 +1,9 @@
cinder/context.py cinder/context.py
cinder/i18n.py cinder/i18n.py
cinder/exception.py
cinder/manager.py cinder/manager.py
cinder/utils.py
cinder/volume/__init__.py cinder/volume/__init__.py
cinder/volume/manager.py cinder/volume/manager.py
cinder/volume/volume_types.py cinder/volume/volume_types.py
cinder/volume/volume_utils.py