Merge "Replace model_query with direct query call"
This commit is contained in:
commit
55583202bd
@ -80,19 +80,13 @@ def get_backend():
|
||||
return sys.modules[__name__]
|
||||
|
||||
|
||||
def model_query(context, *args):
|
||||
session = context.session
|
||||
query = session.query(*args)
|
||||
return query
|
||||
|
||||
|
||||
def soft_delete_aware_query(context, *args, **kwargs):
|
||||
"""Stack query helper that accounts for context's `show_deleted` field.
|
||||
|
||||
:param show_deleted: if True, overrides context's show_deleted field.
|
||||
"""
|
||||
|
||||
query = model_query(context, *args)
|
||||
query = context.session.query(*args)
|
||||
show_deleted = kwargs.get('show_deleted') or context.show_deleted
|
||||
|
||||
if not show_deleted:
|
||||
@ -101,7 +95,7 @@ def soft_delete_aware_query(context, *args, **kwargs):
|
||||
|
||||
|
||||
def raw_template_get(context, template_id):
|
||||
result = model_query(context, models.RawTemplate).get(template_id)
|
||||
result = context.session.query(models.RawTemplate).get(template_id)
|
||||
|
||||
if not result:
|
||||
raise exception.NotFound(_('raw template with id %s not found') %
|
||||
@ -161,7 +155,7 @@ def raw_template_files_get(context, files_id):
|
||||
|
||||
|
||||
def resource_get(context, resource_id):
|
||||
result = model_query(context, models.Resource).get(resource_id)
|
||||
result = context.session.query(models.Resource).get(resource_id)
|
||||
|
||||
if not result:
|
||||
raise exception.NotFound(_("resource with id %s not found") %
|
||||
@ -170,8 +164,8 @@ def resource_get(context, resource_id):
|
||||
|
||||
|
||||
def resource_get_by_name_and_stack(context, resource_name, stack_id):
|
||||
result = model_query(
|
||||
context, models.Resource
|
||||
result = context.session.query(
|
||||
models.Resource
|
||||
).filter_by(
|
||||
name=resource_name
|
||||
).filter_by(
|
||||
@ -181,7 +175,7 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id):
|
||||
|
||||
|
||||
def resource_get_by_physical_resource_id(context, physical_resource_id):
|
||||
results = (model_query(context, models.Resource)
|
||||
results = (context.session.query(models.Resource)
|
||||
.filter_by(physical_resource_id=physical_resource_id)
|
||||
.all())
|
||||
|
||||
@ -193,7 +187,7 @@ def resource_get_by_physical_resource_id(context, physical_resource_id):
|
||||
|
||||
|
||||
def resource_get_all(context):
|
||||
results = model_query(context, models.Resource).all()
|
||||
results = context.session.query(models.Resource).all()
|
||||
|
||||
if not results:
|
||||
raise exception.NotFound(_('no resources were found'))
|
||||
@ -221,7 +215,7 @@ def resource_data_get_all(context, resource_id, data=None):
|
||||
If data is encrypted, this method will decrypt the results.
|
||||
"""
|
||||
if data is None:
|
||||
data = (model_query(context, models.ResourceData)
|
||||
data = (context.session.query(models.ResourceData)
|
||||
.filter_by(resource_id=resource_id)).all()
|
||||
|
||||
if not data:
|
||||
@ -274,7 +268,7 @@ def stack_tags_delete(context, stack_id):
|
||||
|
||||
|
||||
def stack_tags_get(context, stack_id):
|
||||
result = (model_query(context, models.StackTag)
|
||||
result = (context.session.query(models.StackTag)
|
||||
.filter_by(stack_id=stack_id)
|
||||
.all())
|
||||
return result or None
|
||||
@ -285,7 +279,7 @@ def resource_data_get_by_key(context, resource_id, key):
|
||||
|
||||
Does not decrypt resource_data.
|
||||
"""
|
||||
result = (model_query(context, models.ResourceData)
|
||||
result = (context.session.query(models.ResourceData)
|
||||
.filter_by(resource_id=resource_id)
|
||||
.filter_by(key=key).first())
|
||||
|
||||
@ -314,7 +308,7 @@ def resource_data_set(context, resource_id, key, value, redact=False):
|
||||
|
||||
|
||||
def resource_exchange_stacks(context, resource_id1, resource_id2):
|
||||
query = model_query(context, models.Resource)
|
||||
query = context.session.query(models.Resource)
|
||||
session = query.session
|
||||
session.begin()
|
||||
|
||||
@ -339,8 +333,8 @@ def resource_create(context, values):
|
||||
|
||||
|
||||
def resource_get_all_by_stack(context, stack_id, filters=None):
|
||||
query = model_query(
|
||||
context, models.Resource
|
||||
query = context.session.query(
|
||||
models.Resource
|
||||
).filter_by(
|
||||
stack_id=stack_id
|
||||
).options(orm.joinedload("data"))
|
||||
@ -357,9 +351,9 @@ def resource_get_all_by_stack(context, stack_id, filters=None):
|
||||
|
||||
def resource_get_all_active_by_stack(context, stack_id):
|
||||
filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'}
|
||||
subquery = model_query(context, models.Resource.id).filter_by(**filters)
|
||||
subquery = context.session.query(models.Resource.id).filter_by(**filters)
|
||||
|
||||
results = model_query(context, models.Resource).filter_by(
|
||||
results = context.session.query(models.Resource).filter_by(
|
||||
stack_id=stack_id).filter(
|
||||
models.Resource.id.notin_(subquery.as_scalar())
|
||||
).options(orm.joinedload("data")).all()
|
||||
@ -371,8 +365,8 @@ def resource_get_all_active_by_stack(context, stack_id):
|
||||
|
||||
|
||||
def resource_get_all_by_root_stack(context, stack_id, filters=None):
|
||||
query = model_query(
|
||||
context, models.Resource
|
||||
query = context.session.query(
|
||||
models.Resource
|
||||
).filter_by(
|
||||
root_stack_id=stack_id
|
||||
).options(orm.joinedload("data"))
|
||||
@ -405,7 +399,7 @@ def stack_get_by_name(context, stack_name):
|
||||
|
||||
def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
|
||||
eager_load=False):
|
||||
query = model_query(context, models.Stack)
|
||||
query = context.session.query(models.Stack)
|
||||
if eager_load:
|
||||
query = query.options(orm.joinedload("raw_template"))
|
||||
result = query.get(stack_id)
|
||||
@ -425,7 +419,7 @@ def stack_get(context, stack_id, show_deleted=False, tenant_safe=True,
|
||||
|
||||
|
||||
def stack_get_status(context, stack_id):
|
||||
query = model_query(context, models.Stack)
|
||||
query = context.session.query(models.Stack)
|
||||
query = query.options(
|
||||
orm.load_only("action", "status", "status_reason", "updated_at"))
|
||||
result = query.filter_by(id=stack_id).first()
|
||||
@ -468,7 +462,7 @@ def _paginate_query(context, query, model, limit=None, sort_keys=None,
|
||||
|
||||
model_marker = None
|
||||
if marker:
|
||||
model_marker = model_query(context, model).get(marker)
|
||||
model_marker = context.session.query(model).get(marker)
|
||||
try:
|
||||
query = utils.paginate_query(query, model, limit, sort_keys,
|
||||
model_marker, sort_dir)
|
||||
@ -691,8 +685,8 @@ def stack_get_root_id(context, stack_id):
|
||||
|
||||
def stack_count_total_resources(context, stack_id):
|
||||
# count all resources which belong to the root stack
|
||||
results = model_query(
|
||||
context, models.Resource
|
||||
results = context.session.query(
|
||||
models.Resource
|
||||
).filter(models.Resource.root_stack_id == stack_id).count()
|
||||
return results
|
||||
|
||||
@ -731,7 +725,7 @@ def user_creds_create(context):
|
||||
|
||||
|
||||
def user_creds_get(context, user_creds_id):
|
||||
db_result = model_query(context, models.UserCreds).get(user_creds_id)
|
||||
db_result = context.session.query(models.UserCreds).get(user_creds_id)
|
||||
if db_result is None:
|
||||
return None
|
||||
# Return a dict copy of db results, do not decrypt details into db_result
|
||||
@ -747,7 +741,7 @@ def user_creds_get(context, user_creds_id):
|
||||
|
||||
@db_utils.retry_on_stale_data_error
|
||||
def user_creds_delete(context, user_creds_id):
|
||||
creds = model_query(context, models.UserCreds).get(user_creds_id)
|
||||
creds = context.session.query(models.UserCreds).get(user_creds_id)
|
||||
if not creds:
|
||||
raise exception.NotFound(
|
||||
_('Attempt to delete user creds with id '
|
||||
@ -758,22 +752,22 @@ def user_creds_delete(context, user_creds_id):
|
||||
|
||||
|
||||
def event_get(context, event_id):
|
||||
result = model_query(context, models.Event).get(event_id)
|
||||
result = context.session.query(models.Event).get(event_id)
|
||||
return result
|
||||
|
||||
|
||||
def event_get_all(context):
|
||||
stacks = soft_delete_aware_query(context, models.Stack)
|
||||
stack_ids = [stack.id for stack in stacks]
|
||||
results = model_query(
|
||||
context, models.Event
|
||||
results = context.session.query(
|
||||
models.Event
|
||||
).filter(models.Event.stack_id.in_(stack_ids)).all()
|
||||
return results
|
||||
|
||||
|
||||
def event_get_all_by_tenant(context, limit=None, marker=None,
|
||||
sort_keys=None, sort_dir=None, filters=None):
|
||||
query = model_query(context, models.Event)
|
||||
query = context.session.query(models.Event)
|
||||
query = db_filters.exact_filter(query, models.Event, filters)
|
||||
query = query.join(
|
||||
models.Event.stack
|
||||
@ -784,7 +778,7 @@ def event_get_all_by_tenant(context, limit=None, marker=None,
|
||||
|
||||
|
||||
def _query_all_by_stack(context, stack_id):
|
||||
query = model_query(context, models.Event).filter_by(stack_id=stack_id)
|
||||
query = context.session.query(models.Event).filter_by(stack_id=stack_id)
|
||||
return query
|
||||
|
||||
|
||||
@ -809,10 +803,10 @@ def _events_paginate_query(context, query, model, limit=None, sort_keys=None,
|
||||
|
||||
model_marker = None
|
||||
if marker:
|
||||
# not to use model_query(context, model).get(marker), because
|
||||
# not to use context.session.query(model).get(marker), because
|
||||
# user can only see the ID(column 'uuid') and the ID as the marker
|
||||
model_marker = model_query(
|
||||
context, model).filter_by(uuid=marker).first()
|
||||
model_marker = context.session.query(
|
||||
model).filter_by(uuid=marker).first()
|
||||
try:
|
||||
query = utils.paginate_query(query, model, limit, sort_keys,
|
||||
model_marker, sort_dir)
|
||||
@ -841,7 +835,7 @@ def _events_filter_and_page_query(context, query,
|
||||
|
||||
|
||||
def event_count_all_by_stack(context, stack_id):
|
||||
query = model_query(context, func.count(models.Event.id))
|
||||
query = context.session.query(func.count(models.Event.id))
|
||||
return query.filter_by(stack_id=stack_id).scalar()
|
||||
|
||||
|
||||
@ -874,24 +868,24 @@ def event_create(context, values):
|
||||
|
||||
|
||||
def watch_rule_get(context, watch_rule_id):
|
||||
result = model_query(context, models.WatchRule).get(watch_rule_id)
|
||||
result = context.session.query(models.WatchRule).get(watch_rule_id)
|
||||
return result
|
||||
|
||||
|
||||
def watch_rule_get_by_name(context, watch_rule_name):
|
||||
result = model_query(
|
||||
context, models.WatchRule).filter_by(name=watch_rule_name).first()
|
||||
result = context.session.query(
|
||||
models.WatchRule).filter_by(name=watch_rule_name).first()
|
||||
return result
|
||||
|
||||
|
||||
def watch_rule_get_all(context):
|
||||
results = model_query(context, models.WatchRule).all()
|
||||
results = context.session.query(models.WatchRule).all()
|
||||
return results
|
||||
|
||||
|
||||
def watch_rule_get_all_by_stack(context, stack_id):
|
||||
results = model_query(
|
||||
context, models.WatchRule).filter_by(stack_id=stack_id).all()
|
||||
results = context.session.query(
|
||||
models.WatchRule).filter_by(stack_id=stack_id).all()
|
||||
return results
|
||||
|
||||
|
||||
@ -936,12 +930,12 @@ def watch_data_create(context, values):
|
||||
|
||||
|
||||
def watch_data_get_all(context):
|
||||
results = model_query(context, models.WatchData).all()
|
||||
results = context.session.query(models.WatchData).all()
|
||||
return results
|
||||
|
||||
|
||||
def watch_data_get_all_by_watch_rule_id(context, watch_rule_id):
|
||||
results = model_query(context, models.WatchData).filter_by(
|
||||
results = context.session.query(models.WatchData).filter_by(
|
||||
watch_rule_id=watch_rule_id).all()
|
||||
return results
|
||||
|
||||
@ -954,7 +948,7 @@ def software_config_create(context, values):
|
||||
|
||||
|
||||
def software_config_get(context, config_id):
|
||||
result = model_query(context, models.SoftwareConfig).get(config_id)
|
||||
result = context.session.query(models.SoftwareConfig).get(config_id)
|
||||
if (result is not None and context is not None and
|
||||
result.tenant != context.tenant_id):
|
||||
result = None
|
||||
@ -967,7 +961,7 @@ def software_config_get(context, config_id):
|
||||
|
||||
def software_config_get_all(context, limit=None, marker=None,
|
||||
tenant_safe=True):
|
||||
query = model_query(context, models.SoftwareConfig)
|
||||
query = context.session.query(models.SoftwareConfig)
|
||||
if tenant_safe and not context.is_admin:
|
||||
query = query.filter_by(tenant=context.tenant_id)
|
||||
return _paginate_query(context, query, models.SoftwareConfig,
|
||||
@ -977,7 +971,7 @@ def software_config_get_all(context, limit=None, marker=None,
|
||||
def software_config_delete(context, config_id):
|
||||
config = software_config_get(context, config_id)
|
||||
# Query if the software config has been referenced by deployment.
|
||||
result = model_query(context, models.SoftwareDeployment).filter_by(
|
||||
result = context.session.query(models.SoftwareDeployment).filter_by(
|
||||
config_id=config_id).first()
|
||||
if result:
|
||||
msg = (_("Software config with id %s can not be deleted as "
|
||||
@ -999,7 +993,8 @@ def software_deployment_create(context, values):
|
||||
|
||||
|
||||
def software_deployment_get(context, deployment_id):
|
||||
result = model_query(context, models.SoftwareDeployment).get(deployment_id)
|
||||
result = context.session.query(
|
||||
models.SoftwareDeployment).get(deployment_id)
|
||||
if (result is not None and context is not None and
|
||||
context.tenant_id not in (result.tenant,
|
||||
result.stack_user_project_id)):
|
||||
@ -1013,8 +1008,8 @@ def software_deployment_get(context, deployment_id):
|
||||
|
||||
def software_deployment_get_all(context, server_id=None):
|
||||
sd = models.SoftwareDeployment
|
||||
query = model_query(
|
||||
context, sd
|
||||
query = context.session.query(
|
||||
sd
|
||||
).filter(sqlalchemy.or_(
|
||||
sd.tenant == context.tenant_id,
|
||||
sd.stack_user_project_id == context.tenant_id)
|
||||
@ -1043,7 +1038,7 @@ def snapshot_create(context, values):
|
||||
|
||||
|
||||
def snapshot_get(context, snapshot_id):
|
||||
result = model_query(context, models.Snapshot).get(snapshot_id)
|
||||
result = context.session.query(models.Snapshot).get(snapshot_id)
|
||||
if (result is not None and context is not None and
|
||||
context.tenant_id != result.tenant):
|
||||
result = None
|
||||
@ -1078,7 +1073,7 @@ def snapshot_delete(context, snapshot_id):
|
||||
|
||||
|
||||
def snapshot_get_all(context, stack_id):
|
||||
return model_query(context, models.Snapshot).filter_by(
|
||||
return context.session.query(models.Snapshot).filter_by(
|
||||
stack_id=stack_id, tenant=context.tenant_id)
|
||||
|
||||
|
||||
@ -1108,19 +1103,19 @@ def service_delete(context, service_id, soft_delete=True):
|
||||
|
||||
|
||||
def service_get(context, service_id):
|
||||
result = model_query(context, models.Service).get(service_id)
|
||||
result = context.session.query(models.Service).get(service_id)
|
||||
if result is None:
|
||||
raise exception.EntityNotFound(entity='Service', name=service_id)
|
||||
return result
|
||||
|
||||
|
||||
def service_get_all(context):
|
||||
return (model_query(context, models.Service).
|
||||
return (context.session.query(models.Service).
|
||||
filter_by(deleted_at=None).all())
|
||||
|
||||
|
||||
def service_get_all_by_args(context, host, binary, hostname):
|
||||
return (model_query(context, models.Service).
|
||||
return (context.session.query(models.Service).
|
||||
filter_by(host=host).
|
||||
filter_by(binary=binary).
|
||||
filter_by(hostname=hostname).all())
|
||||
@ -1250,7 +1245,7 @@ def purge_deleted(age, granularity='days'):
|
||||
|
||||
def sync_point_delete_all_by_stack_and_traversal(context, stack_id,
|
||||
traversal_id):
|
||||
rows_deleted = model_query(context, models.SyncPoint).filter_by(
|
||||
rows_deleted = context.session.query(models.SyncPoint).filter_by(
|
||||
stack_id=stack_id, traversal_id=traversal_id).delete()
|
||||
return rows_deleted
|
||||
|
||||
@ -1267,7 +1262,7 @@ def sync_point_create(context, values):
|
||||
|
||||
def sync_point_get(context, entity_id, traversal_id, is_update):
|
||||
entity_id = str(entity_id)
|
||||
return model_query(context, models.SyncPoint).get(
|
||||
return context.session.query(models.SyncPoint).get(
|
||||
(entity_id, traversal_id, is_update)
|
||||
)
|
||||
|
||||
@ -1276,7 +1271,7 @@ def sync_point_update_input_data(context, entity_id,
|
||||
traversal_id, is_update, atomic_key,
|
||||
input_data):
|
||||
entity_id = str(entity_id)
|
||||
rows_updated = model_query(context, models.SyncPoint).filter_by(
|
||||
rows_updated = context.session.query(models.SyncPoint).filter_by(
|
||||
entity_id=entity_id,
|
||||
traversal_id=traversal_id,
|
||||
is_update=is_update,
|
||||
@ -1499,19 +1494,19 @@ def _get_batch(session, ctxt, query, model, batch_size=50):
|
||||
|
||||
def reset_stack_status(context, stack_id, stack=None):
|
||||
if stack is None:
|
||||
stack = model_query(context, models.Stack).get(stack_id)
|
||||
stack = context.session.query(models.Stack).get(stack_id)
|
||||
|
||||
if stack is None:
|
||||
raise exception.NotFound(_('Stack with id %s not found') % stack_id)
|
||||
|
||||
session = context.session
|
||||
with session.begin():
|
||||
query = model_query(context, models.Resource).filter_by(
|
||||
query = context.session.query(models.Resource).filter_by(
|
||||
status='IN_PROGRESS', stack_id=stack_id)
|
||||
query.update({'status': 'FAILED',
|
||||
'status_reason': 'Stack status manually reset'})
|
||||
|
||||
query = model_query(context, models.ResourceData)
|
||||
query = context.session.query(models.ResourceData)
|
||||
query = query.join(models.Resource)
|
||||
query = query.filter_by(stack_id=stack_id)
|
||||
query = query.filter(
|
||||
@ -1519,11 +1514,11 @@ def reset_stack_status(context, stack_id, stack=None):
|
||||
data_ids = [data.id for data in query]
|
||||
|
||||
if data_ids:
|
||||
query = model_query(context, models.ResourceData)
|
||||
query = context.session.query(models.ResourceData)
|
||||
query = query.filter(models.ResourceData.id.in_(data_ids))
|
||||
query.delete(synchronize_session='fetch')
|
||||
|
||||
query = model_query(context, models.Stack).filter_by(owner_id=stack_id)
|
||||
query = context.session.query(models.Stack).filter_by(owner_id=stack_id)
|
||||
for child in query:
|
||||
reset_stack_status(context, child.id, child)
|
||||
|
||||
|
@ -259,18 +259,17 @@ class SqlAlchemyTest(common.HeatTestCase):
|
||||
self.assertIn(['name', 'id'], args)
|
||||
|
||||
@mock.patch.object(db_api.utils, 'paginate_query')
|
||||
@mock.patch.object(db_api, 'model_query')
|
||||
def test_paginate_query_gets_model_marker(self, mock_query,
|
||||
mock_paginate_query):
|
||||
def test_paginate_query_gets_model_marker(self, mock_paginate_query):
|
||||
query = mock.Mock()
|
||||
model = mock.Mock()
|
||||
marker = mock.Mock()
|
||||
|
||||
mock_query_object = mock.Mock()
|
||||
mock_query_object.get.return_value = 'real_marker'
|
||||
mock_query.return_value = mock_query_object
|
||||
ctx = mock.MagicMock()
|
||||
ctx.session.query.return_value = mock_query_object
|
||||
|
||||
db_api._paginate_query(self.ctx, query, model, marker=marker)
|
||||
db_api._paginate_query(ctx, query, model, marker=marker)
|
||||
mock_query_object.get.assert_called_once_with(marker)
|
||||
args, _ = mock_paginate_query.call_args
|
||||
self.assertIn('real_marker', args)
|
||||
|
Loading…
Reference in New Issue
Block a user