# Copyright (c) 2015 Ericsson AB. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. # # Copyright (c) 2019-2021 Wind River Systems, Inc. # # SPDX-License-Identifier: Apache-2.0 # """ Implementation of SQLAlchemy backend. """ import sys from oslo_db.sqlalchemy import enginefacade from oslo_log import log as logging from sqlalchemy import Table, MetaData from sqlalchemy.sql import select from dcdbsync.common import exceptions as exception from dcdbsync.common.i18n import _ LOG = logging.getLogger(__name__) _main_context_manager = None def _get_main_context_manager(): global _main_context_manager if not _main_context_manager: _main_context_manager = enginefacade.transaction_context() return _main_context_manager _CONTEXT = None def _get_context(): global _CONTEXT if _CONTEXT is None: import threading _CONTEXT = threading.local() return _CONTEXT class TableRegistry(object): def __init__(self): self.metadata = MetaData() def get(self, connection, tablename): try: table = self.metadata.tables[tablename] except KeyError: table = Table( tablename, self.metadata, autoload_with=connection ) return table registry = TableRegistry() def get_read_connection(): reader = _get_main_context_manager().reader return reader.connection.using(_get_context()) def get_write_connection(): writer = _get_main_context_manager().writer return writer.connection.using(_get_context()) def row2dict(table, row): d = {} for c in table.columns: c_value = getattr(row, c.name) d[c.name] = c_value return d def index2column(r_table, index_name): column = None for c in r_table.columns: if c.name == index_name: column = c break return column def query(connection, table, index_name=None, index_value=None): global registry r_table = registry.get(connection, table) if index_name and index_value: c = index2column(r_table, index_name) stmt = select([r_table]).where(c == index_value) else: stmt = select([r_table]) records = [] result = connection.execute(stmt) for row in result: # convert the row into a dictionary d = row2dict(r_table, row) records.append(d) return records def insert(connection, table, data): global registry r_table = registry.get(connection, table) stmt = r_table.insert() connection.execute(stmt, data) def delete(connection, table, index_name, index_value): global registry r_table = registry.get(connection, table) c = index2column(r_table, index_name) stmt = r_table.delete().where(c == index_value) connection.execute(stmt) def update(connection, table, index_name, index_value, data): global registry r_table = registry.get(connection, table) c = index2column(r_table, index_name) stmt = r_table.update().where(c == index_value).values(data) connection.execute(stmt) def get_backend(): """The backend is this module itself.""" return sys.modules[__name__] def is_admin_context(context): """Indicate if the request context is an administrator.""" if not context: LOG.warning(_('Use of empty request context is deprecated'), DeprecationWarning) raise Exception('die') return context.is_admin def is_user_context(context): """Indicate if the request context is a normal user.""" if not context: return False if context.is_admin: return False if not context.user or not context.project: return False return True def require_admin_context(f): """Decorator to require admin request context. The first argument to the wrapped function must be the context. """ def wrapper(*args, **kwargs): if not is_admin_context(args[0]): raise exception.AdminRequired() return f(*args, **kwargs) return wrapper def require_context(f): """Decorator to require *any* user or admin context. This does no authorization for user or project access matching, see :py:func:`authorize_project_context` and :py:func:`authorize_user_context`. The first argument to the wrapped function must be the context. """ def wrapper(*args, **kwargs): if not is_admin_context(args[0]) and not is_user_context(args[0]): raise exception.NotAuthorized() return f(*args, **kwargs) return wrapper ################### # identity users ################### @require_context def user_get_all(context): result = [] with get_read_connection() as conn: # user table users = query(conn, 'user') # local_user table local_users = query(conn, 'local_user') # password table passwords = query(conn, 'password') for local_user in local_users: user = {'user': user for user in users if user['id'] == local_user['user_id']} user_passwords = {'password': [password for password in passwords if password['local_user_id'] == local_user['id']]} user_consolidated = dict(list({'local_user': local_user}.items()) + list(user.items()) + list(user_passwords.items())) result.append(user_consolidated) return result @require_context def user_get(context, user_id): result = {} with get_read_connection() as conn: # user table users = query(conn, 'user', 'id', user_id) if not users: raise exception.UserNotFound(user_id=user_id) result['user'] = users[0] # local_user table local_users = query(conn, 'local_user', 'user_id', user_id) if not local_users: raise exception.UserNotFound(user_id=user_id) result['local_user'] = local_users[0] # password table result['password'] = [] if result['local_user']: result['password'] = query(conn, 'password', 'local_user_id', result['local_user'].get('id')) return result @require_admin_context def user_create(context, payload): users = [payload['user']] local_users = [payload['local_user']] passwords = payload['password'] with get_write_connection() as conn: insert(conn, 'user', users) # ignore auto generated id for local_user in local_users: local_user.pop('id', None) insert(conn, 'local_user', local_users) inserted_local_users = query(conn, 'local_user', 'user_id', payload['local_user']['user_id']) if not inserted_local_users: raise exception.UserNotFound(user_id=payload['local_user'] ['user_id']) for password in passwords: # ignore auto generated id password.pop('id', None) password['local_user_id'] = inserted_local_users[0]['id'] insert(conn, 'password', passwords) return user_get(context, payload['user']['id']) @require_admin_context def user_update(context, user_id, payload): with get_write_connection() as conn: # user table table = 'user' new_user_id = user_id if table in payload: user = payload[table] update(conn, table, 'id', user_id, user) new_user_id = user.get('id') # local_user table table = 'local_user' if table in payload: local_user = payload[table] # ignore auto generated id local_user.pop('id', None) update(conn, table, 'user_id', user_id, local_user) updated_local_users = query(conn, table, 'user_id', new_user_id) if not updated_local_users: raise exception.UserNotFound(user_id=payload[table]['user_id']) # password table table = 'password' if table in payload: delete(conn, table, 'local_user_id', updated_local_users[0]['id']) passwords = payload[table] for password in passwords: # ignore auto generated ids password.pop('id', None) password['local_user_id'] = \ updated_local_users[0]['id'] insert(conn, table, password) # Need to update the actor_id in assignment and system_assignment # along with the user_id in user_group_membership tables if the # user id is updated if user_id != new_user_id: assignment = {'actor_id': new_user_id} user_group_membership = {'user_id': new_user_id} update(conn, 'assignment', 'actor_id', user_id, assignment) update(conn, 'system_assignment', 'actor_id', user_id, assignment) update(conn, 'user_group_membership', 'user_id', user_id, user_group_membership) return user_get(context, new_user_id) ################### # identity groups ################### @require_context def group_get_all(context): result = [] with get_read_connection() as conn: # groups table groups = query(conn, 'group') # user_group_membership table user_group_memberships = query(conn, 'user_group_membership') for group in groups: local_user_id_list = [membership['user_id'] for membership in user_group_memberships if membership['group_id'] == group['id']] local_user_id_list.sort() local_user_ids = {'local_user_ids': local_user_id_list} group_consolidated = dict(list({'group': group}.items()) + list(local_user_ids.items())) result.append(group_consolidated) return result @require_context def group_get(context, group_id): result = {} with get_read_connection() as conn: local_user_id_list = [] # group table group = query(conn, 'group', 'id', group_id) if not group: raise exception.GroupNotFound(group_id=group_id) result['group'] = group[0] # user_group_membership table user_group_memberships = query(conn, 'user_group_membership', 'group_id', group_id) for user_group_membership in user_group_memberships: local_user = query(conn, 'local_user', 'user_id', user_group_membership.get('user_id')) if not local_user: raise exception.UserNotFound(user_id=user_group_membership.get('user_id')) local_user_id_list.append(local_user[0]['user_id']) result['local_user_ids'] = local_user_id_list return result @require_admin_context def group_create(context, payload): group = payload['group'] local_user_ids = payload['local_user_ids'] with get_write_connection() as conn: insert(conn, 'group', group) for local_user_id in local_user_ids: user_group_membership = {'user_id': local_user_id, 'group_id': group['id']} insert(conn, 'user_group_membership', user_group_membership) return group_get(context, payload['group']['id']) @require_admin_context def group_update(context, group_id, payload): with get_write_connection() as conn: new_group_id = group_id if 'group' in payload and 'local_user_ids' in payload: group = payload['group'] new_group_id = group.get('id') # local_user_id_list is a sorted list of user IDs that # belong to this group local_user_id_list = payload['local_user_ids'] user_group_memberships = query(conn, 'user_group_membership', 'group_id', group_id) existing_user_list = [user_group_membership['user_id'] for user_group_membership in user_group_memberships] existing_user_list.sort() deleted = False if (group_id != new_group_id) or (local_user_id_list != existing_user_list): # Foreign key constraint exists on 'group_id' of user_group_membership table # and 'id' of group table. So delete user group membership records before # updating group if groups IDs are different # Alternatively, if there is a discrepency in the user group memberships, # delete and re-create them delete(conn, 'user_group_membership', 'group_id', group_id) deleted = True # Update group table update(conn, 'group', 'id', group_id, group) if deleted: for local_user_id in local_user_id_list: item = {'user_id': local_user_id, 'group_id': new_group_id} insert(conn, 'user_group_membership', item) # Need to update the actor_id in assignment and system_assignment # tables if the group id is updated if group_id != new_group_id: assignment = {'actor_id': new_group_id} update(conn, 'assignment', 'actor_id', group_id, assignment) update(conn, 'system_assignment', 'actor_id', group_id, assignment) return group_get(context, new_group_id) ################### # identity projects ################### @require_context def project_get_all(context): result = [] with get_read_connection() as conn: # project table projects = query(conn, 'project') for project in projects: project_consolidated = {'project': project} result.append(project_consolidated) return result @require_context def project_get(context, project_id): result = {} with get_read_connection() as conn: # project table projects = query(conn, 'project', 'id', project_id) if not projects: raise exception.ProjectNotFound(project_id=project_id) result['project'] = projects[0] return result @require_admin_context def project_create(context, payload): projects = [payload['project']] with get_write_connection() as conn: insert(conn, 'project', projects) return project_get(context, payload['project']['id']) @require_admin_context def project_update(context, project_id, payload): with get_write_connection() as conn: # project table table = 'project' new_project_id = project_id if table in payload: domain_ref_projects = [] parent_ref_projects = [] domain_ref_users = [] domain_ref_local_users = [] project = payload[table] new_project_id = project.get('id') if project_id != new_project_id: domain_ref_projects = query(conn, 'project', 'domain_id', project_id) delete(conn, 'project', 'domain_id', project_id) parent_ref_projects = query(conn, 'project', 'parent_id', project_id) delete(conn, 'project', 'parent_id', project_id) # For user table: CONSTRAINT `user_ibfk_1` # FOREIGN KEY(`domain_id`) REFERENCES `project`(`id`) domain_ref_users = query(conn, 'user', 'domain_id', project_id) domain_ref_local_users = query(conn, 'local_user', 'domain_id', project_id) delete(conn, 'user', 'domain_id', project_id) # Update project table update(conn, table, 'id', project_id, project) # Update saved records from project table and insert them back if domain_ref_projects: for domain_ref_project in domain_ref_projects: domain_ref_project['domain_id'] = new_project_id if domain_ref_project['parent_id'] == project_id: domain_ref_project['parent_id'] = new_project_id insert(conn, 'project', domain_ref_projects) if parent_ref_projects: for parent_ref_project in parent_ref_projects: parent_ref_project['parent_id'] = new_project_id if parent_ref_project['domain_id'] == project_id: parent_ref_project['domain_id'] = new_project_id insert(conn, 'project', parent_ref_projects) if domain_ref_users: for domain_ref_user in domain_ref_users: domain_ref_user['domain_id'] = new_project_id insert(conn, 'user', domain_ref_users) if domain_ref_local_users: for domain_ref_local_user in domain_ref_local_users: domain_ref_local_user['domain_id'] = new_project_id insert(conn, 'local_user', domain_ref_local_users) # Need to update the target_id in assignment table # if the project id is updated if project_id != new_project_id: table = 'assignment' assignment = {'target_id': new_project_id} update(conn, table, 'target_id', project_id, assignment) return project_get(context, new_project_id) ################### # identity roles ################### @require_context def role_get_all(context): result = [] with get_read_connection() as conn: # role table roles = query(conn, 'role') for role in roles: role_consolidated = {'role': role} result.append(role_consolidated) return result @require_context def role_get(context, role_id): result = {} with get_read_connection() as conn: # role table roles = query(conn, 'role', 'id', role_id) if not roles: raise exception.RoleNotFound(role_id=role_id) result['role'] = roles[0] return result @require_admin_context def role_create(context, payload): roles = [payload['role']] with get_write_connection() as conn: insert(conn, 'role', roles) return role_get(context, payload['role']['id']) @require_admin_context def role_update(context, role_id, payload): with get_write_connection() as conn: # role table table = 'role' new_role_id = role_id if table in payload: prior_roles = [] implied_roles = [] role = payload[table] new_role_id = role.get('id') if role_id != new_role_id: # implied_role table has foreign key references to role table. # The foreign key references are on DELETE CASCADE only. To # avoid foreign key constraints violation, save these records # from implied_role table, delete them, update role table, # update and insert them back after role table is updated. prior_roles = query(conn, 'implied_role', 'prior_role_id', role_id) delete(conn, 'implied_role', 'prior_role_id', role_id) implied_roles = query(conn, 'implied_role', 'implied_role_id', role_id) delete(conn, 'implied_role', 'implied_role_id', role_id) # Update role table update(conn, table, 'id', role_id, role) # Update saved records from implied_role table and insert them back if prior_roles: for prior_role in prior_roles: prior_role['prior_role_id'] = new_role_id insert(conn, 'implied_role', prior_roles) if implied_roles: for implied_role in implied_roles: implied_role['implied_role_id'] = new_role_id insert(conn, 'implied_role', implied_roles) # Need to update the role_id in assignment and system_assignment tables # if the role id is updated if role_id != new_role_id: assignment = {'role_id': new_role_id} update(conn, 'assignment', 'role_id', role_id, assignment) update(conn, 'system_assignment', 'role_id', role_id, assignment) return role_get(context, new_role_id) ################################## # identity token revocation events ################################## @require_context def revoke_event_get_all(context): result = [] with get_read_connection() as conn: # revocation_event table revoke_events = query(conn, 'revocation_event') for revoke_event in revoke_events: revoke_event_consolidated = {'revocation_event': revoke_event} result.append(revoke_event_consolidated) return result @require_context def revoke_event_get_by_audit(context, audit_id): result = {} with get_read_connection() as conn: # revocation_event table revoke_events = query(conn, 'revocation_event', 'audit_id', audit_id) if not revoke_events: raise exception.RevokeEventNotFound() result['revocation_event'] = revoke_events[0] return result @require_context def revoke_event_get_by_user(context, user_id, issued_before): result = {} with get_read_connection() as conn: # revocation_event table events = query(conn, 'revocation_event', 'user_id', user_id) revoke_events = [event for event in events if str(event['issued_before']) == issued_before] if not revoke_events: raise exception.RevokeEventNotFound() result['revocation_event'] = revoke_events[0] return result @require_admin_context def revoke_event_create(context, payload): revoke_event = payload['revocation_event'] # ignore auto generated id revoke_event.pop('id', None) revoke_events = [revoke_event] with get_write_connection() as conn: insert(conn, 'revocation_event', revoke_events) result = {} if revoke_event.get('audit_id') is not None: result = revoke_event_get_by_audit(context, revoke_event.get('audit_id')) elif (revoke_event.get('user_id') is not None) and \ (revoke_event.get('issued_before') is not None): result = revoke_event_get_by_user(context, revoke_event.get('user_id'), revoke_event.get('issued_before')) return result @require_admin_context def revoke_event_delete_by_audit(context, audit_id): with get_write_connection() as conn: delete(conn, 'revocation_event', 'audit_id', audit_id) @require_admin_context def revoke_event_delete_by_user(context, user_id, issued_before): result = revoke_event_get_by_user(context, user_id, issued_before) event_id = result['revocation_event']['id'] with get_write_connection() as conn: delete(conn, 'revocation_event', 'id', event_id)