Merge "Replace model_query with direct query call"

This commit is contained in:
Jenkins 2016-06-30 15:56:45 +00:00 committed by Gerrit Code Review
commit 55583202bd
2 changed files with 65 additions and 71 deletions

View File

@ -80,19 +80,13 @@ def get_backend():
return sys.modules[__name__]
def model_query(context, *args):
session = context.session
query = session.query(*args)
return query
def soft_delete_aware_query(context, *args, **kwargs):
"""Stack query helper that accounts for context's `show_deleted` field.
:param show_deleted: if True, overrides context's show_deleted field.
"""
query = model_query(context, *args)
query = context.session.query(*args)
show_deleted = kwargs.get('show_deleted') or context.show_deleted
if not show_deleted:
@ -101,7 +95,7 @@ def soft_delete_aware_query(context, *args, **kwargs):
def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id)
result = context.session.query(models.RawTemplate).get(template_id)
if not result:
raise exception.NotFound(_('raw template with id %s not found') %
@ -161,7 +155,7 @@ def raw_template_files_get(context, files_id):
def resource_get(context, resource_id):
result = model_query(context, models.Resource).get(resource_id)
result = context.session.query(models.Resource).get(resource_id)
if not result:
raise exception.NotFound(_("resource with id %s not found") %
@ -170,8 +164,8 @@ def resource_get(context, resource_id):
def resource_get_by_name_and_stack(context, resource_name, stack_id):
result = model_query(
context, models.Resource
result = context.session.query(
models.Resource
).filter_by(
name=resource_name
).filter_by(
@ -181,7 +175,7 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id):
def resource_get_by_physical_resource_id(context, physical_resource_id):
results = (model_query(context, models.Resource)
results = (context.session.query(models.Resource)
.filter_by(physical_resource_id=physical_resource_id)
.all())
@ -193,7 +187,7 @@ def resource_get_by_physical_resource_id(context, physical_resource_id):
def resource_get_all(context):
results = model_query(context, models.Resource).all()
results = context.session.query(models.Resource).all()
if not results:
raise exception.NotFound(_('no resources were found'))
@ -221,7 +215,7 @@ def resource_data_get_all(context, resource_id, data=None):
If data is encrypted, this method will decrypt the results.
"""
if data is None:
data = (model_query(context, models.ResourceData)
data = (context.session.query(models.ResourceData)
.filter_by(resource_id=resource_id)).all()
if not data:
@ -274,7 +268,7 @@ def stack_tags_delete(context, stack_id):
def stack_tags_get(context, stack_id):
result = (model_query(context, models.StackTag)
result = (context.session.query(models.StackTag)
.filter_by(stack_id=stack_id)
.all())
return result or None
@ -285,7 +279,7 @@ def resource_data_get_by_key(context, resource_id, key):
Does not decrypt resource_data.
"""
result = (model_query(context, models.ResourceData)
result = (context.session.query(models.ResourceData)
.filter_by(resource_id=resource_id)
.filter_by(key=key).first())
@ -314,7 +308,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
def resource_exchange_stacks(context, resource_id1, resource_id2):
query = model_query(context, models.Resource)
query = context.session.query(models.Resource)
session = query.session
session.begin()
@ -339,8 +333,8 @@ def resource_create(context, values):
def resource_get_all_by_stack(context, stack_id, filters=None):
query = model_query(
context, models.Resource
query = context.session.query(
models.Resource
).filter_by(
stack_id=stack_id
).options(orm.joinedload("data"))
@ -357,9 +351,9 @@ def resource_get_all_by_stack(context, stack_id, filters=None):
def resource_get_all_active_by_stack(context, stack_id):
filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'}
subquery = model_query(context, models.Resource.id).filter_by(**filters)
subquery = context.session.query(models.Resource.id).filter_by(**filters)
results = model_query(context, models.Resource).filter_by(
results = context.session.query(models.Resource).filter_by(
stack_id=stack_id).filter(
models.Resource.id.notin_(subquery.as_scalar())
).options(orm.joinedload("data")).all()
@ -371,8 +365,8 @@ def resource_get_all_active_by_stack(context, stack_id):
def resource_get_all_by_root_stack(context, stack_id, filters=None):
query = model_query(
context, models.Resource
query = context.session.query(
models.Resource
).filter_by(
root_stack_id=stack_id
).options(orm.joinedload("data"))
@ -405,7 +399,7 @@ def stack_get_by_name(context, stack_name):
def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
eager_load=False):
query = model_query(context, models.Stack)
query = context.session.query(models.Stack)
if eager_load:
query = query.options(orm.joinedload("raw_template"))
result = query.get(stack_id)
@ -425,7 +419,7 @@ def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
def stack_get_status(context, stack_id):
query = model_query(context, models.Stack)
query = context.session.query(models.Stack)
query = query.options(
orm.load_only("action", "status", "status_reason", "updated_at"))
result = query.filter_by(id=stack_id).first()
@ -468,7 +462,7 @@ def _paginate_query(context, query, model, limit=None, sort_keys=None,
model_marker = None
if marker:
model_marker = model_query(context, model).get(marker)
model_marker = context.session.query(model).get(marker)
try:
query = utils.paginate_query(query, model, limit, sort_keys,
model_marker, sort_dir)
@ -691,8 +685,8 @@ def stack_get_root_id(context, stack_id):
def stack_count_total_resources(context, stack_id):
# count all resources which belong to the root stack
results = model_query(
context, models.Resource
results = context.session.query(
models.Resource
).filter(models.Resource.root_stack_id == stack_id).count()
return results
@ -731,7 +725,7 @@ def user_creds_create(context):
def user_creds_get(context, user_creds_id):
db_result = model_query(context, models.UserCreds).get(user_creds_id)
db_result = context.session.query(models.UserCreds).get(user_creds_id)
if db_result is None:
return None
# Return a dict copy of db results, do not decrypt details into db_result
@ -747,7 +741,7 @@ def user_creds_get(context, user_creds_id):
@db_utils.retry_on_stale_data_error
def user_creds_delete(context, user_creds_id):
creds = model_query(context, models.UserCreds).get(user_creds_id)
creds = context.session.query(models.UserCreds).get(user_creds_id)
if not creds:
raise exception.NotFound(
_('Attempt to delete user creds with id '
@ -758,22 +752,22 @@ def user_creds_delete(context, user_creds_id):
def event_get(context, event_id):
result = model_query(context, models.Event).get(event_id)
result = context.session.query(models.Event).get(event_id)
return result
def event_get_all(context):
stacks = soft_delete_aware_query(context, models.Stack)
stack_ids = [stack.id for stack in stacks]
results = model_query(
context, models.Event
results = context.session.query(
models.Event
).filter(models.Event.stack_id.in_(stack_ids)).all()
return results
def event_get_all_by_tenant(context, limit=None, marker=None,
sort_keys=None, sort_dir=None, filters=None):
query = model_query(context, models.Event)
query = context.session.query(models.Event)
query = db_filters.exact_filter(query, models.Event, filters)
query = query.join(
models.Event.stack
@ -784,7 +778,7 @@ def event_get_all_by_tenant(context, limit=None, marker=None,
def _query_all_by_stack(context, stack_id):
query = model_query(context, models.Event).filter_by(stack_id=stack_id)
query = context.session.query(models.Event).filter_by(stack_id=stack_id)
return query
@ -809,10 +803,10 @@ def _events_paginate_query(context, query, model, limit=None, sort_keys=None,
model_marker = None
if marker:
# not to use model_query(context, model).get(marker), because
# not to use context.session.query(model).get(marker), because
# user can only see the ID(column 'uuid') and the ID as the marker
model_marker = model_query(
context, model).filter_by(uuid=marker).first()
model_marker = context.session.query(
model).filter_by(uuid=marker).first()
try:
query = utils.paginate_query(query, model, limit, sort_keys,
model_marker, sort_dir)
@ -841,7 +835,7 @@ def _events_filter_and_page_query(context, query,
def event_count_all_by_stack(context, stack_id):
query = model_query(context, func.count(models.Event.id))
query = context.session.query(func.count(models.Event.id))
return query.filter_by(stack_id=stack_id).scalar()
@ -874,24 +868,24 @@ def event_create(context, values):
def watch_rule_get(context, watch_rule_id):
result = model_query(context, models.WatchRule).get(watch_rule_id)
result = context.session.query(models.WatchRule).get(watch_rule_id)
return result
def watch_rule_get_by_name(context, watch_rule_name):
result = model_query(
context, models.WatchRule).filter_by(name=watch_rule_name).first()
result = context.session.query(
models.WatchRule).filter_by(name=watch_rule_name).first()
return result
def watch_rule_get_all(context):
results = model_query(context, models.WatchRule).all()
results = context.session.query(models.WatchRule).all()
return results
def watch_rule_get_all_by_stack(context, stack_id):
results = model_query(
context, models.WatchRule).filter_by(stack_id=stack_id).all()
results = context.session.query(
models.WatchRule).filter_by(stack_id=stack_id).all()
return results
@ -936,12 +930,12 @@ def watch_data_create(context, values):
def watch_data_get_all(context):
results = model_query(context, models.WatchData).all()
results = context.session.query(models.WatchData).all()
return results
def watch_data_get_all_by_watch_rule_id(context, watch_rule_id):
results = model_query(context, models.WatchData).filter_by(
results = context.session.query(models.WatchData).filter_by(
watch_rule_id=watch_rule_id).all()
return results
@ -954,7 +948,7 @@ def software_config_create(context, values):
def software_config_get(context, config_id):
result = model_query(context, models.SoftwareConfig).get(config_id)
result = context.session.query(models.SoftwareConfig).get(config_id)
if (result is not None and context is not None and
result.tenant != context.tenant_id):
result = None
@ -967,7 +961,7 @@ def software_config_get(context, config_id):
def software_config_get_all(context, limit=None, marker=None,
tenant_safe=True):
query = model_query(context, models.SoftwareConfig)
query = context.session.query(models.SoftwareConfig)
if tenant_safe and not context.is_admin:
query = query.filter_by(tenant=context.tenant_id)
return _paginate_query(context, query, models.SoftwareConfig,
@ -977,7 +971,7 @@ def software_config_get_all(context, limit=None, marker=None,
def software_config_delete(context, config_id):
config = software_config_get(context, config_id)
# Query if the software config has been referenced by deployment.
result = model_query(context, models.SoftwareDeployment).filter_by(
result = context.session.query(models.SoftwareDeployment).filter_by(
config_id=config_id).first()
if result:
msg = (_("Software config with id %s can not be deleted as "
@ -999,7 +993,8 @@ def software_deployment_create(context, values):
def software_deployment_get(context, deployment_id):
result = model_query(context, models.SoftwareDeployment).get(deployment_id)
result = context.session.query(
models.SoftwareDeployment).get(deployment_id)
if (result is not None and context is not None and
context.tenant_id not in (result.tenant,
result.stack_user_project_id)):
@ -1013,8 +1008,8 @@ def software_deployment_get(context, deployment_id):
def software_deployment_get_all(context, server_id=None):
sd = models.SoftwareDeployment
query = model_query(
context, sd
query = context.session.query(
sd
).filter(sqlalchemy.or_(
sd.tenant == context.tenant_id,
sd.stack_user_project_id == context.tenant_id)
@ -1043,7 +1038,7 @@ def snapshot_create(context, values):
def snapshot_get(context, snapshot_id):
result = model_query(context, models.Snapshot).get(snapshot_id)
result = context.session.query(models.Snapshot).get(snapshot_id)
if (result is not None and context is not None and
context.tenant_id != result.tenant):
result = None
@ -1078,7 +1073,7 @@ def snapshot_delete(context, snapshot_id):
def snapshot_get_all(context, stack_id):
return model_query(context, models.Snapshot).filter_by(
return context.session.query(models.Snapshot).filter_by(
stack_id=stack_id, tenant=context.tenant_id)
@ -1108,19 +1103,19 @@ def service_delete(context, service_id, soft_delete=True):
def service_get(context, service_id):
result = model_query(context, models.Service).get(service_id)
result = context.session.query(models.Service).get(service_id)
if result is None:
raise exception.EntityNotFound(entity='Service', name=service_id)
return result
def service_get_all(context):
return (model_query(context, models.Service).
return (context.session.query(models.Service).
filter_by(deleted_at=None).all())
def service_get_all_by_args(context, host, binary, hostname):
return (model_query(context, models.Service).
return (context.session.query(models.Service).
filter_by(host=host).
filter_by(binary=binary).
filter_by(hostname=hostname).all())
@ -1250,7 +1245,7 @@ def purge_deleted(age, granularity='days'):
def sync_point_delete_all_by_stack_and_traversal(context, stack_id,
traversal_id):
rows_deleted = model_query(context, models.SyncPoint).filter_by(
rows_deleted = context.session.query(models.SyncPoint).filter_by(
stack_id=stack_id, traversal_id=traversal_id).delete()
return rows_deleted
@ -1267,7 +1262,7 @@ def sync_point_create(context, values):
def sync_point_get(context, entity_id, traversal_id, is_update):
entity_id = str(entity_id)
return model_query(context, models.SyncPoint).get(
return context.session.query(models.SyncPoint).get(
(entity_id, traversal_id, is_update)
)
@ -1276,7 +1271,7 @@ def sync_point_update_input_data(context, entity_id,
traversal_id, is_update, atomic_key,
input_data):
entity_id = str(entity_id)
rows_updated = model_query(context, models.SyncPoint).filter_by(
rows_updated = context.session.query(models.SyncPoint).filter_by(
entity_id=entity_id,
traversal_id=traversal_id,
is_update=is_update,
@ -1499,19 +1494,19 @@ def _get_batch(session, ctxt, query, model, batch_size=50):
def reset_stack_status(context, stack_id, stack=None):
if stack is None:
stack = model_query(context, models.Stack).get(stack_id)
stack = context.session.query(models.Stack).get(stack_id)
if stack is None:
raise exception.NotFound(_('Stack with id %s not found') % stack_id)
session = context.session
with session.begin():
query = model_query(context, models.Resource).filter_by(
query = context.session.query(models.Resource).filter_by(
status='IN_PROGRESS', stack_id=stack_id)
query.update({'status': 'FAILED',
'status_reason': 'Stack status manually reset'})
query = model_query(context, models.ResourceData)
query = context.session.query(models.ResourceData)
query = query.join(models.Resource)
query = query.filter_by(stack_id=stack_id)
query = query.filter(
@ -1519,11 +1514,11 @@ def reset_stack_status(context, stack_id, stack=None):
data_ids = [data.id for data in query]
if data_ids:
query = model_query(context, models.ResourceData)
query = context.session.query(models.ResourceData)
query = query.filter(models.ResourceData.id.in_(data_ids))
query.delete(synchronize_session='fetch')
query = model_query(context, models.Stack).filter_by(owner_id=stack_id)
query = context.session.query(models.Stack).filter_by(owner_id=stack_id)
for child in query:
reset_stack_status(context, child.id, child)

View File

@ -259,18 +259,17 @@ class SqlAlchemyTest(common.HeatTestCase):
self.assertIn(['name', 'id'], args)
@mock.patch.object(db_api.utils, 'paginate_query')
@mock.patch.object(db_api, 'model_query')
def test_paginate_query_gets_model_marker(self, mock_query,
mock_paginate_query):
def test_paginate_query_gets_model_marker(self, mock_paginate_query):
query = mock.Mock()
model = mock.Mock()
marker = mock.Mock()
mock_query_object = mock.Mock()
mock_query_object.get.return_value = 'real_marker'
mock_query.return_value = mock_query_object
ctx = mock.MagicMock()
ctx.session.query.return_value = mock_query_object
db_api._paginate_query(self.ctx, query, model, marker=marker)
db_api._paginate_query(ctx, query, model, marker=marker)
mock_query_object.get.assert_called_once_with(marker)
args, _ = mock_paginate_query.call_args
self.assertIn('real_marker', args)