diff --git a/heat/db/sqlalchemy/api.py b/heat/db/sqlalchemy/api.py index e87334fdad..5a3e7a6bd6 100644 --- a/heat/db/sqlalchemy/api.py +++ b/heat/db/sqlalchemy/api.py @@ -80,19 +80,13 @@ def get_backend(): return sys.modules[__name__] -def model_query(context, *args): - session = context.session - query = session.query(*args) - return query - - def soft_delete_aware_query(context, *args, **kwargs): """Stack query helper that accounts for context's `show_deleted` field. :param show_deleted: if True, overrides context's show_deleted field. """ - query = model_query(context, *args) + query = context.session.query(*args) show_deleted = kwargs.get('show_deleted') or context.show_deleted if not show_deleted: @@ -101,7 +95,7 @@ def soft_delete_aware_query(context, *args, **kwargs): def raw_template_get(context, template_id): - result = model_query(context, models.RawTemplate).get(template_id) + result = context.session.query(models.RawTemplate).get(template_id) if not result: raise exception.NotFound(_('raw template with id %s not found') % @@ -161,7 +155,7 @@ def raw_template_files_get(context, files_id): def resource_get(context, resource_id): - result = model_query(context, models.Resource).get(resource_id) + result = context.session.query(models.Resource).get(resource_id) if not result: raise exception.NotFound(_("resource with id %s not found") % @@ -170,8 +164,8 @@ def resource_get(context, resource_id): def resource_get_by_name_and_stack(context, resource_name, stack_id): - result = model_query( - context, models.Resource + result = context.session.query( + models.Resource ).filter_by( name=resource_name ).filter_by( @@ -181,7 +175,7 @@ def resource_get_by_name_and_stack(context, resource_name, stack_id): def resource_get_by_physical_resource_id(context, physical_resource_id): - results = (model_query(context, models.Resource) + results = (context.session.query(models.Resource) .filter_by(physical_resource_id=physical_resource_id) .all()) @@ -193,7 +187,7 @@ def resource_get_by_physical_resource_id(context, physical_resource_id): def resource_get_all(context): - results = model_query(context, models.Resource).all() + results = context.session.query(models.Resource).all() if not results: raise exception.NotFound(_('no resources were found')) @@ -221,7 +215,7 @@ def resource_data_get_all(context, resource_id, data=None): If data is encrypted, this method will decrypt the results. """ if data is None: - data = (model_query(context, models.ResourceData) + data = (context.session.query(models.ResourceData) .filter_by(resource_id=resource_id)).all() if not data: @@ -274,7 +268,7 @@ def stack_tags_delete(context, stack_id): def stack_tags_get(context, stack_id): - result = (model_query(context, models.StackTag) + result = (context.session.query(models.StackTag) .filter_by(stack_id=stack_id) .all()) return result or None @@ -285,7 +279,7 @@ def resource_data_get_by_key(context, resource_id, key): Does not decrypt resource_data. """ - result = (model_query(context, models.ResourceData) + result = (context.session.query(models.ResourceData) .filter_by(resource_id=resource_id) .filter_by(key=key).first()) @@ -314,7 +308,7 @@ def resource_data_set(context, resource_id, key, value, redact=False): def resource_exchange_stacks(context, resource_id1, resource_id2): - query = model_query(context, models.Resource) + query = context.session.query(models.Resource) session = query.session session.begin() @@ -339,8 +333,8 @@ def resource_create(context, values): def resource_get_all_by_stack(context, stack_id, filters=None): - query = model_query( - context, models.Resource + query = context.session.query( + models.Resource ).filter_by( stack_id=stack_id ).options(orm.joinedload("data")) @@ -357,9 +351,9 @@ def resource_get_all_by_stack(context, stack_id, filters=None): def resource_get_all_active_by_stack(context, stack_id): filters = {'stack_id': stack_id, 'action': 'DELETE', 'status': 'COMPLETE'} - subquery = model_query(context, models.Resource.id).filter_by(**filters) + subquery = context.session.query(models.Resource.id).filter_by(**filters) - results = model_query(context, models.Resource).filter_by( + results = context.session.query(models.Resource).filter_by( stack_id=stack_id).filter( models.Resource.id.notin_(subquery.as_scalar()) ).options(orm.joinedload("data")).all() @@ -371,8 +365,8 @@ def resource_get_all_active_by_stack(context, stack_id): def resource_get_all_by_root_stack(context, stack_id, filters=None): - query = model_query( - context, models.Resource + query = context.session.query( + models.Resource ).filter_by( root_stack_id=stack_id ).options(orm.joinedload("data")) @@ -405,7 +399,7 @@ def stack_get_by_name(context, stack_name): def stack_get(context, stack_id, show_deleted=False, tenant_safe=True, eager_load=False): - query = model_query(context, models.Stack) + query = context.session.query(models.Stack) if eager_load: query = query.options(orm.joinedload("raw_template")) result = query.get(stack_id) @@ -425,7 +419,7 @@ def stack_get(context, stack_id, show_deleted=False, tenant_safe=True, def stack_get_status(context, stack_id): - query = model_query(context, models.Stack) + query = context.session.query(models.Stack) query = query.options( orm.load_only("action", "status", "status_reason", "updated_at")) result = query.filter_by(id=stack_id).first() @@ -468,7 +462,7 @@ def _paginate_query(context, query, model, limit=None, sort_keys=None, model_marker = None if marker: - model_marker = model_query(context, model).get(marker) + model_marker = context.session.query(model).get(marker) try: query = utils.paginate_query(query, model, limit, sort_keys, model_marker, sort_dir) @@ -691,8 +685,8 @@ def stack_get_root_id(context, stack_id): def stack_count_total_resources(context, stack_id): # count all resources which belong to the root stack - results = model_query( - context, models.Resource + results = context.session.query( + models.Resource ).filter(models.Resource.root_stack_id == stack_id).count() return results @@ -731,7 +725,7 @@ def user_creds_create(context): def user_creds_get(context, user_creds_id): - db_result = model_query(context, models.UserCreds).get(user_creds_id) + db_result = context.session.query(models.UserCreds).get(user_creds_id) if db_result is None: return None # Return a dict copy of db results, do not decrypt details into db_result @@ -747,7 +741,7 @@ def user_creds_get(context, user_creds_id): @db_utils.retry_on_stale_data_error def user_creds_delete(context, user_creds_id): - creds = model_query(context, models.UserCreds).get(user_creds_id) + creds = context.session.query(models.UserCreds).get(user_creds_id) if not creds: raise exception.NotFound( _('Attempt to delete user creds with id ' @@ -758,22 +752,22 @@ def user_creds_delete(context, user_creds_id): def event_get(context, event_id): - result = model_query(context, models.Event).get(event_id) + result = context.session.query(models.Event).get(event_id) return result def event_get_all(context): stacks = soft_delete_aware_query(context, models.Stack) stack_ids = [stack.id for stack in stacks] - results = model_query( - context, models.Event + results = context.session.query( + models.Event ).filter(models.Event.stack_id.in_(stack_ids)).all() return results def event_get_all_by_tenant(context, limit=None, marker=None, sort_keys=None, sort_dir=None, filters=None): - query = model_query(context, models.Event) + query = context.session.query(models.Event) query = db_filters.exact_filter(query, models.Event, filters) query = query.join( models.Event.stack @@ -784,7 +778,7 @@ def event_get_all_by_tenant(context, limit=None, marker=None, def _query_all_by_stack(context, stack_id): - query = model_query(context, models.Event).filter_by(stack_id=stack_id) + query = context.session.query(models.Event).filter_by(stack_id=stack_id) return query @@ -809,10 +803,10 @@ def _events_paginate_query(context, query, model, limit=None, sort_keys=None, model_marker = None if marker: - # not to use model_query(context, model).get(marker), because + # not to use context.session.query(model).get(marker), because # user can only see the ID(column 'uuid') and the ID as the marker - model_marker = model_query( - context, model).filter_by(uuid=marker).first() + model_marker = context.session.query( + model).filter_by(uuid=marker).first() try: query = utils.paginate_query(query, model, limit, sort_keys, model_marker, sort_dir) @@ -841,7 +835,7 @@ def _events_filter_and_page_query(context, query, def event_count_all_by_stack(context, stack_id): - query = model_query(context, func.count(models.Event.id)) + query = context.session.query(func.count(models.Event.id)) return query.filter_by(stack_id=stack_id).scalar() @@ -874,24 +868,24 @@ def event_create(context, values): def watch_rule_get(context, watch_rule_id): - result = model_query(context, models.WatchRule).get(watch_rule_id) + result = context.session.query(models.WatchRule).get(watch_rule_id) return result def watch_rule_get_by_name(context, watch_rule_name): - result = model_query( - context, models.WatchRule).filter_by(name=watch_rule_name).first() + result = context.session.query( + models.WatchRule).filter_by(name=watch_rule_name).first() return result def watch_rule_get_all(context): - results = model_query(context, models.WatchRule).all() + results = context.session.query(models.WatchRule).all() return results def watch_rule_get_all_by_stack(context, stack_id): - results = model_query( - context, models.WatchRule).filter_by(stack_id=stack_id).all() + results = context.session.query( + models.WatchRule).filter_by(stack_id=stack_id).all() return results @@ -936,12 +930,12 @@ def watch_data_create(context, values): def watch_data_get_all(context): - results = model_query(context, models.WatchData).all() + results = context.session.query(models.WatchData).all() return results def watch_data_get_all_by_watch_rule_id(context, watch_rule_id): - results = model_query(context, models.WatchData).filter_by( + results = context.session.query(models.WatchData).filter_by( watch_rule_id=watch_rule_id).all() return results @@ -954,7 +948,7 @@ def software_config_create(context, values): def software_config_get(context, config_id): - result = model_query(context, models.SoftwareConfig).get(config_id) + result = context.session.query(models.SoftwareConfig).get(config_id) if (result is not None and context is not None and result.tenant != context.tenant_id): result = None @@ -967,7 +961,7 @@ def software_config_get(context, config_id): def software_config_get_all(context, limit=None, marker=None, tenant_safe=True): - query = model_query(context, models.SoftwareConfig) + query = context.session.query(models.SoftwareConfig) if tenant_safe and not context.is_admin: query = query.filter_by(tenant=context.tenant_id) return _paginate_query(context, query, models.SoftwareConfig, @@ -977,7 +971,7 @@ def software_config_get_all(context, limit=None, marker=None, def software_config_delete(context, config_id): config = software_config_get(context, config_id) # Query if the software config has been referenced by deployment. - result = model_query(context, models.SoftwareDeployment).filter_by( + result = context.session.query(models.SoftwareDeployment).filter_by( config_id=config_id).first() if result: msg = (_("Software config with id %s can not be deleted as " @@ -999,7 +993,8 @@ def software_deployment_create(context, values): def software_deployment_get(context, deployment_id): - result = model_query(context, models.SoftwareDeployment).get(deployment_id) + result = context.session.query( + models.SoftwareDeployment).get(deployment_id) if (result is not None and context is not None and context.tenant_id not in (result.tenant, result.stack_user_project_id)): @@ -1013,8 +1008,8 @@ def software_deployment_get(context, deployment_id): def software_deployment_get_all(context, server_id=None): sd = models.SoftwareDeployment - query = model_query( - context, sd + query = context.session.query( + sd ).filter(sqlalchemy.or_( sd.tenant == context.tenant_id, sd.stack_user_project_id == context.tenant_id) @@ -1043,7 +1038,7 @@ def snapshot_create(context, values): def snapshot_get(context, snapshot_id): - result = model_query(context, models.Snapshot).get(snapshot_id) + result = context.session.query(models.Snapshot).get(snapshot_id) if (result is not None and context is not None and context.tenant_id != result.tenant): result = None @@ -1078,7 +1073,7 @@ def snapshot_delete(context, snapshot_id): def snapshot_get_all(context, stack_id): - return model_query(context, models.Snapshot).filter_by( + return context.session.query(models.Snapshot).filter_by( stack_id=stack_id, tenant=context.tenant_id) @@ -1108,19 +1103,19 @@ def service_delete(context, service_id, soft_delete=True): def service_get(context, service_id): - result = model_query(context, models.Service).get(service_id) + result = context.session.query(models.Service).get(service_id) if result is None: raise exception.EntityNotFound(entity='Service', name=service_id) return result def service_get_all(context): - return (model_query(context, models.Service). + return (context.session.query(models.Service). filter_by(deleted_at=None).all()) def service_get_all_by_args(context, host, binary, hostname): - return (model_query(context, models.Service). + return (context.session.query(models.Service). filter_by(host=host). filter_by(binary=binary). filter_by(hostname=hostname).all()) @@ -1250,7 +1245,7 @@ def purge_deleted(age, granularity='days'): def sync_point_delete_all_by_stack_and_traversal(context, stack_id, traversal_id): - rows_deleted = model_query(context, models.SyncPoint).filter_by( + rows_deleted = context.session.query(models.SyncPoint).filter_by( stack_id=stack_id, traversal_id=traversal_id).delete() return rows_deleted @@ -1267,7 +1262,7 @@ def sync_point_create(context, values): def sync_point_get(context, entity_id, traversal_id, is_update): entity_id = str(entity_id) - return model_query(context, models.SyncPoint).get( + return context.session.query(models.SyncPoint).get( (entity_id, traversal_id, is_update) ) @@ -1276,7 +1271,7 @@ def sync_point_update_input_data(context, entity_id, traversal_id, is_update, atomic_key, input_data): entity_id = str(entity_id) - rows_updated = model_query(context, models.SyncPoint).filter_by( + rows_updated = context.session.query(models.SyncPoint).filter_by( entity_id=entity_id, traversal_id=traversal_id, is_update=is_update, @@ -1499,19 +1494,19 @@ def _get_batch(session, ctxt, query, model, batch_size=50): def reset_stack_status(context, stack_id, stack=None): if stack is None: - stack = model_query(context, models.Stack).get(stack_id) + stack = context.session.query(models.Stack).get(stack_id) if stack is None: raise exception.NotFound(_('Stack with id %s not found') % stack_id) session = context.session with session.begin(): - query = model_query(context, models.Resource).filter_by( + query = context.session.query(models.Resource).filter_by( status='IN_PROGRESS', stack_id=stack_id) query.update({'status': 'FAILED', 'status_reason': 'Stack status manually reset'}) - query = model_query(context, models.ResourceData) + query = context.session.query(models.ResourceData) query = query.join(models.Resource) query = query.filter_by(stack_id=stack_id) query = query.filter( @@ -1519,11 +1514,11 @@ def reset_stack_status(context, stack_id, stack=None): data_ids = [data.id for data in query] if data_ids: - query = model_query(context, models.ResourceData) + query = context.session.query(models.ResourceData) query = query.filter(models.ResourceData.id.in_(data_ids)) query.delete(synchronize_session='fetch') - query = model_query(context, models.Stack).filter_by(owner_id=stack_id) + query = context.session.query(models.Stack).filter_by(owner_id=stack_id) for child in query: reset_stack_status(context, child.id, child) diff --git a/heat/tests/db/test_sqlalchemy_api.py b/heat/tests/db/test_sqlalchemy_api.py index 02d0e0cc48..5114b99ec2 100644 --- a/heat/tests/db/test_sqlalchemy_api.py +++ b/heat/tests/db/test_sqlalchemy_api.py @@ -259,18 +259,17 @@ class SqlAlchemyTest(common.HeatTestCase): self.assertIn(['name', 'id'], args) @mock.patch.object(db_api.utils, 'paginate_query') - @mock.patch.object(db_api, 'model_query') - def test_paginate_query_gets_model_marker(self, mock_query, - mock_paginate_query): + def test_paginate_query_gets_model_marker(self, mock_paginate_query): query = mock.Mock() model = mock.Mock() marker = mock.Mock() mock_query_object = mock.Mock() mock_query_object.get.return_value = 'real_marker' - mock_query.return_value = mock_query_object + ctx = mock.MagicMock() + ctx.session.query.return_value = mock_query_object - db_api._paginate_query(self.ctx, query, model, marker=marker) + db_api._paginate_query(ctx, query, model, marker=marker) mock_query_object.get.assert_called_once_with(marker) args, _ = mock_paginate_query.call_args self.assertIn('real_marker', args)