Replace model_query with direct query call

The following changes for bug #1479723 may need the flexibility to
create reader queries with sessions other than the context session.
Also this function is no longer saving any typing or doing anything
useful.

Change-Id: I0242febe97cf1da30fb094cf7a395db6c2f4665a
Related-Bug: #1479723
This commit is contained in:
Steve Baker 2016-06-17 12:40:19 +12:00
parent 419c9ab994
commit a65bd2b19d
2 changed files with 65 additions and 71 deletions

View File

@ -80,19 +80,13 @@ def get_backend():
return sys.modules[__name__] return sys.modules[__name__]
def model_query(context, *args):
session = context.session
query = session.query(*args)
return query
def soft_delete_aware_query(context, *args, **kwargs): def soft_delete_aware_query(context, *args, **kwargs):
"""Stack query helper that accounts for context's `show_deleted` field. """Stack query helper that accounts for context's `show_deleted` field.
:param show_deleted: if True, overrides context's show_deleted field. :param show_deleted: if True, overrides context's show_deleted field.
""" """
query = model_query(context, *args) query = context.session.query(*args)
show_deleted = kwargs.get('show_deleted') or context.show_deleted show_deleted = kwargs.get('show_deleted') or context.show_deleted
if not show_deleted: if not show_deleted:
@ -101,7 +95,7 @@ def soft_delete_aware_query(context, *args, **kwargs):
def raw_template_get(context, template_id): def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id) result = context.session.query(models.RawTemplate).get(template_id)
if not result: if not result:
raise exception.NotFound(_('raw template with id %s not found') % raise exception.NotFound(_('raw template with id %s not found') %
@ -161,7 +155,7 @@ def raw_template_files_get(context, files_id):
def resource_get(context, resource_id): def resource_get(context, resource_id):
result = model_query(context, models.Resource).get(resource_id) result = context.session.query(models.Resource).get(resource_id)
if not result: if not result:
raise exception.NotFound(_("resource with id %s not found") % raise exception.NotFound(_("resource with id %s not found") %
@ -170,8 +164,8 @@ def resource_get(context, resource_id):
def resource_get_by_name_and_stack(context, resource_name, stack_id): def resource_get_by_name_and_stack(context, resource_name, stack_id):
result = model_query( result = context.session.query(
context, models.Resource models.Resource
).filter_by( ).filter_by(
name=resource_name name=resource_name
).filter_by( ).filter_by(
@ -181,7 +175,7 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id):
def resource_get_by_physical_resource_id(context, physical_resource_id): def resource_get_by_physical_resource_id(context, physical_resource_id):
results = (model_query(context, models.Resource) results = (context.session.query(models.Resource)
.filter_by(physical_resource_id=physical_resource_id) .filter_by(physical_resource_id=physical_resource_id)
.all()) .all())
@ -193,7 +187,7 @@ def resource_get_by_physical_resource_id(context, physical_resource_id):
def resource_get_all(context): def resource_get_all(context):
results = model_query(context, models.Resource).all() results = context.session.query(models.Resource).all()
if not results: if not results:
raise exception.NotFound(_('no resources were found')) raise exception.NotFound(_('no resources were found'))
@ -221,7 +215,7 @@ def resource_data_get_all(context, resource_id, data=None):
If data is encrypted, this method will decrypt the results. If data is encrypted, this method will decrypt the results.
""" """
if data is None: if data is None:
data = (model_query(context, models.ResourceData) data = (context.session.query(models.ResourceData)
.filter_by(resource_id=resource_id)).all() .filter_by(resource_id=resource_id)).all()
if not data: if not data:
@ -274,7 +268,7 @@ def stack_tags_delete(context, stack_id):
def stack_tags_get(context, stack_id): def stack_tags_get(context, stack_id):
result = (model_query(context, models.StackTag) result = (context.session.query(models.StackTag)
.filter_by(stack_id=stack_id) .filter_by(stack_id=stack_id)
.all()) .all())
return result or None return result or None
@ -285,7 +279,7 @@ def resource_data_get_by_key(context, resource_id, key):
Does not decrypt resource_data. Does not decrypt resource_data.
""" """
result = (model_query(context, models.ResourceData) result = (context.session.query(models.ResourceData)
.filter_by(resource_id=resource_id) .filter_by(resource_id=resource_id)
.filter_by(key=key).first()) .filter_by(key=key).first())
@ -314,7 +308,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
def resource_exchange_stacks(context, resource_id1, resource_id2): def resource_exchange_stacks(context, resource_id1, resource_id2):
query = model_query(context, models.Resource) query = context.session.query(models.Resource)
session = query.session session = query.session
session.begin() session.begin()
@ -339,8 +333,8 @@ def resource_create(context, values):
def resource_get_all_by_stack(context, stack_id, filters=None): def resource_get_all_by_stack(context, stack_id, filters=None):
query = model_query( query = context.session.query(
context, models.Resource models.Resource
).filter_by( ).filter_by(
stack_id=stack_id stack_id=stack_id
).options(orm.joinedload("data")) ).options(orm.joinedload("data"))
@ -357,9 +351,9 @@ def resource_get_all_by_stack(context, stack_id, filters=None):
def resource_get_all_active_by_stack(context, stack_id): def resource_get_all_active_by_stack(context, stack_id):
filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'} filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'}
subquery = model_query(context, models.Resource.id).filter_by(**filters) subquery = context.session.query(models.Resource.id).filter_by(**filters)
results = model_query(context, models.Resource).filter_by( results = context.session.query(models.Resource).filter_by(
stack_id=stack_id).filter( stack_id=stack_id).filter(
models.Resource.id.notin_(subquery.as_scalar()) models.Resource.id.notin_(subquery.as_scalar())
).options(orm.joinedload("data")).all() ).options(orm.joinedload("data")).all()
@ -371,8 +365,8 @@ def resource_get_all_active_by_stack(context, stack_id):
def resource_get_all_by_root_stack(context, stack_id, filters=None): def resource_get_all_by_root_stack(context, stack_id, filters=None):
query = model_query( query = context.session.query(
context, models.Resource models.Resource
).filter_by( ).filter_by(
root_stack_id=stack_id root_stack_id=stack_id
).options(orm.joinedload("data")) ).options(orm.joinedload("data"))
@ -405,7 +399,7 @@ def stack_get_by_name(context, stack_name):
def stack_get(context, stack_id, show_deleted=False, tenant_safe=True, def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
eager_load=False): eager_load=False):
query = model_query(context, models.Stack) query = context.session.query(models.Stack)
if eager_load: if eager_load:
query = query.options(orm.joinedload("raw_template")) query = query.options(orm.joinedload("raw_template"))
result = query.get(stack_id) result = query.get(stack_id)
@ -425,7 +419,7 @@ def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
def stack_get_status(context, stack_id): def stack_get_status(context, stack_id):
query = model_query(context, models.Stack) query = context.session.query(models.Stack)
query = query.options( query = query.options(
orm.load_only("action", "status", "status_reason", "updated_at")) orm.load_only("action", "status", "status_reason", "updated_at"))
result = query.filter_by(id=stack_id).first() result = query.filter_by(id=stack_id).first()
@ -468,7 +462,7 @@ def _paginate_query(context, query, model, limit=None, sort_keys=None,
model_marker = None model_marker = None
if marker: if marker:
model_marker = model_query(context, model).get(marker) model_marker = context.session.query(model).get(marker)
try: try:
query = utils.paginate_query(query, model, limit, sort_keys, query = utils.paginate_query(query, model, limit, sort_keys,
model_marker, sort_dir) model_marker, sort_dir)
@ -691,8 +685,8 @@ def stack_get_root_id(context, stack_id):
def stack_count_total_resources(context, stack_id): def stack_count_total_resources(context, stack_id):
# count all resources which belong to the root stack # count all resources which belong to the root stack
results = model_query( results = context.session.query(
context, models.Resource models.Resource
).filter(models.Resource.root_stack_id == stack_id).count() ).filter(models.Resource.root_stack_id == stack_id).count()
return results return results
@ -731,7 +725,7 @@ def user_creds_create(context):
def user_creds_get(context, user_creds_id): def user_creds_get(context, user_creds_id):
db_result = model_query(context, models.UserCreds).get(user_creds_id) db_result = context.session.query(models.UserCreds).get(user_creds_id)
if db_result is None: if db_result is None:
return None return None
# Return a dict copy of db results, do not decrypt details into db_result # Return a dict copy of db results, do not decrypt details into db_result
@ -747,7 +741,7 @@ def user_creds_get(context, user_creds_id):
@db_utils.retry_on_stale_data_error @db_utils.retry_on_stale_data_error
def user_creds_delete(context, user_creds_id): def user_creds_delete(context, user_creds_id):
creds = model_query(context, models.UserCreds).get(user_creds_id) creds = context.session.query(models.UserCreds).get(user_creds_id)
if not creds: if not creds:
raise exception.NotFound( raise exception.NotFound(
_('Attempt to delete user creds with id ' _('Attempt to delete user creds with id '
@ -758,22 +752,22 @@ def user_creds_delete(context, user_creds_id):
def event_get(context, event_id): def event_get(context, event_id):
result = model_query(context, models.Event).get(event_id) result = context.session.query(models.Event).get(event_id)
return result return result
def event_get_all(context): def event_get_all(context):
stacks = soft_delete_aware_query(context, models.Stack) stacks = soft_delete_aware_query(context, models.Stack)
stack_ids = [stack.id for stack in stacks] stack_ids = [stack.id for stack in stacks]
results = model_query( results = context.session.query(
context, models.Event models.Event
).filter(models.Event.stack_id.in_(stack_ids)).all() ).filter(models.Event.stack_id.in_(stack_ids)).all()
return results return results
def event_get_all_by_tenant(context, limit=None, marker=None, def event_get_all_by_tenant(context, limit=None, marker=None,
sort_keys=None, sort_dir=None, filters=None): sort_keys=None, sort_dir=None, filters=None):
query = model_query(context, models.Event) query = context.session.query(models.Event)
query = db_filters.exact_filter(query, models.Event, filters) query = db_filters.exact_filter(query, models.Event, filters)
query = query.join( query = query.join(
models.Event.stack models.Event.stack
@ -784,7 +778,7 @@ def event_get_all_by_tenant(context, limit=None, marker=None,
def _query_all_by_stack(context, stack_id): def _query_all_by_stack(context, stack_id):
query = model_query(context, models.Event).filter_by(stack_id=stack_id) query = context.session.query(models.Event).filter_by(stack_id=stack_id)
return query return query
@ -809,10 +803,10 @@ def _events_paginate_query(context, query, model, limit=None, sort_keys=None,
model_marker = None model_marker = None
if marker: if marker:
# not to use model_query(context, model).get(marker), because # not to use context.session.query(model).get(marker), because
# user can only see the ID(column 'uuid') and the ID as the marker # user can only see the ID(column 'uuid') and the ID as the marker
model_marker = model_query( model_marker = context.session.query(
context, model).filter_by(uuid=marker).first() model).filter_by(uuid=marker).first()
try: try:
query = utils.paginate_query(query, model, limit, sort_keys, query = utils.paginate_query(query, model, limit, sort_keys,
model_marker, sort_dir) model_marker, sort_dir)
@ -841,7 +835,7 @@ def _events_filter_and_page_query(context, query,
def event_count_all_by_stack(context, stack_id): def event_count_all_by_stack(context, stack_id):
query = model_query(context, func.count(models.Event.id)) query = context.session.query(func.count(models.Event.id))
return query.filter_by(stack_id=stack_id).scalar() return query.filter_by(stack_id=stack_id).scalar()
@ -874,24 +868,24 @@ def event_create(context, values):
def watch_rule_get(context, watch_rule_id): def watch_rule_get(context, watch_rule_id):
result = model_query(context, models.WatchRule).get(watch_rule_id) result = context.session.query(models.WatchRule).get(watch_rule_id)
return result return result
def watch_rule_get_by_name(context, watch_rule_name): def watch_rule_get_by_name(context, watch_rule_name):
result = model_query( result = context.session.query(
context, models.WatchRule).filter_by(name=watch_rule_name).first() models.WatchRule).filter_by(name=watch_rule_name).first()
return result return result
def watch_rule_get_all(context): def watch_rule_get_all(context):
results = model_query(context, models.WatchRule).all() results = context.session.query(models.WatchRule).all()
return results return results
def watch_rule_get_all_by_stack(context, stack_id): def watch_rule_get_all_by_stack(context, stack_id):
results = model_query( results = context.session.query(
context, models.WatchRule).filter_by(stack_id=stack_id).all() models.WatchRule).filter_by(stack_id=stack_id).all()
return results return results
@ -936,12 +930,12 @@ def watch_data_create(context, values):
def watch_data_get_all(context): def watch_data_get_all(context):
results = model_query(context, models.WatchData).all() results = context.session.query(models.WatchData).all()
return results return results
def watch_data_get_all_by_watch_rule_id(context, watch_rule_id): def watch_data_get_all_by_watch_rule_id(context, watch_rule_id):
results = model_query(context, models.WatchData).filter_by( results = context.session.query(models.WatchData).filter_by(
watch_rule_id=watch_rule_id).all() watch_rule_id=watch_rule_id).all()
return results return results
@ -954,7 +948,7 @@ def software_config_create(context, values):
def software_config_get(context, config_id): def software_config_get(context, config_id):
result = model_query(context, models.SoftwareConfig).get(config_id) result = context.session.query(models.SoftwareConfig).get(config_id)
if (result is not None and context is not None and if (result is not None and context is not None and
result.tenant != context.tenant_id): result.tenant != context.tenant_id):
result = None result = None
@ -967,7 +961,7 @@ def software_config_get(context, config_id):
def software_config_get_all(context, limit=None, marker=None, def software_config_get_all(context, limit=None, marker=None,
tenant_safe=True): tenant_safe=True):
query = model_query(context, models.SoftwareConfig) query = context.session.query(models.SoftwareConfig)
if tenant_safe and not context.is_admin: if tenant_safe and not context.is_admin:
query = query.filter_by(tenant=context.tenant_id) query = query.filter_by(tenant=context.tenant_id)
return _paginate_query(context, query, models.SoftwareConfig, return _paginate_query(context, query, models.SoftwareConfig,
@ -977,7 +971,7 @@ def software_config_get_all(context, limit=None, marker=None,
def software_config_delete(context, config_id): def software_config_delete(context, config_id):
config = software_config_get(context, config_id) config = software_config_get(context, config_id)
# Query if the software config has been referenced by deployment. # Query if the software config has been referenced by deployment.
result = model_query(context, models.SoftwareDeployment).filter_by( result = context.session.query(models.SoftwareDeployment).filter_by(
config_id=config_id).first() config_id=config_id).first()
if result: if result:
msg = (_("Software config with id %s can not be deleted as " msg = (_("Software config with id %s can not be deleted as "
@ -999,7 +993,8 @@ def software_deployment_create(context, values):
def software_deployment_get(context, deployment_id): def software_deployment_get(context, deployment_id):
result = model_query(context, models.SoftwareDeployment).get(deployment_id) result = context.session.query(
models.SoftwareDeployment).get(deployment_id)
if (result is not None and context is not None and if (result is not None and context is not None and
context.tenant_id not in (result.tenant, context.tenant_id not in (result.tenant,
result.stack_user_project_id)): result.stack_user_project_id)):
@ -1013,8 +1008,8 @@ def software_deployment_get(context, deployment_id):
def software_deployment_get_all(context, server_id=None): def software_deployment_get_all(context, server_id=None):
sd = models.SoftwareDeployment sd = models.SoftwareDeployment
query = model_query( query = context.session.query(
context, sd sd
).filter(sqlalchemy.or_( ).filter(sqlalchemy.or_(
sd.tenant == context.tenant_id, sd.tenant == context.tenant_id,
sd.stack_user_project_id == context.tenant_id) sd.stack_user_project_id == context.tenant_id)
@ -1043,7 +1038,7 @@ def snapshot_create(context, values):
def snapshot_get(context, snapshot_id): def snapshot_get(context, snapshot_id):
result = model_query(context, models.Snapshot).get(snapshot_id) result = context.session.query(models.Snapshot).get(snapshot_id)
if (result is not None and context is not None and if (result is not None and context is not None and
context.tenant_id != result.tenant): context.tenant_id != result.tenant):
result = None result = None
@ -1078,7 +1073,7 @@ def snapshot_delete(context, snapshot_id):
def snapshot_get_all(context, stack_id): def snapshot_get_all(context, stack_id):
return model_query(context, models.Snapshot).filter_by( return context.session.query(models.Snapshot).filter_by(
stack_id=stack_id, tenant=context.tenant_id) stack_id=stack_id, tenant=context.tenant_id)
@ -1108,19 +1103,19 @@ def service_delete(context, service_id, soft_delete=True):
def service_get(context, service_id): def service_get(context, service_id):
result = model_query(context, models.Service).get(service_id) result = context.session.query(models.Service).get(service_id)
if result is None: if result is None:
raise exception.EntityNotFound(entity='Service', name=service_id) raise exception.EntityNotFound(entity='Service', name=service_id)
return result return result
def service_get_all(context): def service_get_all(context):
return (model_query(context, models.Service). return (context.session.query(models.Service).
filter_by(deleted_at=None).all()) filter_by(deleted_at=None).all())
def service_get_all_by_args(context, host, binary, hostname): def service_get_all_by_args(context, host, binary, hostname):
return (model_query(context, models.Service). return (context.session.query(models.Service).
filter_by(host=host). filter_by(host=host).
filter_by(binary=binary). filter_by(binary=binary).
filter_by(hostname=hostname).all()) filter_by(hostname=hostname).all())
@ -1250,7 +1245,7 @@ def purge_deleted(age, granularity='days'):
def sync_point_delete_all_by_stack_and_traversal(context, stack_id, def sync_point_delete_all_by_stack_and_traversal(context, stack_id,
traversal_id): traversal_id):
rows_deleted = model_query(context, models.SyncPoint).filter_by( rows_deleted = context.session.query(models.SyncPoint).filter_by(
stack_id=stack_id, traversal_id=traversal_id).delete() stack_id=stack_id, traversal_id=traversal_id).delete()
return rows_deleted return rows_deleted
@ -1267,7 +1262,7 @@ def sync_point_create(context, values):
def sync_point_get(context, entity_id, traversal_id, is_update): def sync_point_get(context, entity_id, traversal_id, is_update):
entity_id = str(entity_id) entity_id = str(entity_id)
return model_query(context, models.SyncPoint).get( return context.session.query(models.SyncPoint).get(
(entity_id, traversal_id, is_update) (entity_id, traversal_id, is_update)
) )
@ -1276,7 +1271,7 @@ def sync_point_update_input_data(context, entity_id,
traversal_id, is_update, atomic_key, traversal_id, is_update, atomic_key,
input_data): input_data):
entity_id = str(entity_id) entity_id = str(entity_id)
rows_updated = model_query(context, models.SyncPoint).filter_by( rows_updated = context.session.query(models.SyncPoint).filter_by(
entity_id=entity_id, entity_id=entity_id,
traversal_id=traversal_id, traversal_id=traversal_id,
is_update=is_update, is_update=is_update,
@ -1499,19 +1494,19 @@ def _get_batch(session, ctxt, query, model, batch_size=50):
def reset_stack_status(context, stack_id, stack=None): def reset_stack_status(context, stack_id, stack=None):
if stack is None: if stack is None:
stack = model_query(context, models.Stack).get(stack_id) stack = context.session.query(models.Stack).get(stack_id)
if stack is None: if stack is None:
raise exception.NotFound(_('Stack with id %s not found') % stack_id) raise exception.NotFound(_('Stack with id %s not found') % stack_id)
session = context.session session = context.session
with session.begin(): with session.begin():
query = model_query(context, models.Resource).filter_by( query = context.session.query(models.Resource).filter_by(
status='IN_PROGRESS', stack_id=stack_id) status='IN_PROGRESS', stack_id=stack_id)
query.update({'status': 'FAILED', query.update({'status': 'FAILED',
'status_reason': 'Stack status manually reset'}) 'status_reason': 'Stack status manually reset'})
query = model_query(context, models.ResourceData) query = context.session.query(models.ResourceData)
query = query.join(models.Resource) query = query.join(models.Resource)
query = query.filter_by(stack_id=stack_id) query = query.filter_by(stack_id=stack_id)
query = query.filter( query = query.filter(
@ -1519,11 +1514,11 @@ def reset_stack_status(context, stack_id, stack=None):
data_ids = [data.id for data in query] data_ids = [data.id for data in query]
if data_ids: if data_ids:
query = model_query(context, models.ResourceData) query = context.session.query(models.ResourceData)
query = query.filter(models.ResourceData.id.in_(data_ids)) query = query.filter(models.ResourceData.id.in_(data_ids))
query.delete(synchronize_session='fetch') query.delete(synchronize_session='fetch')
query = model_query(context, models.Stack).filter_by(owner_id=stack_id) query = context.session.query(models.Stack).filter_by(owner_id=stack_id)
for child in query: for child in query:
reset_stack_status(context, child.id, child) reset_stack_status(context, child.id, child)

View File

@ -259,18 +259,17 @@ class SqlAlchemyTest(common.HeatTestCase):
self.assertIn(['name', 'id'], args) self.assertIn(['name', 'id'], args)
@mock.patch.object(db_api.utils, 'paginate_query') @mock.patch.object(db_api.utils, 'paginate_query')
@mock.patch.object(db_api, 'model_query') def test_paginate_query_gets_model_marker(self, mock_paginate_query):
def test_paginate_query_gets_model_marker(self, mock_query,
mock_paginate_query):
query = mock.Mock() query = mock.Mock()
model = mock.Mock() model = mock.Mock()
marker = mock.Mock() marker = mock.Mock()
mock_query_object = mock.Mock() mock_query_object = mock.Mock()
mock_query_object.get.return_value = 'real_marker' mock_query_object.get.return_value = 'real_marker'
mock_query.return_value = mock_query_object ctx = mock.MagicMock()
ctx.session.query.return_value = mock_query_object
db_api._paginate_query(self.ctx, query, model, marker=marker) db_api._paginate_query(ctx, query, model, marker=marker)
mock_query_object.get.assert_called_once_with(marker) mock_query_object.get.assert_called_once_with(marker)
args, _ = mock_paginate_query.call_args args, _ = mock_paginate_query.call_args
self.assertIn('real_marker', args) self.assertIn('real_marker', args)