tests: refactor objects test cases to use db models instead of dicts
This should reflect the code under test better, and is needed for one of patches in the review queue (I130609194f15b89df89e5606fb8193849edd14d8) to pass some of those test cases. Partially-Implements: blueprint adopt-oslo-versioned-objects-for-db Change-Id: Id1ca4ce7b134d9729e68661cedb2f5556e58d6ff
This commit is contained in:
parent
c8fea2b392
commit
9cd230397c
@ -30,7 +30,9 @@ class NetworkPortSecurityDbObjTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
|
||||
def setUp(self):
|
||||
super(NetworkPortSecurityDbObjTestCase, self).setUp()
|
||||
for db_obj, obj_field in zip(self.db_objs, self.obj_fields):
|
||||
for db_obj, obj_field, obj in zip(
|
||||
self.db_objs, self.obj_fields, self.objs):
|
||||
network = self._create_network()
|
||||
db_obj['network_id'] = network['id']
|
||||
obj_field['id'] = network['id']
|
||||
obj['id'] = network['id']
|
||||
|
@ -32,5 +32,5 @@ class NetworkSegmentDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
def setUp(self):
|
||||
super(NetworkSegmentDbObjectTestCase, self).setUp()
|
||||
self._create_test_network()
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['network_id'] = self._network['id']
|
||||
|
@ -34,5 +34,5 @@ class AllowedAddrPairsDbObjTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
self.context = context.get_admin_context()
|
||||
self._create_test_network()
|
||||
self._create_test_port(self._network)
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['port_id'] = self._port['id']
|
||||
|
@ -10,6 +10,8 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import itertools
|
||||
|
||||
from neutron.objects.port.extensions import extra_dhcp_opt
|
||||
from neutron.tests.unit.objects import test_base as obj_test_base
|
||||
from neutron.tests.unit import testlib_api
|
||||
@ -29,7 +31,5 @@ class ExtraDhcpOptDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
super(ExtraDhcpOptDbObjectTestCase, self).setUp()
|
||||
self._create_test_network()
|
||||
self._create_test_port(self._network)
|
||||
for obj in self.db_objs:
|
||||
obj['port_id'] = self._port['id']
|
||||
for obj in self.obj_fields:
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['port_id'] = self._port['id']
|
||||
|
@ -82,7 +82,9 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
|
||||
context_mock.assert_called_once_with()
|
||||
self.get_objects.assert_any_call(
|
||||
admin_context, self._test_class.db_model, _pager=None)
|
||||
self._validate_objects(self.db_objs, objs)
|
||||
self.assertItemsEqual(
|
||||
[test_base.get_obj_db_fields(obj) for obj in self.objs],
|
||||
[test_base.get_obj_db_fields(obj) for obj in objs])
|
||||
|
||||
def test_get_objects_valid_fields(self):
|
||||
admin_context = self.context.elevated()
|
||||
@ -103,7 +105,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
|
||||
get_objects_mock.assert_any_call(
|
||||
admin_context, self._test_class.db_model, _pager=None,
|
||||
**self.valid_field_filter)
|
||||
self._validate_objects([self.db_obj], objs)
|
||||
self._check_equal(objs[0], self.objs[0])
|
||||
|
||||
def test_get_object(self):
|
||||
admin_context = self.context.elevated()
|
||||
@ -114,7 +116,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
|
||||
return_value=admin_context) as context_mock:
|
||||
obj = self._test_class.get_object(self.context, id='fake_id')
|
||||
self.assertTrue(self._is_test_class(obj))
|
||||
self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
|
||||
self._check_equal(obj, self.objs[0])
|
||||
context_mock.assert_called_once_with()
|
||||
get_object_mock.assert_called_once_with(
|
||||
admin_context, self._test_class.db_model, id='fake_id')
|
||||
@ -139,9 +141,8 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
|
||||
self._create_test_port(self._network)
|
||||
|
||||
def _create_test_policy(self):
|
||||
policy_obj = policy.QosPolicy(self.context, **self.db_obj)
|
||||
policy_obj.create()
|
||||
return policy_obj
|
||||
self.objs[0].create()
|
||||
return self.objs[0]
|
||||
|
||||
def _create_test_policy_with_rules(self, rule_type, reload_rules=False):
|
||||
policy_obj = self._create_test_policy()
|
||||
|
@ -27,7 +27,6 @@ from oslo_versionedobjects import fixture
|
||||
import testtools
|
||||
|
||||
from neutron.common import constants
|
||||
from neutron.common import utils as common_utils
|
||||
from neutron import context
|
||||
from neutron.db import db_base_plugin_v2
|
||||
from neutron.db import model_base
|
||||
@ -48,14 +47,12 @@ OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.'
|
||||
TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number']
|
||||
|
||||
|
||||
class FakeModel(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
class FakeModel(dict):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectFieldsModel(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
class ObjectFieldsModel(dict):
|
||||
pass
|
||||
|
||||
|
||||
@obj_base.VersionedObjectRegistry.register_if(False)
|
||||
@ -396,9 +393,11 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = {
|
||||
}
|
||||
|
||||
|
||||
# TODO(ihrachys) consider renaming into e.g. get_obj_persistent_fields
|
||||
def get_obj_db_fields(obj):
|
||||
return {field: getattr(obj, field) for field in obj.fields
|
||||
if field not in obj.synthetic_fields}
|
||||
if field not in obj.synthetic_fields
|
||||
if field in obj}
|
||||
|
||||
|
||||
def get_value(generator, version):
|
||||
@ -429,11 +428,19 @@ class _BaseObjectTestCase(object):
|
||||
# neutron.objects.db.api from core plugin instance
|
||||
self.setup_coreplugin(self.CORE_PLUGIN)
|
||||
self.context = context.get_admin_context()
|
||||
self.db_objs = list(self.get_random_fields() for _ in range(3))
|
||||
self.db_objs = [
|
||||
self._test_class.db_model(**self.get_random_fields())
|
||||
for _ in range(3)
|
||||
]
|
||||
self.db_obj = self.db_objs[0]
|
||||
|
||||
# TODO(ihrachys) remove obj_fields since they duplicate self.objs
|
||||
self.obj_fields = [self._test_class.modify_fields_from_db(db_obj)
|
||||
for db_obj in self.db_objs]
|
||||
self.objs = [
|
||||
self._test_class(self.context, **fields)
|
||||
for fields in self.obj_fields
|
||||
]
|
||||
|
||||
valid_field = [f for f in self._test_class.fields
|
||||
if f not in self._test_class.synthetic_fields][0]
|
||||
@ -447,8 +454,10 @@ class _BaseObjectTestCase(object):
|
||||
synthetic_obj_fields = self.get_random_fields(FakeSmallNeutronObject)
|
||||
self.model_map = {
|
||||
self._test_class.db_model: self.db_objs,
|
||||
ObjectFieldsModel: [synthetic_obj_fields]}
|
||||
ObjectFieldsModel: [ObjectFieldsModel(**synthetic_obj_fields)]}
|
||||
|
||||
# TODO(ihrachys): rename the method to explicitly reflect it returns db
|
||||
# attributes not object fields
|
||||
@classmethod
|
||||
def get_random_fields(cls, obj_cls=None):
|
||||
obj_cls = obj_cls or cls._test_class
|
||||
@ -504,6 +513,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
self.model_map[self._test_class.db_model] = self.db_objs
|
||||
self.pager_map = collections.defaultdict(lambda: None)
|
||||
|
||||
# TODO(ihrachys) document the intent of all common test cases in docstrings
|
||||
def test_get_object(self):
|
||||
with mock.patch.object(obj_db_api, 'get_object',
|
||||
return_value=self.db_obj) as get_object_mock:
|
||||
@ -512,7 +522,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
obj_keys = self.generate_object_keys(self._test_class)
|
||||
obj = self._test_class.get_object(self.context, **obj_keys)
|
||||
self.assertTrue(self._is_test_class(obj))
|
||||
self.assertEqual(self.obj_fields[0], get_obj_db_fields(obj))
|
||||
self._check_equal(obj, self.objs[0])
|
||||
get_object_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model,
|
||||
**self._test_class.modify_fields_to_db(obj_keys))
|
||||
@ -550,8 +560,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
obj = self._test_class.get_object(self.context,
|
||||
**obj_keys)
|
||||
self.assertTrue(self._is_test_class(obj))
|
||||
self.assertEqual(self.obj_fields[0],
|
||||
get_obj_db_fields(obj))
|
||||
self._check_equal(obj, self.objs[0])
|
||||
get_object_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model,
|
||||
**self._test_class.modify_fields_to_db(obj_keys))
|
||||
@ -574,37 +583,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
return mock_calls
|
||||
|
||||
def test_get_objects(self):
|
||||
'''Test that get_objects fetches data from database.'''
|
||||
with mock.patch.object(
|
||||
obj_db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects) as get_objects_mock:
|
||||
objs = self._test_class.get_objects(self.context)
|
||||
self._validate_objects(self.db_objs, objs)
|
||||
mock_calls = [
|
||||
mock.call(self.context, self._test_class.db_model,
|
||||
_pager=self.pager_map[self._test_class.obj_name()])
|
||||
]
|
||||
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
|
||||
self.db_objs))
|
||||
get_objects_mock.assert_has_calls(mock_calls)
|
||||
self.assertItemsEqual(
|
||||
[get_obj_db_fields(obj) for obj in self.objs],
|
||||
[get_obj_db_fields(obj) for obj in objs])
|
||||
get_objects_mock.assert_any_call(
|
||||
self.context, self._test_class.db_model,
|
||||
_pager=self.pager_map[self._test_class.obj_name()]
|
||||
)
|
||||
|
||||
def test_get_objects_valid_fields(self):
|
||||
'''Test that a valid filter does not raise an error.'''
|
||||
with mock.patch.object(
|
||||
obj_db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects) as get_objects_mock:
|
||||
|
||||
objs = self._test_class.get_objects(self.context,
|
||||
**self.valid_field_filter)
|
||||
self._validate_objects(self.db_objs, objs)
|
||||
mock_calls = [
|
||||
mock.call(
|
||||
self.context, self._test_class.db_model,
|
||||
_pager=self.pager_map[self._test_class.obj_name()],
|
||||
**self._test_class.modify_fields_to_db(self.valid_field_filter)
|
||||
)
|
||||
]
|
||||
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
|
||||
[self.db_obj]))
|
||||
get_objects_mock.assert_has_calls(mock_calls)
|
||||
obj_db_api, 'get_objects', side_effect=self.fake_get_objects):
|
||||
self._test_class.get_objects(self.context,
|
||||
**self.valid_field_filter)
|
||||
|
||||
def test_get_objects_mixed_fields(self):
|
||||
synthetic_fields = (
|
||||
@ -661,19 +658,11 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
self._test_class.count, self.context,
|
||||
fake_field='xxx')
|
||||
|
||||
def _validate_objects(self, expected, observed):
|
||||
self.assertTrue(all(self._is_test_class(obj) for obj in observed))
|
||||
self.assertEqual(
|
||||
sorted([self._test_class.modify_fields_from_db(db_obj)
|
||||
for db_obj in expected],
|
||||
key=common_utils.safe_sort_key),
|
||||
sorted([get_obj_db_fields(obj) for obj in observed],
|
||||
key=common_utils.safe_sort_key))
|
||||
|
||||
def _check_equal(self, obj, db_obj):
|
||||
self.assertEqual(
|
||||
sorted(db_obj),
|
||||
sorted(get_obj_db_fields(obj)))
|
||||
# TODO(ihrachys) swap the order of arguments to reflect the order of
|
||||
# self.assert* methods
|
||||
def _check_equal(self, observed, expected):
|
||||
self.assertItemsEqual(get_obj_db_fields(expected),
|
||||
get_obj_db_fields(observed))
|
||||
|
||||
def test_create(self):
|
||||
with mock.patch.object(obj_db_api, 'create_object',
|
||||
@ -681,21 +670,21 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
with mock.patch.object(obj_db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects):
|
||||
obj = self._test_class(self.context, **self.obj_fields[0])
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self._check_equal(obj, self.objs[0])
|
||||
obj.create()
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self._check_equal(obj, self.objs[0])
|
||||
create_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model, self.db_obj)
|
||||
self.context, self._test_class.db_model,
|
||||
self._test_class.modify_fields_to_db(
|
||||
get_obj_db_fields(self.objs[0])))
|
||||
|
||||
def test_create_updates_from_db_object(self):
|
||||
with mock.patch.object(obj_db_api, 'create_object',
|
||||
return_value=self.db_obj):
|
||||
with mock.patch.object(obj_db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects):
|
||||
obj = self._test_class(self.context, **self.obj_fields[1])
|
||||
self._check_equal(obj, self.obj_fields[1])
|
||||
obj.create()
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self.objs[1].create()
|
||||
self._check_equal(self.objs[1], self.objs[0])
|
||||
|
||||
def test_create_duplicates(self):
|
||||
with mock.patch.object(obj_db_api, 'create_object',
|
||||
@ -772,7 +761,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
side_effect=self.fake_get_objects):
|
||||
obj = self._test_class(self.context, **self.obj_fields[0])
|
||||
# get new values and fix keys
|
||||
update_mock.return_value = self.db_objs[1].copy()
|
||||
update_mock.return_value = self.db_objs[1]
|
||||
fixed_keys = self._test_class.modify_fields_to_db(
|
||||
obj._get_composite_keys())
|
||||
for key, value in fixed_keys.items():
|
||||
@ -813,14 +802,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
|
||||
obj_db_api, 'get_objects',
|
||||
side_effect=self.fake_get_objects):
|
||||
obj.update()
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self._check_equal(obj, self.objs[0])
|
||||
|
||||
@mock.patch.object(obj_db_api, 'delete_object')
|
||||
def test_delete(self, delete_mock):
|
||||
obj = self._test_class(self.context, **self.obj_fields[0])
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self._check_equal(obj, self.objs[0])
|
||||
obj.delete()
|
||||
self._check_equal(obj, self.obj_fields[0])
|
||||
self._check_equal(obj, self.objs[0])
|
||||
delete_mock.assert_called_once_with(
|
||||
self.context, self._test_class.db_model,
|
||||
**self._test_class.modify_fields_to_db(obj._get_composite_keys()))
|
||||
@ -1024,7 +1013,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
|
||||
continue
|
||||
for db_obj in self.db_objs:
|
||||
objclass_fields = self.get_random_fields(objclass)
|
||||
db_obj[synth_field] = [objclass_fields]
|
||||
db_obj[synth_field] = [objclass.db_model(**objclass_fields)]
|
||||
|
||||
def _create_test_network(self):
|
||||
# TODO(ihrachys): replace with network.create() once we get an object
|
||||
@ -1208,10 +1197,13 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
|
||||
obj = self._make_object(self.obj_fields[0])
|
||||
obj.create()
|
||||
|
||||
for field in remove_timestamps_from_fields(self.obj_fields[0]):
|
||||
filters = {field: [self.obj_fields[0][field]]}
|
||||
for field in remove_timestamps_from_fields(get_obj_db_fields(obj)):
|
||||
filters = {field: [self.objs[0][field]]}
|
||||
new = self._test_class.get_objects(self.context, **filters)
|
||||
self.assertEqual([obj], new, 'Filtering by %s failed.' % field)
|
||||
self.assertItemsEqual(
|
||||
[obj._get_composite_keys()],
|
||||
[obj_._get_composite_keys() for obj_ in new],
|
||||
'Filtering by %s failed.' % field)
|
||||
|
||||
def _get_non_synth_fields(self, objclass, db_attrs):
|
||||
fields = objclass.modify_fields_from_db(db_attrs)
|
||||
|
@ -29,9 +29,8 @@ from neutron.tests.unit.objects import test_base
|
||||
from neutron.tests.unit import testlib_api
|
||||
|
||||
|
||||
class FakeDbModel(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
class FakeDbModel(dict):
|
||||
pass
|
||||
|
||||
|
||||
class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2):
|
||||
|
@ -118,7 +118,7 @@ class DefaultSecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
|
||||
self.sg_obj = securitygroup.SecurityGroup(
|
||||
self.context, **test_base.remove_timestamps_from_fields(sg_fields))
|
||||
self.sg_obj.create()
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['security_group_id'] = self.sg_obj['id']
|
||||
|
||||
|
||||
@ -140,6 +140,6 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
|
||||
self.sg_obj = securitygroup.SecurityGroup(
|
||||
self.context, **test_base.remove_timestamps_from_fields(sg_fields))
|
||||
self.sg_obj.create()
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['security_group_id'] = self.sg_obj['id']
|
||||
obj['remote_group_id'] = self.sg_obj['id']
|
||||
|
@ -39,7 +39,7 @@ class IPAllocationPoolDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
super(IPAllocationPoolDbObjectTestCase, self).setUp()
|
||||
self._create_test_network()
|
||||
self._create_test_subnet(self._network)
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['subnet_id'] = self._subnet['id']
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
for db_obj in self.db_objs]
|
||||
self._create_test_network()
|
||||
self._create_test_subnet(self._network)
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['subnet_id'] = self._subnet['id']
|
||||
|
||||
def _is_objects_unique(self):
|
||||
@ -128,7 +128,7 @@ class RouteDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
super(RouteDbObjectTestCase, self).setUp()
|
||||
self._create_test_network()
|
||||
self._create_test_subnet(self._network)
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['subnet_id'] = self._subnet['id']
|
||||
|
||||
|
||||
@ -151,7 +151,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
|
||||
super(SubnetDbObjectTestCase, self).setUp()
|
||||
self._create_test_network()
|
||||
self._create_test_segment(self._network)
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['network_id'] = self._network['id']
|
||||
obj['segment_id'] = self._segment['id']
|
||||
|
||||
|
@ -87,5 +87,5 @@ class SubnetPoolPrefixDbObjectTestCase(
|
||||
def setUp(self):
|
||||
super(SubnetPoolPrefixDbObjectTestCase, self).setUp()
|
||||
self._create_test_subnetpool()
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields):
|
||||
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
|
||||
obj['subnetpool_id'] = self._pool.id
|
||||
|
Loading…
x
Reference in New Issue
Block a user