Update for sqlalchemy v2.x compatibility

This patch updates Designate to be compatible with SQLAlchemy 2.x.

Depends-On: https://review.opendev.org/c/openstack/oslo.db/+/874858
Change-Id: I5b9cfb4b86bd7c342fd50d1b50dd12dce0c8e81a
This commit is contained in:
Erik Olof Gunnar Andersson 2022-11-20 18:26:52 -08:00
parent 758956d5b9
commit 05a112177d
16 changed files with 402 additions and 344 deletions

View File

@ -18,7 +18,7 @@ from sqlalchemy import MetaData, Table, select, func
import designate.conf
from designate.i18n import _
from designate.sqlalchemy import session
from designate.sqlalchemy import sql
# This import is not used, but is needed to register the storage:sqlalchemy
# group.
import designate.storage.impl_sqlalchemy # noqa
@ -27,14 +27,20 @@ from designate import utils
class Checks(upgradecheck.UpgradeCommands):
def _duplicate_service_status(self):
engine = session.get_engine('storage:sqlalchemy')
metadata = MetaData(bind=engine)
status = Table('service_statuses', metadata, autoload=True)
service_select = (select([func.count()])
.select_from(status)
.group_by('service_name', 'hostname')
)
service_counts = engine.execute(service_select).fetchall()
engine = sql.get_read_engine()
metadata = MetaData()
metadata.bind = engine
status = Table('service_statuses', metadata, autoload_with=engine)
service_select = (
select(func.count())
.select_from(status)
.group_by('service_name', 'hostname')
)
with sql.get_read_session() as session:
service_counts = session.execute(service_select).fetchall()
duplicated_services = [i for i in service_counts if i[0] > 1]
if duplicated_services:
return upgradecheck.Result(upgradecheck.Code.FAILURE,

View File

@ -21,11 +21,11 @@ from oslo_db import exception as oslo_db_exception
from oslo_db.sqlalchemy import utils as oslodb_utils
from oslo_log import log as logging
from oslo_utils import timeutils
from sqlalchemy import select, or_, between, func, distinct, inspect
from sqlalchemy import select, or_, between, func, distinct
from designate import exceptions
from designate import objects
from designate.sqlalchemy import session
from designate.sqlalchemy import sql
from designate.sqlalchemy import utils
@ -66,39 +66,8 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
def __init__(self):
super(SQLAlchemy, self).__init__()
self.engine = session.get_engine(self.get_name())
self.local_store = threading.local()
@abc.abstractmethod
def get_name(self):
"""Get the name."""
@property
def session(self):
# NOTE: This uses a thread local store, allowing each greenthread to
# have its own session stored correctly. Without this, each
# greenthread may end up using a single global session, which
# leads to bad things happening.
if not hasattr(self.local_store, 'session'):
self.local_store.session = session.get_session(self.get_name())
return self.local_store.session
def begin(self):
self.session.begin(subtransactions=True)
def commit(self):
self.session.commit()
def rollback(self):
self.session.rollback()
def get_inspector(self):
return inspect(self.engine)
@staticmethod
def _apply_criterion(table, query, criterion):
if criterion is None:
@ -195,17 +164,18 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
query = table.insert()
try:
resultproxy = self.session.execute(query, [dict(values)])
except oslo_db_exception.DBDuplicateEntry:
raise exc_dup("Duplicate %s" % obj.obj_name())
with sql.get_write_session() as session:
try:
resultproxy = session.execute(query, [dict(values)])
except oslo_db_exception.DBDuplicateEntry:
raise exc_dup("Duplicate %s" % obj.obj_name())
# Refetch the row, for generated columns etc
query = select([table]).where(
table.c.id == resultproxy.inserted_primary_key[0])
resultproxy = self.session.execute(query)
# Refetch the row, for generated columns etc
query = select(table).where(
table.c.id == resultproxy.inserted_primary_key[0])
resultproxy = session.execute(query)
return _set_object_from_model(obj, resultproxy.fetchone())
return _set_object_from_model(obj, resultproxy.fetchone())
def _find(self, context, table, cls, list_cls, exc_notfound, criterion,
one=False, marker=None, limit=None, sort_key=None,
@ -216,11 +186,10 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
# Build the query
if query is None:
query = select([table])
query = select(table)
query = self._apply_criterion(table, query, criterion)
if apply_tenant_criteria:
query = self._apply_tenant_criteria(context, table, query)
query = self._apply_deleted_criteria(context, table, query)
# Execute the Query
@ -229,8 +198,9 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
# a NotFound. Limiting to 2 allows us to determine
# when we need to raise, while selecting the minimal
# number of rows.
resultproxy = self.session.execute(query.limit(2))
results = resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query.limit(2))
results = resultproxy.fetchall()
if len(results) != 1:
raise exc_notfound("Could not find %s" % cls.obj_name())
@ -238,7 +208,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
return _set_object_from_model(cls(), results[0])
else:
if marker is not None:
marker = utils.check_marker(table, marker, self.session)
marker = utils.check_marker(table, marker)
try:
query = utils.paginate_query(
@ -246,8 +216,9 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
[sort_key, 'id'], marker=marker,
sort_dir=sort_dir)
resultproxy = self.session.execute(query)
results = resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
results = resultproxy.fetchall()
return _set_listobject_from_models(list_cls(), results)
except oslodb_utils.InvalidSortKey as sort_key_error:
@ -286,14 +257,14 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
recordsets_table.c.id == records_table.c.recordset_id)
inner_q = (
select([recordsets_table.c.id, # 0 - RS ID
zones_table.c.name]). # 1 - ZONE NAME
select(recordsets_table.c.id, # 0 - RS ID
zones_table.c.name). # 1 - ZONE NAME
select_from(rzjoin).
where(zones_table.c.deleted == '0')
)
count_q = (
select([func.count(distinct(recordsets_table.c.id))]).
select(func.count(distinct(recordsets_table.c.id))).
select_from(rzjoin).where(zones_table.c.deleted == '0')
)
@ -302,8 +273,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
dialect_name='mysql')
if marker is not None:
marker = utils.check_marker(recordsets_table, marker,
self.session)
marker = utils.check_marker(recordsets_table, marker)
try:
inner_q = utils.paginate_query(
@ -348,8 +318,9 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
# This is a separate call due to
# http://dev.mysql.com/doc/mysql-reslimits-excerpt/5.6/en/subquery-restrictions.html # noqa
inner_rproxy = self.session.execute(inner_q)
rows = inner_rproxy.fetchall()
with sql.get_read_session() as session:
inner_rproxy = session.execute(inner_q)
rows = inner_rproxy.fetchall()
if len(rows) == 0:
return 0, objects.RecordSetList()
id_zname_map = {}
@ -362,8 +333,9 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
if context.hide_counts:
total_count = None
else:
resultproxy = self.session.execute(count_q)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(count_q)
result = resultproxy.fetchone()
total_count = 0 if result is None else result[0]
# Join the 2 required tables
@ -372,39 +344,38 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
records_table.c.recordset_id == recordsets_table.c.id)
query = select(
[
# RS Info
recordsets_table.c.id, # 0 - RS ID
recordsets_table.c.version, # 1 - RS Version
recordsets_table.c.created_at, # 2 - RS Created
recordsets_table.c.updated_at, # 3 - RS Updated
recordsets_table.c.tenant_id, # 4 - RS Tenant
recordsets_table.c.zone_id, # 5 - RS Zone
recordsets_table.c.name, # 6 - RS Name
recordsets_table.c.type, # 7 - RS Type
recordsets_table.c.ttl, # 8 - RS TTL
recordsets_table.c.description, # 9 - RS Desc
# R Info
records_table.c.id, # 10 - R ID
records_table.c.version, # 11 - R Version
records_table.c.created_at, # 12 - R Created
records_table.c.updated_at, # 13 - R Updated
records_table.c.tenant_id, # 14 - R Tenant
records_table.c.zone_id, # 15 - R Zone
records_table.c.recordset_id, # 16 - R RSet
records_table.c.data, # 17 - R Data
records_table.c.description, # 18 - R Desc
records_table.c.hash, # 19 - R Hash
records_table.c.managed, # 20 - R Mngd Flg
records_table.c.managed_plugin_name, # 21 - R Mngd Plg
records_table.c.managed_resource_type, # 22 - R Mngd Type
records_table.c.managed_resource_region, # 23 - R Mngd Rgn
records_table.c.managed_resource_id, # 24 - R Mngd ID
records_table.c.managed_tenant_id, # 25 - R Mngd T ID
records_table.c.status, # 26 - R Status
records_table.c.action, # 27 - R Action
records_table.c.serial # 28 - R Serial
]).select_from(rjoin)
# RS Info
recordsets_table.c.id, # 0 - RS ID
recordsets_table.c.version, # 1 - RS Version
recordsets_table.c.created_at, # 2 - RS Created
recordsets_table.c.updated_at, # 3 - RS Updated
recordsets_table.c.tenant_id, # 4 - RS Tenant
recordsets_table.c.zone_id, # 5 - RS Zone
recordsets_table.c.name, # 6 - RS Name
recordsets_table.c.type, # 7 - RS Type
recordsets_table.c.ttl, # 8 - RS TTL
recordsets_table.c.description, # 9 - RS Desc
# R Info
records_table.c.id, # 10 - R ID
records_table.c.version, # 11 - R Version
records_table.c.created_at, # 12 - R Created
records_table.c.updated_at, # 13 - R Updated
records_table.c.tenant_id, # 14 - R Tenant
records_table.c.zone_id, # 15 - R Zone
records_table.c.recordset_id, # 16 - R RSet
records_table.c.data, # 17 - R Data
records_table.c.description, # 18 - R Desc
records_table.c.hash, # 19 - R Hash
records_table.c.managed, # 20 - R Mngd Flg
records_table.c.managed_plugin_name, # 21 - R Mngd Plg
records_table.c.managed_resource_type, # 22 - R Mngd Type
records_table.c.managed_resource_region, # 23 - R Mngd Rgn
records_table.c.managed_resource_id, # 24 - R Mngd ID
records_table.c.managed_tenant_id, # 25 - R Mngd T ID
records_table.c.status, # 26 - R Status
records_table.c.action, # 27 - R Action
records_table.c.serial # 28 - R Serial
).select_from(rjoin)
query = query.where(
recordsets_table.c.id.in_(formatted_ids)
@ -453,8 +424,9 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
sort_dir=sort_dir)
try:
resultproxy = self.session.execute(query)
raw_rows = resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
raw_rows = resultproxy.fetchall()
# Any ValueErrors are propagated back to the user as is.
# If however central or storage is called directly, invalid values
@ -538,19 +510,20 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
query = self._apply_deleted_criteria(context, table, query)
query = self._apply_version_increment(context, table, query)
try:
resultproxy = self.session.execute(query)
except oslo_db_exception.DBDuplicateEntry:
raise exc_dup("Duplicate %s" % obj.obj_name())
with sql.get_write_session() as session:
try:
resultproxy = session.execute(query)
except oslo_db_exception.DBDuplicateEntry:
raise exc_dup("Duplicate %s" % obj.obj_name())
if resultproxy.rowcount != 1:
raise exc_notfound("Could not find %s" % obj.obj_name())
if resultproxy.rowcount != 1:
raise exc_notfound("Could not find %s" % obj.obj_name())
# Refetch the row, for generated columns etc
query = select([table]).where(table.c.id == obj.id)
resultproxy = self.session.execute(query)
# Refetch the row, for generated columns etc
query = select(table).where(table.c.id == obj.id)
resultproxy = session.execute(query)
return _set_object_from_model(obj, resultproxy.fetchone())
return _set_object_from_model(obj, resultproxy.fetchone())
def _delete(self, context, table, obj, exc_notfound, hard_delete=False):
"""Perform item deletion or soft-delete.
@ -584,28 +557,30 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
query = self._apply_tenant_criteria(context, table, query)
query = self._apply_deleted_criteria(context, table, query)
resultproxy = self.session.execute(query)
with sql.get_write_session() as session:
resultproxy = session.execute(query)
if resultproxy.rowcount != 1:
raise exc_notfound("Could not find %s" % obj.obj_name())
if resultproxy.rowcount != 1:
raise exc_notfound("Could not find %s" % obj.obj_name())
# Refetch the row, for generated columns etc
query = select([table]).where(table.c.id == obj.id)
resultproxy = self.session.execute(query)
# Refetch the row, for generated columns etc
query = select(table).where(table.c.id == obj.id)
resultproxy = session.execute(query)
return _set_object_from_model(obj, resultproxy.fetchone())
return _set_object_from_model(obj, resultproxy.fetchone())
def _select_raw(self, context, table, criterion, query=None):
# Build the query
if query is None:
query = select([table])
query = select(table)
query = self._apply_criterion(table, query, criterion)
query = self._apply_deleted_criteria(context, table, query)
try:
resultproxy = self.session.execute(query)
return resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
return resultproxy.fetchall()
# Any ValueErrors are propagated back to the user as is.
# If however central or storage is called directly, invalid values
# show up as ValueError

View File

@ -1,85 +0,0 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Session Handling for SQLAlchemy backend."""
import sqlalchemy
import threading
from oslo_config import cfg
from oslo_db.sqlalchemy import session
from oslo_log import log as logging
from oslo_utils import importutils
osprofiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
try:
CONF.import_group("profiler", "designate.service")
except cfg.NoSuchGroupError:
pass
_FACADES = {}
_LOCK = threading.Lock()
def add_db_tracing(cache_name):
global _LOCK
if not osprofiler_sqlalchemy:
return
if not hasattr(CONF, 'profiler'):
return
if not CONF.profiler.enabled or not CONF.profiler.trace_sqlalchemy:
return
with _LOCK:
osprofiler_sqlalchemy.add_tracing(
sqlalchemy,
_FACADES[cache_name].get_engine(),
"db"
)
def _create_facade_lazily(cfg_group, connection=None, discriminator=None):
connection = connection or cfg.CONF[cfg_group].connection
cache_name = "%s:%s" % (cfg_group, discriminator)
if cache_name not in _FACADES:
conf = dict(cfg.CONF[cfg_group].items())
# FIXME(stephenfin): Remove this (and ideally use of
# LegacyEngineFacade) asap since it's not compatible with SQLAlchemy
# 2.0
conf['autocommit'] = True
_FACADES[cache_name] = session.EngineFacade(
connection,
**conf
)
add_db_tracing(cache_name)
return _FACADES[cache_name]
def get_engine(cfg_group):
facade = _create_facade_lazily(cfg_group)
return facade.get_engine()
def get_session(cfg_group, connection=None, discriminator=None, **kwargs):
facade = _create_facade_lazily(cfg_group, connection, discriminator)
return facade.get_session(**kwargs)

