tests: refactor objects test cases to use db models instead of dicts

This should reflect the code under test better, and is needed for one of
patches in the review queue (I130609194f15b89df89e5606fb8193849edd14d8)
to pass some of those test cases.

Partially-Implements: blueprint adopt-oslo-versioned-objects-for-db
Change-Id: Id1ca4ce7b134d9729e68661cedb2f5556e58d6ff
This commit is contained in:
Ihar Hrachyshka 2016-08-29 14:37:40 +00:00
parent c8fea2b392
commit 9cd230397c
10 changed files with 81 additions and 87 deletions

View File

@ -30,7 +30,9 @@ class NetworkPortSecurityDbObjTestCase(obj_test_base.BaseDbObjectTestCase,
def setUp(self):
super(NetworkPortSecurityDbObjTestCase, self).setUp()
for db_obj, obj_field in zip(self.db_objs, self.obj_fields):
for db_obj, obj_field, obj in zip(
self.db_objs, self.obj_fields, self.objs):
network = self._create_network()
db_obj['network_id'] = network['id']
obj_field['id'] = network['id']
obj['id'] = network['id']

View File

@ -32,5 +32,5 @@ class NetworkSegmentDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
def setUp(self):
super(NetworkSegmentDbObjectTestCase, self).setUp()
self._create_test_network()
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['network_id'] = self._network['id']

View File

@ -34,5 +34,5 @@ class AllowedAddrPairsDbObjTestCase(obj_test_base.BaseDbObjectTestCase,
self.context = context.get_admin_context()
self._create_test_network()
self._create_test_port(self._network)
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['port_id'] = self._port['id']

View File

@ -10,6 +10,8 @@
# License for the specific language governing permissions and limitations
# under the License.
import itertools
from neutron.objects.port.extensions import extra_dhcp_opt
from neutron.tests.unit.objects import test_base as obj_test_base
from neutron.tests.unit import testlib_api
@ -29,7 +31,5 @@ class ExtraDhcpOptDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
super(ExtraDhcpOptDbObjectTestCase, self).setUp()
self._create_test_network()
self._create_test_port(self._network)
for obj in self.db_objs:
obj['port_id'] = self._port['id']
for obj in self.obj_fields:
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['port_id'] = self._port['id']

View File

@ -82,7 +82,9 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
context_mock.assert_called_once_with()
self.get_objects.assert_any_call(
admin_context, self._test_class.db_model, _pager=None)
self._validate_objects(self.db_objs, objs)
self.assertItemsEqual(
[test_base.get_obj_db_fields(obj) for obj in self.objs],
[test_base.get_obj_db_fields(obj) for obj in objs])
def test_get_objects_valid_fields(self):
admin_context = self.context.elevated()
@ -103,7 +105,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
get_objects_mock.assert_any_call(
admin_context, self._test_class.db_model, _pager=None,
**self.valid_field_filter)
self._validate_objects([self.db_obj], objs)
self._check_equal(objs[0], self.objs[0])
def test_get_object(self):
admin_context = self.context.elevated()
@ -114,7 +116,7 @@ class QosPolicyObjectTestCase(test_base.BaseObjectIfaceTestCase):
return_value=admin_context) as context_mock:
obj = self._test_class.get_object(self.context, id='fake_id')
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.db_obj, test_base.get_obj_db_fields(obj))
self._check_equal(obj, self.objs[0])
context_mock.assert_called_once_with()
get_object_mock.assert_called_once_with(
admin_context, self._test_class.db_model, id='fake_id')
@ -139,9 +141,8 @@ class QosPolicyDbObjectTestCase(test_base.BaseDbObjectTestCase,
self._create_test_port(self._network)
def _create_test_policy(self):
policy_obj = policy.QosPolicy(self.context, **self.db_obj)
policy_obj.create()
return policy_obj
self.objs[0].create()
return self.objs[0]
def _create_test_policy_with_rules(self, rule_type, reload_rules=False):
policy_obj = self._create_test_policy()

View File

@ -27,7 +27,6 @@ from oslo_versionedobjects import fixture
import testtools
from neutron.common import constants
from neutron.common import utils as common_utils
from neutron import context
from neutron.db import db_base_plugin_v2
from neutron.db import model_base
@ -48,14 +47,12 @@ OBJECTS_BASE_OBJ_FROM_PRIMITIVE = ('oslo_versionedobjects.base.'
TIMESTAMP_FIELDS = ['created_at', 'updated_at', 'revision_number']
class FakeModel(object):
def __init__(self, *args, **kwargs):
pass
class FakeModel(dict):
pass
class ObjectFieldsModel(object):
def __init__(self, *args, **kwargs):
pass
class ObjectFieldsModel(dict):
pass
@obj_base.VersionedObjectRegistry.register_if(False)
@ -396,9 +393,11 @@ FIELD_TYPE_VALUE_GENERATOR_MAP = {
}
# TODO(ihrachys) consider renaming into e.g. get_obj_persistent_fields
def get_obj_db_fields(obj):
return {field: getattr(obj, field) for field in obj.fields
if field not in obj.synthetic_fields}
if field not in obj.synthetic_fields
if field in obj}
def get_value(generator, version):
@ -429,11 +428,19 @@ class _BaseObjectTestCase(object):
# neutron.objects.db.api from core plugin instance
self.setup_coreplugin(self.CORE_PLUGIN)
self.context = context.get_admin_context()
self.db_objs = list(self.get_random_fields() for _ in range(3))
self.db_objs = [
self._test_class.db_model(**self.get_random_fields())
for _ in range(3)
]
self.db_obj = self.db_objs[0]
# TODO(ihrachys) remove obj_fields since they duplicate self.objs
self.obj_fields = [self._test_class.modify_fields_from_db(db_obj)
for db_obj in self.db_objs]
self.objs = [
self._test_class(self.context, **fields)
for fields in self.obj_fields
]
valid_field = [f for f in self._test_class.fields
if f not in self._test_class.synthetic_fields][0]
@ -447,8 +454,10 @@ class _BaseObjectTestCase(object):
synthetic_obj_fields = self.get_random_fields(FakeSmallNeutronObject)
self.model_map = {
self._test_class.db_model: self.db_objs,
ObjectFieldsModel: [synthetic_obj_fields]}
ObjectFieldsModel: [ObjectFieldsModel(**synthetic_obj_fields)]}
# TODO(ihrachys): rename the method to explicitly reflect it returns db
# attributes not object fields
@classmethod
def get_random_fields(cls, obj_cls=None):
obj_cls = obj_cls or cls._test_class
@ -504,6 +513,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self.model_map[self._test_class.db_model] = self.db_objs
self.pager_map = collections.defaultdict(lambda: None)
# TODO(ihrachys) document the intent of all common test cases in docstrings
def test_get_object(self):
with mock.patch.object(obj_db_api, 'get_object',
return_value=self.db_obj) as get_object_mock:
@ -512,7 +522,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj_keys = self.generate_object_keys(self._test_class)
obj = self._test_class.get_object(self.context, **obj_keys)
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.obj_fields[0], get_obj_db_fields(obj))
self._check_equal(obj, self.objs[0])
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model,
**self._test_class.modify_fields_to_db(obj_keys))
@ -550,8 +560,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj = self._test_class.get_object(self.context,
**obj_keys)
self.assertTrue(self._is_test_class(obj))
self.assertEqual(self.obj_fields[0],
get_obj_db_fields(obj))
self._check_equal(obj, self.objs[0])
get_object_mock.assert_called_once_with(
self.context, self._test_class.db_model,
**self._test_class.modify_fields_to_db(obj_keys))
@ -574,37 +583,25 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
return mock_calls
def test_get_objects(self):
'''Test that get_objects fetches data from database.'''
with mock.patch.object(
obj_db_api, 'get_objects',
side_effect=self.fake_get_objects) as get_objects_mock:
objs = self._test_class.get_objects(self.context)
self._validate_objects(self.db_objs, objs)
mock_calls = [
mock.call(self.context, self._test_class.db_model,
_pager=self.pager_map[self._test_class.obj_name()])
]
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
self.db_objs))
get_objects_mock.assert_has_calls(mock_calls)
self.assertItemsEqual(
[get_obj_db_fields(obj) for obj in self.objs],
[get_obj_db_fields(obj) for obj in objs])
get_objects_mock.assert_any_call(
self.context, self._test_class.db_model,
_pager=self.pager_map[self._test_class.obj_name()]
)
def test_get_objects_valid_fields(self):
'''Test that a valid filter does not raise an error.'''
with mock.patch.object(
obj_db_api, 'get_objects',
side_effect=self.fake_get_objects) as get_objects_mock:
objs = self._test_class.get_objects(self.context,
**self.valid_field_filter)
self._validate_objects(self.db_objs, objs)
mock_calls = [
mock.call(
self.context, self._test_class.db_model,
_pager=self.pager_map[self._test_class.obj_name()],
**self._test_class.modify_fields_to_db(self.valid_field_filter)
)
]
mock_calls.extend(self._get_synthetic_fields_get_objects_calls(
[self.db_obj]))
get_objects_mock.assert_has_calls(mock_calls)
obj_db_api, 'get_objects', side_effect=self.fake_get_objects):
self._test_class.get_objects(self.context,
**self.valid_field_filter)
def test_get_objects_mixed_fields(self):
synthetic_fields = (
@ -661,19 +658,11 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
self._test_class.count, self.context,
fake_field='xxx')
def _validate_objects(self, expected, observed):
self.assertTrue(all(self._is_test_class(obj) for obj in observed))
self.assertEqual(
sorted([self._test_class.modify_fields_from_db(db_obj)
for db_obj in expected],
key=common_utils.safe_sort_key),
sorted([get_obj_db_fields(obj) for obj in observed],
key=common_utils.safe_sort_key))
def _check_equal(self, obj, db_obj):
self.assertEqual(
sorted(db_obj),
sorted(get_obj_db_fields(obj)))
# TODO(ihrachys) swap the order of arguments to reflect the order of
# self.assert* methods
def _check_equal(self, observed, expected):
self.assertItemsEqual(get_obj_db_fields(expected),
get_obj_db_fields(observed))
def test_create(self):
with mock.patch.object(obj_db_api, 'create_object',
@ -681,21 +670,21 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
with mock.patch.object(obj_db_api, 'get_objects',
side_effect=self.fake_get_objects):
obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.obj_fields[0])
self._check_equal(obj, self.objs[0])
obj.create()
self._check_equal(obj, self.obj_fields[0])
self._check_equal(obj, self.objs[0])
create_mock.assert_called_once_with(
self.context, self._test_class.db_model, self.db_obj)
self.context, self._test_class.db_model,
self._test_class.modify_fields_to_db(
get_obj_db_fields(self.objs[0])))
def test_create_updates_from_db_object(self):
with mock.patch.object(obj_db_api, 'create_object',
return_value=self.db_obj):
with mock.patch.object(obj_db_api, 'get_objects',
side_effect=self.fake_get_objects):
obj = self._test_class(self.context, **self.obj_fields[1])
self._check_equal(obj, self.obj_fields[1])
obj.create()
self._check_equal(obj, self.obj_fields[0])
self.objs[1].create()
self._check_equal(self.objs[1], self.objs[0])
def test_create_duplicates(self):
with mock.patch.object(obj_db_api, 'create_object',
@ -772,7 +761,7 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
side_effect=self.fake_get_objects):
obj = self._test_class(self.context, **self.obj_fields[0])
# get new values and fix keys
update_mock.return_value = self.db_objs[1].copy()
update_mock.return_value = self.db_objs[1]
fixed_keys = self._test_class.modify_fields_to_db(
obj._get_composite_keys())
for key, value in fixed_keys.items():
@ -813,14 +802,14 @@ class BaseObjectIfaceTestCase(_BaseObjectTestCase, test_base.BaseTestCase):
obj_db_api, 'get_objects',
side_effect=self.fake_get_objects):
obj.update()
self._check_equal(obj, self.obj_fields[0])
self._check_equal(obj, self.objs[0])
@mock.patch.object(obj_db_api, 'delete_object')
def test_delete(self, delete_mock):
obj = self._test_class(self.context, **self.obj_fields[0])
self._check_equal(obj, self.obj_fields[0])
self._check_equal(obj, self.objs[0])
obj.delete()
self._check_equal(obj, self.obj_fields[0])
self._check_equal(obj, self.objs[0])
delete_mock.assert_called_once_with(
self.context, self._test_class.db_model,
**self._test_class.modify_fields_to_db(obj._get_composite_keys()))
@ -1024,7 +1013,7 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
continue
for db_obj in self.db_objs:
objclass_fields = self.get_random_fields(objclass)
db_obj[synth_field] = [objclass_fields]
db_obj[synth_field] = [objclass.db_model(**objclass_fields)]
def _create_test_network(self):
# TODO(ihrachys): replace with network.create() once we get an object
@ -1208,10 +1197,13 @@ class BaseDbObjectTestCase(_BaseObjectTestCase,
obj = self._make_object(self.obj_fields[0])
obj.create()
for field in remove_timestamps_from_fields(self.obj_fields[0]):
filters = {field: [self.obj_fields[0][field]]}
for field in remove_timestamps_from_fields(get_obj_db_fields(obj)):
filters = {field: [self.objs[0][field]]}
new = self._test_class.get_objects(self.context, **filters)
self.assertEqual([obj], new, 'Filtering by %s failed.' % field)
self.assertItemsEqual(
[obj._get_composite_keys()],
[obj_._get_composite_keys() for obj_ in new],
'Filtering by %s failed.' % field)
def _get_non_synth_fields(self, objclass, db_attrs):
fields = objclass.modify_fields_from_db(db_attrs)

View File

@ -29,9 +29,8 @@ from neutron.tests.unit.objects import test_base
from neutron.tests.unit import testlib_api
class FakeDbModel(object):
def __init__(self, *args, **kwargs):
pass
class FakeDbModel(dict):
pass
class FakeRbacModel(rbac_db_models.RBACColumns, model_base.BASEV2):

View File

@ -118,7 +118,7 @@ class DefaultSecurityGroupDbObjTestCase(test_base.BaseDbObjectTestCase,
self.sg_obj = securitygroup.SecurityGroup(
self.context, **test_base.remove_timestamps_from_fields(sg_fields))
self.sg_obj.create()
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['security_group_id'] = self.sg_obj['id']
@ -140,6 +140,6 @@ class SecurityGroupRuleDbObjTestCase(test_base.BaseDbObjectTestCase,
self.sg_obj = securitygroup.SecurityGroup(
self.context, **test_base.remove_timestamps_from_fields(sg_fields))
self.sg_obj.create()
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['security_group_id'] = self.sg_obj['id']
obj['remote_group_id'] = self.sg_obj['id']

View File

@ -39,7 +39,7 @@ class IPAllocationPoolDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
super(IPAllocationPoolDbObjectTestCase, self).setUp()
self._create_test_network()
self._create_test_subnet(self._network)
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['subnet_id'] = self._subnet['id']
@ -69,7 +69,7 @@ class DNSNameServerDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
for db_obj in self.db_objs]
self._create_test_network()
self._create_test_subnet(self._network)
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['subnet_id'] = self._subnet['id']
def _is_objects_unique(self):
@ -128,7 +128,7 @@ class RouteDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
super(RouteDbObjectTestCase, self).setUp()
self._create_test_network()
self._create_test_subnet(self._network)
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['subnet_id'] = self._subnet['id']
@ -151,7 +151,7 @@ class SubnetDbObjectTestCase(obj_test_base.BaseDbObjectTestCase,
super(SubnetDbObjectTestCase, self).setUp()
self._create_test_network()
self._create_test_segment(self._network)
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['network_id'] = self._network['id']
obj['segment_id'] = self._segment['id']

View File

@ -87,5 +87,5 @@ class SubnetPoolPrefixDbObjectTestCase(
def setUp(self):
super(SubnetPoolPrefixDbObjectTestCase, self).setUp()
self._create_test_subnetpool()
for obj in itertools.chain(self.db_objs, self.obj_fields):
for obj in itertools.chain(self.db_objs, self.obj_fields, self.objs):
obj['subnetpool_id'] = self._pool.id