diff --git a/manila/api/v2/shares.py b/manila/api/v2/shares.py index fae7602c1a..db7cd2ebe1 100644 --- a/manila/api/v2/shares.py +++ b/manila/api/v2/shares.py @@ -13,7 +13,6 @@ # License for the specific language governing permissions and limitations # under the License. -from oslo_utils import strutils import six import webob from webob import exc @@ -30,6 +29,7 @@ from manila import db from manila import exception from manila.i18n import _ from manila import share +from manila import utils class ShareController(shares.ShareMixin, @@ -99,44 +99,16 @@ class ShareController(shares.ShareMixin, except KeyError: raise exc.HTTPBadRequest(explanation=_("Must specify 'host'.")) - force_host_assisted_migration = params.get( - 'force_host_assisted_migration', False) - try: - force_host_assisted_migration = strutils.bool_from_string( - force_host_assisted_migration, strict=True) - except ValueError: - msg = _("Invalid value %s for 'force_host_assisted_migration'. " - "Expecting a boolean.") % force_host_assisted_migration - raise exc.HTTPBadRequest(explanation=msg) + force_host_assisted_migration = utils.get_bool_from_api_params( + 'force_host_assisted_migration', params) new_share_network = None new_share_type = None - preserve_metadata = params.get('preserve_metadata', True) - try: - preserve_metadata = strutils.bool_from_string( - preserve_metadata, strict=True) - except ValueError: - msg = _("Invalid value %s for 'preserve_metadata'. " - "Expecting a boolean.") % preserve_metadata - raise exc.HTTPBadRequest(explanation=msg) - - writable = params.get('writable', True) - try: - writable = strutils.bool_from_string(writable, strict=True) - except ValueError: - msg = _("Invalid value %s for 'writable'. " - "Expecting a boolean.") % writable - raise exc.HTTPBadRequest(explanation=msg) - - nondisruptive = params.get('nondisruptive', False) - try: - nondisruptive = strutils.bool_from_string( - nondisruptive, strict=True) - except ValueError: - msg = _("Invalid value %s for 'nondisruptive'. " - "Expecting a boolean.") % nondisruptive - raise exc.HTTPBadRequest(explanation=msg) + preserve_metadata = utils.get_bool_from_api_params('preserve_metadata', + params, True) + writable = utils.get_bool_from_api_params('writable', params, True) + nondisruptive = utils.get_bool_from_api_params('nondisruptive', params) new_share_network_id = params.get('new_share_network_id', None) if new_share_network_id: diff --git a/manila/tests/test_utils.py b/manila/tests/test_utils.py index f1b39d4dff..5b43c6b054 100644 --- a/manila/tests/test_utils.py +++ b/manila/tests/test_utils.py @@ -25,6 +25,7 @@ import mock from oslo_config import cfg from oslo_utils import timeutils import paramiko +from webob import exc import manila from manila.common import constants @@ -378,6 +379,56 @@ class CidrToNetmaskTestCase(test.TestCase): self.assertRaises(exception.InvalidInput, utils.cidr_to_netmask, cidr) +@ddt.ddt +class ParseBoolValueTestCase(test.TestCase): + + @ddt.data( + ('t', True), + ('on', True), + ('1', True), + ('false', False), + ('n', False), + ('no', False), + ('0', False),) + @ddt.unpack + def test_bool_with_valid_string(self, string, value): + fake_dict = {'fake_key': string} + result = utils.get_bool_from_api_params('fake_key', fake_dict) + self.assertEqual(value, result) + + @ddt.data('None', 'invalid', 'falses') + def test_bool_with_invalid_string(self, string): + fake_dict = {'fake_key': string} + self.assertRaises(exc.HTTPBadRequest, + utils.get_bool_from_api_params, + 'fake_key', fake_dict) + + @ddt.data('undefined', None) + def test_bool_with_key_not_found_raise_error(self, def_val): + fake_dict = {'fake_key1': 'value1'} + self.assertRaises(exc.HTTPBadRequest, + utils.get_bool_from_api_params, + 'fake_key2', + fake_dict, + def_val) + + @ddt.data((False, False, False), + (True, True, False), + ('true', True, False), + ('false', False, False), + ('undefined', 'undefined', False), + (False, False, True), + ('true', True, True)) + @ddt.unpack + def test_bool_with_key_not_found(self, def_val, expected, strict): + fake_dict = {'fake_key1': 'value1'} + invalid_default = utils.get_bool_from_api_params('fake_key2', + fake_dict, + def_val, + strict) + self.assertEqual(expected, invalid_default) + + @ddt.ddt class IsValidIPVersion(test.TestCase): """Test suite for function 'is_valid_ip_address'.""" diff --git a/manila/utils.py b/manila/utils.py index ea21ef21af..d2cbcb6699 100644 --- a/manila/utils.py +++ b/manila/utils.py @@ -38,10 +38,12 @@ from oslo_config import cfg from oslo_log import log from oslo_utils import importutils from oslo_utils import netutils +from oslo_utils import strutils from oslo_utils import timeutils import paramiko import retrying import six +from webob import exc from manila.common import constants from manila.db import api as db_api @@ -461,6 +463,26 @@ def retry(exception, interval=1, retries=10, backoff_rate=2, return _decorator +def get_bool_from_api_params(key, params, default=False, strict=True): + """Parse bool value from request params. + + HTTPBadRequest will be directly raised either of the cases below: + 1. invalid bool string was found by key(with strict on). + 2. key not found while default value is invalid(with strict on). + """ + param = params.get(key, default) + try: + param = strutils.bool_from_string(param, + strict=strict, + default=default) + except ValueError: + msg = _('Invalid value %(param)s for %(param_string)s. ' + 'Expecting a boolean.') % {'param': param, + 'param_string': key} + raise exc.HTTPBadRequest(explanation=msg) + return param + + def require_driver_initialized(func): @functools.wraps(func) def wrapper(self, *args, **kwargs):