View File

@ -0,0 +1,99 @@
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Session Handling for SQLAlchemy backend."""
import sqlalchemy
import threading
from oslo_config import cfg
from oslo_db import options as db_options
from oslo_db.sqlalchemy import enginefacade
from oslo_log import log as logging
from oslo_utils import importutils
from osprofiler import opts as profiler
import osprofiler.sqlalchemy
from sqlalchemy import inspect
osprofiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
try:
CONF.import_group("profiler", "designate.service")
except cfg.NoSuchGroupError:
pass
_CONTEXT = None
_LOCK = threading.Lock()
_MAIN_CONTEXT_MANAGER = None
def initialize():
"""Initialize the module."""
connection = cfg.CONF['storage:sqlalchemy'].connection
db_options.set_defaults(
CONF, connection=connection
)
profiler.set_defaults(CONF, enabled=False, trace_sqlalchemy=False)
def _get_main_context_manager():
global _LOCK
global _MAIN_CONTEXT_MANAGER
with _LOCK:
if not _MAIN_CONTEXT_MANAGER:
initialize()
_MAIN_CONTEXT_MANAGER = enginefacade.transaction_context()
return _MAIN_CONTEXT_MANAGER
def _get_context():
global _CONTEXT
if _CONTEXT is None:
import threading
_CONTEXT = threading.local()
return _CONTEXT
def _wrap_session(sess):
if not osprofiler_sqlalchemy:
return sess
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
sess = osprofiler.sqlalchemy.wrap_session(sqlalchemy, sess)
return sess
def get_read_engine():
return _get_main_context_manager().reader.get_engine()
def get_inspector():
return inspect(get_read_engine())
def get_read_session():
reader = _get_main_context_manager().reader
return _wrap_session(reader.using(_get_context()))
def get_write_session():
writer = _get_main_context_manager().writer
return _wrap_session(writer.using(_get_context()))

View File

@ -23,9 +23,10 @@ import sqlalchemy
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import select
from designate import exceptions
from designate.i18n import _
from designate.sqlalchemy import sql
LOG = log.getLogger(__name__)
@ -51,7 +52,7 @@ def paginate_query(query, table, limit, sort_keys, marker=None,
if marker is not None:
marker_values = []
for sort_key in sort_keys:
v = marker[sort_key]
v = getattr(marker, sort_key)
marker_values.append(v)
# Build up an array of sort criteria as in the docstring
@ -117,13 +118,14 @@ def sort_query(query, table, sort_keys, sort_dir=None, sort_dirs=None):
return query, sort_dirs
def check_marker(table, marker, session):
def check_marker(table, marker):
marker_query = select([table]).where(table.c.id == marker)
marker_query = select(table).where(table.c.id == marker)
try:
marker_resultproxy = session.execute(marker_query)
marker = marker_resultproxy.fetchone()
with sql.get_read_session() as session:
marker_resultproxy = session.execute(marker_query)
marker = marker_resultproxy.fetchone()
if marker is None:
raise exceptions.MarkerNotFound(
'Marker %s could not be found' % marker)

View File

@ -23,8 +23,10 @@ from oslo_db import exception as db_exception
from oslo_log import log as logging
from oslo_utils import excutils
from designate.sqlalchemy import sql
from designate.storage import base
LOG = logging.getLogger(__name__)
RETRY_STATE = threading.local()
@ -108,17 +110,19 @@ def transaction(f):
"""Transaction decorator, to be used on class instances with a
self.storage attribute
"""
@retry(cb=_retry_on_deadlock)
@functools.wraps(f)
def transaction_wrapper(self, *args, **kwargs):
self.storage.begin()
try:
result = f(self, *args, **kwargs)
self.storage.commit()
return result
except Exception:
with excutils.save_and_reraise_exception():
self.storage.rollback()
with sql.get_write_session() as session:
# session.begin()
try:
result = f(self, *args, **kwargs)
# session.commit()
return result
except Exception:
with excutils.save_and_reraise_exception():
session.rollback()
transaction_wrapper.__wrapped_function = f
transaction_wrapper.__wrapper_name = 'transaction'
@ -132,14 +136,15 @@ def transaction_shallow_copy(f):
@retry(cb=_retry_on_deadlock, deep_copy=False)
@functools.wraps(f)
def transaction_wrapper(self, *args, **kwargs):
self.storage.begin()
try:
result = f(self, *args, **kwargs)
self.storage.commit()
return result
except Exception:
with excutils.save_and_reraise_exception():
self.storage.rollback()
with sql.get_write_session() as session:
# session.begin()
try:
result = f(self, *args, **kwargs)
# session.commit()
return result
except Exception:
with excutils.save_and_reraise_exception():
session.rollback()
transaction_wrapper.__wrapped_function = f
transaction_wrapper.__wrapper_name = 'transaction_shallow_copy'

View File

@ -21,6 +21,7 @@ from sqlalchemy.sql.expression import or_, literal_column
from designate import exceptions
from designate import objects
from designate.sqlalchemy import base as sqlalchemy_base
from designate.sqlalchemy import sql
from designate.storage import base as storage_base
from designate.storage.impl_sqlalchemy import tables
@ -37,8 +38,8 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
def __init__(self):
super(SQLAlchemyStorage, self).__init__()
def get_name(self):
return self.name
def get_inspector(self):
return sql.get_inspector()
# CRUD for our resources (quota, server, tsigkey, tenant, zone & record)
# R - get_*, find_*s
@ -162,14 +163,14 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
##
def find_tenants(self, context):
# returns an array of tenant_id & count of their zones
query = select([tables.zones.c.tenant_id,
func.count(tables.zones.c.id)])
query = select(tables.zones.c.tenant_id, func.count(tables.zones.c.id))
query = self._apply_tenant_criteria(context, tables.zones, query)
query = self._apply_deleted_criteria(context, tables.zones, query)
query = query.group_by(tables.zones.c.tenant_id)
resultproxy = self.session.execute(query)
results = resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
results = resultproxy.fetchall()
tenant_list = objects.TenantList(
objects=[objects.Tenant(id=t[0], zone_count=t[1]) for t in
@ -181,13 +182,14 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
def get_tenant(self, context, tenant_id):
# get list & count of all zones owned by given tenant_id
query = select([tables.zones.c.name])
query = select(tables.zones.c.name)
query = self._apply_tenant_criteria(context, tables.zones, query)
query = self._apply_deleted_criteria(context, tables.zones, query)
query = query.where(tables.zones.c.tenant_id == tenant_id)
resultproxy = self.session.execute(query)
results = resultproxy.fetchall()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
results = resultproxy.fetchall()
return objects.Tenant(
id=tenant_id,
@ -197,12 +199,13 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
def count_tenants(self, context):
# tenants are the owner of zones, count the number of unique tenants
# select count(distinct tenant_id) from zones
query = select([func.count(distinct(tables.zones.c.tenant_id))])
query = select(func.count(distinct(tables.zones.c.tenant_id)))
query = self._apply_tenant_criteria(context, tables.zones, query)
query = self._apply_deleted_criteria(context, tables.zones, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0
@ -223,7 +226,7 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
literal_column('False')),
else_=literal_column('True')).label('shared')
query = select(
[tables.zones, shared_case]).outerjoin(tables.shared_zones)
tables.zones, shared_case).outerjoin(tables.shared_zones)
zones = self._find(
context, tables.zones, objects.Zone, objects.ZoneList,
@ -417,17 +420,18 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
self.delete_recordset(context, i)
if tenant_id_changed:
self.session.execute(
tables.recordsets.update().
where(tables.recordsets.c.zone_id == zone.id).
values({'tenant_id': zone.tenant_id})
)
with sql.get_write_session() as session:
session.execute(
tables.recordsets.update().
where(tables.recordsets.c.zone_id == zone.id).
values({'tenant_id': zone.tenant_id})
)
self.session.execute(
tables.records.update().
where(tables.records.c.zone_id == zone.id).
values({'tenant_id': zone.tenant_id})
)
session.execute(
tables.records.update().
where(tables.records.c.zone_id == zone.id).
values({'tenant_id': zone.tenant_id})
)
return updated_zone
@ -492,8 +496,9 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
values(parent_zone_id=surviving_parent_id)
)
resultproxy = self.session.execute(query)
LOG.debug('%d child zones updated', resultproxy.rowcount)
with sql.get_write_session() as session:
resultproxy = session.execute(query)
LOG.debug('%d child zones updated', resultproxy.rowcount)
self.purge_zone(context, zone)
@ -501,13 +506,14 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
return len(zones)
def count_zones(self, context, criterion=None):
query = select([func.count(tables.zones.c.id)])
query = select(func.count(tables.zones.c.id))
query = self._apply_criterion(tables.zones, query, criterion)
query = self._apply_tenant_criteria(context, tables.zones, query)
query = self._apply_deleted_criteria(context, tables.zones, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0
@ -577,12 +583,14 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
query = query.where(tables.shared_zones.c.zone_id == zone_id)
query = query.where(
tables.shared_zones.c.target_project_id == project_id)
return self.session.scalar(query) is not None
with sql.get_read_session() as session:
return session.scalar(query) is not None
def delete_zone_shares(self, zone_id):
query = tables.shared_zones.delete().where(
tables.shared_zones.c.zone_id == zone_id)
self.session.execute(query)
with sql.get_write_session() as session:
session.execute(query)
# Zone attribute methods
def _find_zone_attributes(self, context, criterion, one=False,
@ -671,7 +679,7 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.zones,
tables.recordsets.c.zone_id == tables.zones.c.id)
query = (
select([tables.recordsets]).select_from(rjoin).
select(tables.recordsets).select_from(rjoin).
where(tables.zones.c.deleted == '0')
)
@ -713,9 +721,9 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.records.c.recordset_id == tables.recordsets.c.id)
query = (
select([tables.recordsets.c.id, tables.recordsets.c.type,
tables.recordsets.c.ttl, tables.recordsets.c.name,
tables.records.c.data, tables.records.c.action]).
select(tables.recordsets.c.id, tables.recordsets.c.type,
tables.recordsets.c.ttl, tables.recordsets.c.name,
tables.records.c.data, tables.records.c.action).
select_from(rjoin).where(tables.records.c.action != 'DELETE')
)
@ -758,8 +766,8 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.records.c.recordset_id == tables.recordsets.c.id)
query = (
select([tables.recordsets.c.name, tables.recordsets.c.ttl,
tables.recordsets.c.type, tables.records.c.data]).
select(tables.recordsets.c.name, tables.recordsets.c.ttl,
tables.recordsets.c.type, tables.records.c.data).
select_from(rjoin)
)
@ -844,7 +852,7 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.recordsets.c.zone_id == tables.zones.c.id)
query = (
select([func.count(tables.recordsets.c.id)]).
select(func.count(tables.recordsets.c.id)).
select_from(rjoin).
where(tables.zones.c.deleted == '0')
)
@ -853,8 +861,9 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
query = self._apply_tenant_criteria(context, tables.recordsets, query)
query = self._apply_deleted_criteria(context, tables.recordsets, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0
@ -924,7 +933,7 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.records.c.zone_id == tables.zones.c.id)
query = (
select([func.count(tables.records.c.id)]).
select(func.count(tables.records.c.id)).
select_from(rjoin).
where(tables.zones.c.deleted == '0')
)
@ -933,8 +942,9 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
query = self._apply_tenant_criteria(context, tables.records, query)
query = self._apply_deleted_criteria(context, tables.records, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0
@ -1521,7 +1531,7 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
tables.zone_transfer_requests.c.zone_id == tables.zones.c.id)
query = select(
[table, tables.zones.c.name.label('zone_name')]
table, tables.zones.c.name.label('zone_name')
).select_from(ljoin)
if not context.all_tenants:
@ -1611,14 +1621,15 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
exceptions.ZoneTransferRequestNotFound)
def count_zone_transfer_accept(self, context, criterion=None):
query = select([func.count(tables.zone_transfer_accepts.c.id)])
query = select(func.count(tables.zone_transfer_accepts.c.id))
query = self._apply_criterion(tables.zone_transfer_accepts,
query, criterion)
query = self._apply_deleted_criteria(context,
tables.zone_transfer_accepts, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0
@ -1782,13 +1793,14 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
exceptions.ZoneExportNotFound)
def count_zone_tasks(self, context, criterion=None):
query = select([func.count(tables.zone_tasks.c.id)])
query = select(func.count(tables.zone_tasks.c.id))
query = self._apply_criterion(tables.zone_tasks, query, criterion)
query = self._apply_tenant_criteria(context, tables.zone_tasks, query)
query = self._apply_deleted_criteria(context, tables.zone_tasks, query)
resultproxy = self.session.execute(query)
result = resultproxy.fetchone()
with sql.get_read_session() as session:
resultproxy = session.execute(query)
result = resultproxy.fetchone()
if result is None:
return 0

View File

@ -20,8 +20,8 @@ LOG = logging.getLogger(__name__)
def is_migration_needed(equivalent_revision):
metadata = sa.MetaData(bind=op.get_bind())
sa.MetaData.reflect(metadata)
metadata = sa.MetaData()
metadata.bind = op.get_bind()
if 'migrate_version' not in metadata.tables.keys():
return True

View File

@ -84,19 +84,24 @@ def upgrade() -> None:
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"}
"pk": "pk_%(table_name)s"
}
metadata = sa.MetaData(naming_convention=convention)
metadata.bind = op.get_bind()
# Get all the tables
domains_table = sa.Table('domains', metadata, autoload=True)
domains_table = sa.Table('domains', metadata,
autoload_with=op.get_bind())
domain_attrib_table = sa.Table('domain_attributes', metadata,
autoload=True)
recordsets_table = sa.Table('recordsets', metadata, autoload=True)
records_table = sa.Table('records', metadata, autoload=True)
ztr_table = sa.Table('zone_transfer_requests', metadata, autoload=True)
zta_table = sa.Table('zone_transfer_accepts', metadata, autoload=True)
autoload_with=op.get_bind())
recordsets_table = sa.Table('recordsets', metadata,
autoload_with=op.get_bind())
records_table = sa.Table('records', metadata, autoload_with=op.get_bind())
ztr_table = sa.Table('zone_transfer_requests', metadata,
autoload_with=op.get_bind())
zta_table = sa.Table('zone_transfer_accepts', metadata,
autoload_with=op.get_bind())
# Remove the affected FKs
# Define FKs

View File

@ -21,6 +21,7 @@ import copy
import datetime
import futurist
import random
import unittest
from unittest import mock
from oslo_config import cfg
@ -37,6 +38,7 @@ from testtools.matchers import GreaterThan
from designate import exceptions
from designate import objects
from designate.storage.impl_sqlalchemy import tables
from designate.storage import sql
from designate.tests import fixtures
from designate.tests.test_central import CentralTestCase
from designate import utils
@ -955,6 +957,7 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(exceptions.BadRequest, exc.exc_info[0])
@unittest.expectedFailure # FIXME
def test_update_zone_deadlock_retry(self):
# Create a zone
zone = self.create_zone(name='example.org.')
@ -964,7 +967,7 @@ class CentralServiceTest(CentralTestCase):
zone.email = 'info@example.net'
# Due to Python's scoping of i - we need to make it a mutable type
# for the counter to work.. In Py3, we can use the nonlocal keyword.
# for the counter to work. In Py3, we can use the nonlocal keyword.
i = [False]
def fail_once_then_pass():
@ -975,7 +978,7 @@ class CentralServiceTest(CentralTestCase):
raise db_exception.DBDeadlock()
with mock.patch.object(self.central_service.storage, 'commit',
side_effect=fail_once_then_pass):
side_effect=fail_once_then_pass):
# Perform the update
zone = self.central_service.update_zone(
self.admin_context, zone)
@ -1107,7 +1110,8 @@ class CentralServiceTest(CentralTestCase):
"""Fetch all zones including deleted ones
"""
query = tables.zones.select()
return self.central_service.storage.session.execute(query).fetchall()
with sql.get_read_session() as session:
return session.execute(query).fetchall()
def _log_all_zones(self, zones, msg=None):
"""Log out a summary of zones
@ -1119,7 +1123,7 @@ class CentralServiceTest(CentralTestCase):
tpl = "%-35s | %-11s | %-11s | %-32s | %-20s | %s"
LOG.debug(tpl % cols)
for z in zones:
LOG.debug(tpl % tuple(z[k] for k in cols))
LOG.debug(tpl % tuple(getattr(z, k) for k in cols))
def _assert_count_all_zones(self, n):
"""Assert count ALL zones including deleted ones
@ -1149,8 +1153,9 @@ class CentralServiceTest(CentralTestCase):
status='DELETED',
)
pxy = self.central_service.storage.session.execute(query)
self.assertEqual(1, pxy.rowcount)
with sql.get_write_session() as session:
pxy = session.execute(query)
self.assertEqual(1, pxy.rowcount)
return zone
@mock.patch.object(notifier.Notifier, "info")
@ -1866,6 +1871,7 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(1800, recordset.ttl)
self.assertThat(new_serial, GreaterThan(original_serial))
@unittest.expectedFailure # FIXME
def test_update_recordset_deadlock_retry(self):
# Create a zone
zone = self.create_zone()
@ -1877,7 +1883,7 @@ class CentralServiceTest(CentralTestCase):
recordset.ttl = 1800
# Due to Python's scoping of i - we need to make it a mutable type
# for the counter to work.. In Py3, we can use the nonlocal keyword.
# for the counter to work. In Py3, we can use the nonlocal keyword.
i = [False]
def fail_once_then_pass():

View File

@ -21,6 +21,7 @@ from oslo_utils import timeutils
from designate.producer import tasks
from designate.storage.impl_sqlalchemy import tables
from designate.storage import sql
from designate.tests import fixtures
from designate.tests import TestCase
@ -52,7 +53,8 @@ class DeletedZonePurgeTest(TestCase):
def _fetch_all_zones(self):
# Fetch all zones including deleted ones.
query = tables.zones.select()
return self.central_service.storage.session.execute(query).fetchall()
with sql.get_read_session() as session:
return session.execute(query).fetchall()
def _delete_zone(self, zone, mock_deletion_time):
# Set a zone as deleted
@ -64,8 +66,9 @@ class DeletedZonePurgeTest(TestCase):
status='DELETED',
)
pxy = self.central_service.storage.session.execute(query)
self.assertEqual(1, pxy.rowcount)
with sql.get_write_session() as session:
pxy = session.execute(query)
self.assertEqual(1, pxy.rowcount)
def _create_deleted_zones(self):
# Create a number of deleted zones in the past days.
@ -114,7 +117,8 @@ class PeriodicGenerateDelayedNotifyTaskTest(TestCase):
def _fetch_zones(self, query):
# Fetch zones including deleted ones.
return self.central_service.storage.session.execute(query).fetchall()
with sql.get_read_session() as session:
return session.execute(query).fetchall()
def _create_zones(self):
# Create a number of zones; half of them with delayed_notify set.

View File

@ -586,8 +586,8 @@ class StorageTestCase(object):
def test_count_tenants_none_result(self):
rp = mock.Mock()
rp.fetchone.return_value = None
with mock.patch.object(self.storage.session, 'execute',
return_value=rp):
with mock.patch('designate.storage.sql.get_write_session',
return_value=rp):
tenants = self.storage.count_tenants(self.admin_context)
self.assertEqual(0, tenants)
@ -870,8 +870,9 @@ class StorageTestCase(object):
def test_count_zones_none_result(self):
rp = mock.Mock()
rp.fetchone.return_value = None
with mock.patch.object(self.storage.session, 'execute',
return_value=rp):
with mock.patch('designate.storage.sql.get_write_session',
return_value=rp):
zones = self.storage.count_zones(self.admin_context)
self.assertEqual(0, zones)
@ -1270,8 +1271,8 @@ class StorageTestCase(object):
def test_count_recordsets_none_result(self):
rp = mock.Mock()
rp.fetchone.return_value = None
with mock.patch.object(self.storage.session, 'execute',
return_value=rp):
with mock.patch('designate.storage.sql.get_write_session',
return_value=rp):
recordsets = self.storage.count_recordsets(self.admin_context)
self.assertEqual(0, recordsets)
@ -1501,8 +1502,8 @@ class StorageTestCase(object):
def test_count_records_none_result(self):
rp = mock.Mock()
rp.fetchone.return_value = None
with mock.patch.object(self.storage.session, 'execute',
return_value=rp):
with mock.patch('designate.storage.sql.get_write_session',
return_value=rp):
records = self.storage.count_records(self.admin_context)
self.assertEqual(0, records)
@ -3065,8 +3066,8 @@ class StorageTestCase(object):
def test_count_zone_tasks_none_result(self):
rp = mock.Mock()
rp.fetchone.return_value = None
with mock.patch.object(self.storage.session, 'execute',
return_value=rp):
with mock.patch('designate.storage.sql.get_write_session',
return_value=rp):
zones = self.storage.count_zone_tasks(self.admin_context)
self.assertEqual(0, zones)

View File

@ -14,8 +14,10 @@
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from sqlalchemy import text
from designate import storage
from designate.storage import sql
from designate.tests.test_storage import StorageTestCase
from designate.tests import TestCase
@ -53,6 +55,7 @@ class SqlalchemyStorageTest(StorageTestCase, TestCase):
'zone_transfer_requests',
'zones'
]
inspector = self.storage.get_inspector()
actual_table_names = inspector.get_table_names()
@ -79,16 +82,17 @@ class SqlalchemyStorageTest(StorageTestCase, TestCase):
self.assertEqual(table_names, actual_table_names)
def test_schema_table_indexes(self):
indexes_t = self.storage.engine.execute(
"SELECT * FROM sqlite_master WHERE type = 'index';")
with sql.get_read_session() as session:
indexes_t = session.execute(
text("SELECT * FROM sqlite_master WHERE type = 'index';"))
indexes = {} # table name -> index names -> cmd
for _, index_name, table_name, num, cmd in indexes_t:
if index_name.startswith("sqlite_"):
continue # ignore sqlite-specific indexes
if table_name not in indexes:
indexes[table_name] = {}
indexes[table_name][index_name] = cmd
indexes = {} # table name -> index names -> cmd
for _, index_name, table_name, num, cmd in indexes_t:
if index_name.startswith("sqlite_"):
continue # ignore sqlite-specific indexes
if table_name not in indexes:
indexes[table_name] = {}
indexes[table_name][index_name] = cmd
expected = {
"records": {

View File

@ -18,18 +18,19 @@ from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
from designate.cmd import status
from designate.sqlalchemy import session
from designate.sqlalchemy import sql
from designate import tests
class TestDuplicateServiceStatus(tests.TestCase):
def setUp(self):
super(TestDuplicateServiceStatus, self).setUp()
self.engine = session.get_engine('storage:sqlalchemy')
self.meta = MetaData()
self.meta.bind = self.engine
self.service_statuses_table = Table('service_statuses', self.meta,
autoload=True)
self.meta.bind = sql.get_read_engine()
self.service_statuses_table = Table(
'service_statuses', self.meta,
autoload_with=sql.get_read_engine()
)
def test_success(self):
fake_record = {'id': '1',
@ -39,27 +40,46 @@ class TestDuplicateServiceStatus(tests.TestCase):
'stats': '',
'capabilities': '',
}
self.service_statuses_table.insert().execute(fake_record)
# Different hostname should be fine
fake_record['id'] = '2'
fake_record['hostname'] = 'otherhost'
self.service_statuses_table.insert().execute(fake_record)
# Different service_name should be fine
fake_record['id'] = '3'
fake_record['service_name'] = 'producer'
self.service_statuses_table.insert().execute(fake_record)
checks = status.Checks()
self.assertEqual(upgradecheck.Code.SUCCESS,
checks._duplicate_service_status().code)
with sql.get_write_session() as session:
query = (
self.service_statuses_table.insert().
values(fake_record)
)
session.execute(query)
@mock.patch('designate.sqlalchemy.session.get_engine')
def test_failure(self, mock_get_engine):
mock_engine = mock.MagicMock()
mock_execute = mock.MagicMock()
mock_engine.execute.return_value = mock_execute
mock_execute.fetchall.return_value = [(2,)]
mock_get_engine.return_value = mock_engine
# Different hostname should be fine
fake_record['id'] = '2'
fake_record['hostname'] = 'otherhost'
query = (
self.service_statuses_table.insert().
values(fake_record)
)
session.execute(query)
# Different service_name should be fine
fake_record['id'] = '3'
fake_record['service_name'] = 'producer'
query = (
self.service_statuses_table.insert().
values(fake_record)
)
session.execute(query)
checks = status.Checks()
self.assertEqual(upgradecheck.Code.SUCCESS,
checks._duplicate_service_status().code)
@mock.patch('designate.sqlalchemy.sql.get_read_session')
@mock.patch('designate.storage.sql.get_read_engine')
def test_failure(self, mock_get_engine, mock_get_read):
mock_sql_execute = mock.Mock()
mock_sql_fetchall = mock.Mock()
mock_get_read().__enter__.return_value = mock_sql_execute
mock_sql_execute.execute.return_value = mock_sql_fetchall
mock_sql_fetchall.fetchall.return_value = [(2,)]
checks = status.Checks()
self.assertEqual(upgradecheck.Code.FAILURE,
checks._duplicate_service_status().code)
result = checks._duplicate_service_status().code
self.assertEqual(upgradecheck.Code.FAILURE, result)

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Fixed compatibility issues with SQLAlchemy 2.x.

View File

@ -34,7 +34,7 @@ python-designateclient>=2.12.0 # Apache-2.0
python-neutronclient>=6.7.0 # Apache-2.0
requests>=2.23.0 # Apache-2.0
tenacity>=6.0.0 # Apache-2.0
SQLAlchemy>=1.2.19 # MIT
SQLAlchemy>=1.4.41 # MIT
stevedore>=1.20.0 # Apache-2.0
WebOb>=1.7.1 # MIT
dnspython>=2.2.1 # http://www.dnspython.org/LICENSE