Use oslo.db in manila

Use oslo.db library instead of own implementation.

Oslo.db code contains different utils for work with db api, db session,
migrations, test classes for db testing, tools for automatic retry
db.api query if db connection was lost, etc.

Oslo.db code was tested better as it is currently used in many projects,
there will be no need in testing our own implementation.

In many cases our own implementation of work with db duplicates
oslo.db code.

Remove:
- manila/common/sqlalchemyutils.py;
- manila/db/sqlalchemy/utils.py;
- manila/db/sqlalchemy/session.py;
- manila/db/sqlalchemy/migration.py;
- DBError, wrap_db_error, InvalidUnicodeParameter exceptions;
- db_sync, db_version, db_version_control, _find_migrate_repo
function

and replace it with appropriate oslo.db functions.

Add 'joinedload' statement to db queries if necessary.

Fix unit tests, clean up test_migrations.py

Implements bp oslo.db

Change-Id: I48a4da797594cf020f67f78024bd0f86b5abd5ef
This commit is contained in:
Julia Varlamova 2014-07-10 16:29:07 +04:00
parent 0b382f186b
commit f1f8ca0ad1
21 changed files with 164 additions and 1248 deletions

View File

@ -15,6 +15,7 @@
"""The shares api.""" """The shares api."""
from oslo.db import exception as db_exception
import six import six
import webob import webob
from webob import exc from webob import exc
@ -183,7 +184,7 @@ class ShareNetworkController(wsgi.Controller):
share_network = db_api.share_network_update(context, share_network = db_api.share_network_update(context,
id, id,
update_values) update_values)
except exception.DBError: except db_exception.DBError:
msg = "Could not save supplied data due to database error" msg = "Could not save supplied data due to database error"
raise exc.HTTPBadRequest(explanation=msg) raise exc.HTTPBadRequest(explanation=msg)
@ -223,7 +224,7 @@ class ShareNetworkController(wsgi.Controller):
else: else:
try: try:
share_network = db_api.share_network_create(context, values) share_network = db_api.share_network_create(context, values)
except exception.DBError: except db_exception.DBError:
msg = "Could not save supplied data due to database error" msg = "Could not save supplied data due to database error"
raise exc.HTTPBadRequest(explanation=msg) raise exc.HTTPBadRequest(explanation=msg)

View File

@ -1,128 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2010-2011 OpenStack LLC.
# Copyright 2012 Justin Santa Barbara
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of paginate query."""
import sqlalchemy
from manila import exception
from manila.openstack.common import log as logging
LOG = logging.getLogger(__name__)
# copied from glance/db/sqlalchemy/api.py
def paginate_query(query, model, limit, sort_keys, marker=None,
sort_dir=None, sort_dirs=None):
"""Returns a query with sorting / pagination criteria added.
Pagination works by requiring a unique sort_key, specified by sort_keys.
(If sort_keys is not unique, then we risk looping through values.)
We use the last row in the previous page as the 'marker' for pagination.
So we must return values that follow the passed marker in the order.
With a single-valued sort_key, this would be easy: sort_key > X.
With a compound-values sort_key, (k1, k2, k3) we must do this to repeat
the lexicographical ordering:
(k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3)
We also have to cope with different sort_directions.
Typically, the id of the last row is used as the client-facing pagination
marker, then the actual marker object must be fetched from the db and
passed in to us as marker.
:param query: the query object to which we should add paging/sorting
:param model: the ORM model class
:param limit: maximum number of items to return
:param sort_keys: array of attributes by which results should be sorted
:param marker: the last item of the previous page; we returns the next
results after this value.
:param sort_dir: direction in which results should be sorted (asc, desc)
:param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys
:rtype: sqlalchemy.orm.query.Query
:return: The query with sorting/pagination added.
"""
if 'id' not in sort_keys:
# TODO(justinsb): If this ever gives a false-positive, check
# the actual primary key, rather than assuming its id
LOG.warn(_('Id not in sort_keys; is sort_keys unique?'))
assert(not (sort_dir and sort_dirs))
# Default the sort direction to ascending
if sort_dirs is None and sort_dir is None:
sort_dir = 'asc'
# Ensure a per-column sort direction
if sort_dirs is None:
sort_dirs = [sort_dir for _sort_key in sort_keys]
assert(len(sort_dirs) == len(sort_keys))
# Add sorting
for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs):
sort_dir_func = {
'asc': sqlalchemy.asc,
'desc': sqlalchemy.desc,
}[current_sort_dir]
try:
sort_key_attr = getattr(model, current_sort_key)
except AttributeError:
raise exception.InvalidInput(reason='Invalid sort key')
query = query.order_by(sort_dir_func(sort_key_attr))
# Add pagination
if marker is not None:
marker_values = []
for sort_key in sort_keys:
v = getattr(marker, sort_key)
marker_values.append(v)
# Build up an array of sort criteria as in the docstring
criteria_list = []
for i in xrange(0, len(sort_keys)):
crit_attrs = []
for j in xrange(0, i):
model_attr = getattr(model, sort_keys[j])
crit_attrs.append((model_attr == marker_values[j]))
model_attr = getattr(model, sort_keys[i])
if sort_dirs[i] == 'desc':
crit_attrs.append((model_attr < marker_values[i]))
elif sort_dirs[i] == 'asc':
crit_attrs.append((model_attr > marker_values[i]))
else:
raise ValueError(_("Unknown sort direction, "
"must be 'desc' or 'asc'"))
criteria = sqlalchemy.sql.and_(*crit_attrs)
criteria_list.append(criteria)
f = sqlalchemy.sql.or_(*criteria_list)
query = query.filter(f)
if limit is not None:
query = query.limit(limit)
return query

View File

@ -32,21 +32,20 @@ these objects be simple dictionaries.
**Related Flags** **Related Flags**
:db_backend: string to lookup in the list of LazyPluggable backends. :backend: string to lookup in the list of LazyPluggable backends.
`sqlalchemy` is the only supported backend right now. `sqlalchemy` is the only supported backend right now.
:sql_connection: string specifying the sqlalchemy connection to use, like: :connection: string specifying the sqlalchemy connection to use, like:
`sqlite:///var/lib/manila/manila.sqlite`. `sqlite:///var/lib/manila/manila.sqlite`.
:enable_new_services: when adding a new service to the database, is it in the :enable_new_services: when adding a new service to the database, is it in the
pool of available hardware (Default: True) pool of available hardware (Default: True)
""" """
from oslo.config import cfg from oslo.config import cfg
from oslo.db import api as db_api
from manila import exception from manila import exception
from manila import utils
db_opts = [ db_opts = [
cfg.StrOpt('db_backend', cfg.StrOpt('db_backend',
@ -67,8 +66,9 @@ db_opts = [
CONF = cfg.CONF CONF = cfg.CONF
CONF.register_opts(db_opts) CONF.register_opts(db_opts)
IMPL = utils.LazyPluggable('db_backend', _BACKEND_MAPPING = {'sqlalchemy': 'manila.db.sqlalchemy.api'}
sqlalchemy='manila.db.sqlalchemy.api') IMPL = db_api.DBAPI.from_config(cfg.CONF, backend_mapping=_BACKEND_MAPPING,
lazy=True)
################### ###################

View File

@ -18,21 +18,27 @@
"""Database setup and migration commands.""" """Database setup and migration commands."""
import os
from manila.db.sqlalchemy import api as db_api
from manila import utils from manila import utils
IMPL = utils.LazyPluggable('db_backend', IMPL = utils.LazyPluggable('db_backend',
sqlalchemy='manila.db.sqlalchemy.migration') sqlalchemy='oslo.db.sqlalchemy.migration')
INIT_VERSION = 000 INIT_VERSION = 000
MIGRATE_REPO = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'sqlalchemy/migrate_repo')
def db_sync(version=None): def db_sync(version=None):
"""Migrate the database to `version` or the most recent version.""" """Migrate the database to `version` or the most recent version."""
return IMPL.db_sync(version=version) return IMPL.db_sync(db_api.get_engine(), MIGRATE_REPO, version=version,
init_version=INIT_VERSION)
def db_version(): def db_version():
"""Display the current database version.""" """Display the current database version."""
return IMPL.db_version() return IMPL.db_version(db_api.get_engine(), MIGRATE_REPO, INIT_VERSION)

View File

@ -21,23 +21,23 @@
import datetime import datetime
import functools import functools
import sys
import time import time
import uuid import uuid
import warnings import warnings
from oslo.config import cfg from oslo.config import cfg
from oslo.db import exception as db_exception
from oslo.db import options as db_options
from oslo.db.sqlalchemy import session
import six import six
from sqlalchemy.exc import IntegrityError
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.sql.expression import literal_column from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql import func from sqlalchemy.sql import func
from manila.common import constants from manila.common import constants
from manila.common import sqlalchemyutils
from manila import db
from manila.db.sqlalchemy import models from manila.db.sqlalchemy import models
from manila.db.sqlalchemy.session import get_session
from manila import exception from manila import exception
from manila.openstack.common import log as logging from manila.openstack.common import log as logging
from manila.openstack.common import timeutils from manila.openstack.common import timeutils
@ -50,6 +50,35 @@ LOG = logging.getLogger(__name__)
_DEFAULT_QUOTA_NAME = 'default' _DEFAULT_QUOTA_NAME = 'default'
PER_PROJECT_QUOTAS = [] PER_PROJECT_QUOTAS = []
_FACADE = None
_DEFAULT_SQL_CONNECTION = 'sqlite://'
db_options.set_defaults(cfg.CONF,
connection=_DEFAULT_SQL_CONNECTION)
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = session.EngineFacade.from_config(cfg.CONF)
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
return facade.get_session(**kwargs)
def get_backend():
"""The backend is this module itself."""
return sys.modules[__name__]
def is_admin_context(context): def is_admin_context(context):
"""Indicates if the request context is an administrator.""" """Indicates if the request context is an administrator."""
@ -138,7 +167,7 @@ def require_share_exists(f):
""" """
def wrapper(context, share_id, *args, **kwargs): def wrapper(context, share_id, *args, **kwargs):
db.share_get(context, share_id) share_get(context, share_id)
return f(context, share_id, *args, **kwargs) return f(context, share_id, *args, **kwargs)
wrapper.__name__ = f.__name__ wrapper.__name__ = f.__name__
return wrapper return wrapper
@ -380,8 +409,11 @@ def service_create(context, values):
service_ref.update(values) service_ref.update(values)
if not CONF.enable_new_services: if not CONF.enable_new_services:
service_ref.disabled = True service_ref.disabled = True
service_ref.save()
return service_ref session = get_session()
with session.begin():
service_ref.save(session)
return service_ref
@require_admin_context @require_admin_context
@ -478,7 +510,9 @@ def quota_create(context, project_id, resource, limit, user_id=None):
quota_ref.project_id = project_id quota_ref.project_id = project_id
quota_ref.resource = resource quota_ref.resource = resource
quota_ref.hard_limit = limit quota_ref.hard_limit = limit
quota_ref.save() session = get_session()
with session.begin():
quota_ref.save(session)
return quota_ref return quota_ref
@ -551,7 +585,9 @@ def quota_class_create(context, class_name, resource, limit):
quota_class_ref.class_name = class_name quota_class_ref.class_name = class_name
quota_class_ref.resource = resource quota_class_ref.resource = resource
quota_class_ref.hard_limit = limit quota_class_ref.hard_limit = limit
quota_class_ref.save() session = get_session()
with session.begin():
quota_class_ref.save(session)
return quota_class_ref return quota_class_ref
@ -1271,7 +1307,9 @@ def snapshot_data_get_for_project(context, project_id, user_id, session=None):
func.sum(models.ShareSnapshot.size), func.sum(models.ShareSnapshot.size),
read_deleted="no", read_deleted="no",
session=session).\ session=session).\
filter_by(project_id=project_id) filter_by(project_id=project_id).\
options(joinedload('share'))
if user_id: if user_id:
result = query.filter_by(user_id=user_id).first() result = query.filter_by(user_id=user_id).first()
else: else:
@ -1297,6 +1335,7 @@ def share_snapshot_get(context, snapshot_id, session=None):
result = model_query(context, models.ShareSnapshot, session=session, result = model_query(context, models.ShareSnapshot, session=session,
project_only=True).\ project_only=True).\
filter_by(id=snapshot_id).\ filter_by(id=snapshot_id).\
options(joinedload('share')).\
first() first()
if not result: if not result:
@ -1307,7 +1346,9 @@ def share_snapshot_get(context, snapshot_id, session=None):
@require_admin_context @require_admin_context
def share_snapshot_get_all(context): def share_snapshot_get_all(context):
return model_query(context, models.ShareSnapshot).all() return model_query(context, models.ShareSnapshot).\
options(joinedload('share')).\
all()
@require_context @require_context
@ -1315,6 +1356,7 @@ def share_snapshot_get_all_by_project(context, project_id):
authorize_project_context(context, project_id) authorize_project_context(context, project_id)
return model_query(context, models.ShareSnapshot).\ return model_query(context, models.ShareSnapshot).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
options(joinedload('share')).\
all() all()
@ -1322,7 +1364,9 @@ def share_snapshot_get_all_by_project(context, project_id):
def share_snapshot_get_all_for_share(context, share_id): def share_snapshot_get_all_for_share(context, share_id):
return model_query(context, models.ShareSnapshot, read_deleted='no', return model_query(context, models.ShareSnapshot, read_deleted='no',
project_only=True).\ project_only=True).\
filter_by(share_id=share_id).all() filter_by(share_id=share_id).\
options(joinedload('share')).\
all()
@require_context @require_context
@ -1334,6 +1378,7 @@ def share_snapshot_data_get_for_project(context, project_id, session=None):
read_deleted="no", read_deleted="no",
session=session).\ session=session).\
filter_by(project_id=project_id).\ filter_by(project_id=project_id).\
options(joinedload('share')).\
first() first()
# NOTE(vish): convert None to 0 # NOTE(vish): convert None to 0
@ -1379,7 +1424,8 @@ def share_metadata_update(context, share_id, metadata, delete):
def _share_metadata_get_query(context, share_id, session=None): def _share_metadata_get_query(context, share_id, session=None):
return model_query(context, models.ShareMetadata, session=session, return model_query(context, models.ShareMetadata, session=session,
read_deleted="no").\ read_deleted="no").\
filter_by(share_id=share_id) filter_by(share_id=share_id).\
options(joinedload('share'))
@require_context @require_context
@ -1533,7 +1579,7 @@ def share_network_create(context, values):
session = get_session() session = get_session()
with session.begin(): with session.begin():
network_ref.save(session=session) network_ref.save(session=session)
return network_ref return share_network_get(context, values['id'], session)
@require_context @require_context
@ -1649,8 +1695,9 @@ def share_network_remove_security_service(context, id, security_service_id):
def _server_get_query(context, session=None): def _server_get_query(context, session=None):
if session is None: if session is None:
session = get_session() session = get_session()
return model_query(context, models.ShareServer, session=session)\ return model_query(context, models.ShareServer, session=session).\
.options(joinedload('shares')) options(joinedload('shares'), joinedload('network_allocations'),
joinedload('share_network'))
@require_context @require_context
@ -1725,7 +1772,9 @@ def share_server_backend_details_set(context, share_server_id, server_details):
'value': meta_value, 'value': meta_value,
'share_server_id': share_server_id 'share_server_id': share_server_id
}) })
meta_ref.save() session = get_session()
with session.begin():
meta_ref.save(session)
return server_details return server_details
@ -1843,10 +1892,10 @@ def volume_type_create(context, values):
volume_type_ref = models.VolumeTypes() volume_type_ref = models.VolumeTypes()
volume_type_ref.update(values) volume_type_ref.update(values)
volume_type_ref.save(session=session) volume_type_ref.save(session=session)
except exception.Duplicate: except db_exception.DBDuplicateEntry:
raise exception.VolumeTypeExists(id=values['name']) raise exception.VolumeTypeExists(id=values['name'])
except Exception as e: except Exception as e:
raise exception.DBError(e) raise db_exception.DBError(e)
return volume_type_ref return volume_type_ref
@ -1859,6 +1908,7 @@ def volume_type_get_all(context, inactive=False, filters=None):
rows = model_query(context, models.VolumeTypes, rows = model_query(context, models.VolumeTypes,
read_deleted=read_deleted).\ read_deleted=read_deleted).\
options(joinedload('extra_specs')).\ options(joinedload('extra_specs')).\
options(joinedload('shares')).\
order_by("name").\ order_by("name").\
all() all()
@ -1878,6 +1928,7 @@ def _volume_type_get(context, id, session=None, inactive=False):
read_deleted=read_deleted).\ read_deleted=read_deleted).\
options(joinedload('extra_specs')).\ options(joinedload('extra_specs')).\
filter_by(id=id).\ filter_by(id=id).\
options(joinedload('shares')).\
first() first()
if not result: if not result:
@ -1897,6 +1948,7 @@ def _volume_type_get_by_name(context, name, session=None):
result = model_query(context, models.VolumeTypes, session=session).\ result = model_query(context, models.VolumeTypes, session=session).\
options(joinedload('extra_specs')).\ options(joinedload('extra_specs')).\
filter_by(name=name).\ filter_by(name=name).\
options(joinedload('shares')).\
first() first()
if not result: if not result:
@ -1957,7 +2009,8 @@ def volume_get_active_by_window(context,
def _volume_type_extra_specs_query(context, volume_type_id, session=None): def _volume_type_extra_specs_query(context, volume_type_id, session=None):
return model_query(context, models.VolumeTypeExtraSpecs, session=session, return model_query(context, models.VolumeTypeExtraSpecs, session=session,
read_deleted="no").\ read_deleted="no").\
filter_by(volume_type_id=volume_type_id) filter_by(volume_type_id=volume_type_id).\
options(joinedload('volume_type'))
@require_context @require_context
@ -1991,6 +2044,7 @@ def _volume_type_extra_specs_get_item(context, volume_type_id, key,
result = _volume_type_extra_specs_query( result = _volume_type_extra_specs_query(
context, volume_type_id, session=session).\ context, volume_type_id, session=session).\
filter_by(key=key).\ filter_by(key=key).\
options(joinedload('volume_type')).\
first() first()
if not result: if not result:

View File

@ -1,116 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import distutils.version as dist_version
import os
import migrate
from migrate.versioning import util as migrate_util
from oslo.config import cfg
import sqlalchemy
from manila.db import migration
from manila.db.sqlalchemy.session import get_engine
from manila import exception
from manila.openstack.common import log as logging
LOG = logging.getLogger(__name__)
@migrate_util.decorator
def patched_with_engine(f, *a, **kw):
url = a[0]
engine = migrate_util.construct_engine(url, **kw)
try:
kw['engine'] = engine
return f(*a, **kw)
finally:
if isinstance(engine, migrate_util.Engine) and engine is not url:
migrate_util.log.debug('Disposing SQLAlchemy engine %s', engine)
engine.dispose()
# TODO(jkoelker) When migrate 0.7.3 is released and manila depends
# on that version or higher, this can be removed
MIN_PKG_VERSION = dist_version.StrictVersion('0.7.3')
if (not hasattr(migrate, '__version__') or
dist_version.StrictVersion(migrate.__version__) < MIN_PKG_VERSION):
migrate_util.with_engine = patched_with_engine
# NOTE(jkoelker) Delay importing migrate until we are patched
from migrate import exceptions as versioning_exceptions
from migrate.versioning import api as versioning_api
from migrate.versioning.repository import Repository
CONF = cfg.CONF
_REPOSITORY = None
def db_sync(version=None):
if version is not None:
try:
version = int(version)
except ValueError:
raise exception.Error(_("version should be an integer"))
current_version = db_version()
repository = _find_migrate_repo()
if version is None or version > current_version:
return versioning_api.upgrade(get_engine(), repository, version)
else:
return versioning_api.downgrade(get_engine(), repository,
version)
def db_version():
repository = _find_migrate_repo()
try:
return versioning_api.db_version(get_engine(), repository)
except versioning_exceptions.DatabaseNotControlledError:
# If we aren't version controlled we may already have the database
# in the state from before we started version control, check for that
# and set up version_control appropriately
meta = sqlalchemy.MetaData()
engine = get_engine()
meta.reflect(bind=engine)
tables = meta.tables
if len(tables) == 0:
db_version_control(migration.INIT_VERSION)
return versioning_api.db_version(get_engine(), repository)
else:
raise exception.Error(_("Upgrade DB using Essex release first."))
def db_version_control(version=None):
repository = _find_migrate_repo()
versioning_api.version_control(get_engine(), repository, version)
return version
def _find_migrate_repo():
"""Get the path for the migrate repository."""
global _REPOSITORY
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'migrate_repo')
assert os.path.exists(path)
if _REPOSITORY is None:
_REPOSITORY = Repository(path)
return _REPOSITORY

View File

@ -22,6 +22,7 @@ SQLAlchemy models for Manila data.
""" """
from oslo.config import cfg from oslo.config import cfg
from oslo.db.sqlalchemy import models
import six import six
from sqlalchemy import Column, Index, Integer, String, Text, schema from sqlalchemy import Column, Index, Integer, String, Text, schema
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -30,7 +31,6 @@ from sqlalchemy import ForeignKey, DateTime, Boolean, Enum
from sqlalchemy.orm import relationship, backref, object_mapper from sqlalchemy.orm import relationship, backref, object_mapper
from manila.common import constants from manila.common import constants
from manila.db.sqlalchemy.session import get_session
from manila import exception from manila import exception
from manila.openstack.common import timeutils from manila.openstack.common import timeutils
@ -38,73 +38,25 @@ CONF = cfg.CONF
BASE = declarative_base() BASE = declarative_base()
class ManilaBase(object): class ManilaBase(models.ModelBase, models.TimestampMixin):
"""Base class for Manila Models.""" """Base class for Manila Models."""
__table_args__ = {'mysql_engine': 'InnoDB'} __table_args__ = {'mysql_engine': 'InnoDB'}
__table_initialized__ = False
created_at = Column(DateTime, default=lambda: timeutils.utcnow())
updated_at = Column(DateTime, onupdate=lambda: timeutils.utcnow())
deleted_at = Column(DateTime) deleted_at = Column(DateTime)
deleted = Column(Integer, default=0) deleted = Column(Integer, default=0)
metadata = None metadata = None
def save(self, session=None):
"""Save this object."""
if not session:
session = get_session()
# NOTE(boris-42): This part of code should be look like:
# sesssion.add(self)
# session.flush()
# But there is a bug in sqlalchemy and eventlet that
# raises NoneType exception if there is no running
# transaction and rollback is called. As long as
# sqlalchemy has this bug we have to create transaction
# explicity.
with session.begin(subtransactions=True):
try:
session.add(self)
session.flush()
except IntegrityError as e:
raise exception.Duplicate(message=str(e))
def delete(self, session=None): def delete(self, session=None):
"""Delete this object.""" """Delete this object."""
self.deleted = self.id self.deleted = self.id
self.deleted_at = timeutils.utcnow() self.deleted_at = timeutils.utcnow()
self.save(session=session) self.save(session=session)
def __setitem__(self, key, value): def to_dict(self):
setattr(self, key, value) model_dict = {}
for k, v in six.iteritems(self):
def __getitem__(self, key): if not issubclass(type(v), ManilaBase):
return getattr(self, key) model_dict[k] = v
return model_dict
def get(self, key, default=None):
return getattr(self, key, default)
def __iter__(self):
self._i = iter(object_mapper(self).columns)
return self
def next(self):
n = self._i.next().name
return n, getattr(self, n)
def update(self, values):
"""Make the model object behave like a dict."""
for k, v in six.iteritems(values):
setattr(self, k, v)
def iteritems(self):
"""Make the model object behave like a dict.
Includes attributes from joins.
"""
local = dict(self)
joined = dict([(k, v) for k, v in six.iteritems(self.__dict__)
if not k[0] == '_'])
local.update(joined)
return six.iteritems(local)
class Service(BASE, ManilaBase): class Service(BASE, ManilaBase):
@ -500,6 +452,6 @@ def register_models():
ShareAccessMapping, ShareAccessMapping,
ShareSnapshot ShareSnapshot
) )
engine = create_engine(CONF.sql_connection, echo=False) engine = create_engine(CONF.database.connection, echo=False)
for model in models: for model in models:
model.metadata.create_all(engine) model.metadata.create_all(engine)

View File

@ -1,151 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Session Handling for SQLAlchemy backend."""
import time
from oslo.config import cfg
from sqlalchemy.exc import DisconnectionError, OperationalError
import sqlalchemy.interfaces
import sqlalchemy.orm
from sqlalchemy.pool import NullPool, StaticPool
import manila.exception
from manila.openstack.common import log as logging
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
_ENGINE = None
_MAKER = None
def get_session(autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy session."""
global _MAKER
if _MAKER is None:
engine = get_engine()
_MAKER = get_maker(engine, autocommit, expire_on_commit)
session = _MAKER()
session.query = manila.exception.wrap_db_error(session.query)
session.flush = manila.exception.wrap_db_error(session.flush)
return session
def synchronous_switch_listener(dbapi_conn, connection_rec):
"""Switch sqlite connections to non-synchronous mode"""
dbapi_conn.execute("PRAGMA synchronous = OFF")
def ping_listener(dbapi_conn, connection_rec, connection_proxy):
"""
Ensures that MySQL connections checked out of the
pool are alive.
Borrowed from:
http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f
"""
try:
dbapi_conn.cursor().execute('select 1')
except dbapi_conn.OperationalError as ex:
if ex.args[0] in (2006, 2013, 2014, 2045, 2055):
LOG.warn(_('Got mysql server has gone away: %s'), ex)
raise DisconnectionError("Database server went away")
else:
raise
def is_db_connection_error(args):
"""Return True if error in connecting to db."""
# NOTE(adam_g): This is currently MySQL specific and needs to be extended
# to support Postgres and others.
conn_err_codes = ('2002', '2003', '2006')
for err_code in conn_err_codes:
if args.find(err_code) != -1:
return True
return False
def get_engine():
"""Return a SQLAlchemy engine."""
global _ENGINE
if _ENGINE is None:
connection_dict = sqlalchemy.engine.url.make_url(CONF.sql_connection)
engine_args = {
"pool_recycle": CONF.sql_idle_timeout,
"echo": False,
'convert_unicode': True,
}
# Map our SQL debug level to SQLAlchemy's options
if CONF.sql_connection_debug >= 100:
engine_args['echo'] = 'debug'
elif CONF.sql_connection_debug >= 50:
engine_args['echo'] = True
if "sqlite" in connection_dict.drivername:
engine_args["poolclass"] = NullPool
if CONF.sql_connection == "sqlite://":
engine_args["poolclass"] = StaticPool
engine_args["connect_args"] = {'check_same_thread': False}
_ENGINE = sqlalchemy.create_engine(CONF.sql_connection, **engine_args)
if 'mysql' in connection_dict.drivername:
sqlalchemy.event.listen(_ENGINE, 'checkout', ping_listener)
elif "sqlite" in connection_dict.drivername:
if not CONF.sqlite_synchronous:
sqlalchemy.event.listen(_ENGINE, 'connect',
synchronous_switch_listener)
try:
_ENGINE.connect()
except OperationalError as e:
if not is_db_connection_error(e.args[0]):
raise
remaining = CONF.sql_max_retries
if remaining == -1:
remaining = 'infinite'
while True:
msg = _('SQL connection failed. %s attempts left.')
LOG.warn(msg % remaining)
if remaining != 'infinite':
remaining -= 1
time.sleep(CONF.sql_retry_interval)
try:
_ENGINE.connect()
break
except OperationalError as e:
if ((remaining != 'infinite' and remaining == 0) or
not is_db_connection_error(e.args[0])):
raise
return _ENGINE
def get_maker(engine, autocommit=True, expire_on_commit=False):
"""Return a SQLAlchemy sessionmaker using the given engine."""
return sqlalchemy.orm.sessionmaker(bind=engine,
autocommit=autocommit,
expire_on_commit=expire_on_commit)

View File

@ -1,499 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import re
from migrate.changeset import UniqueConstraint, ForeignKeyConstraint
import six
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy.engine import reflection
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import MetaData
from sqlalchemy import schema
from sqlalchemy.sql.expression import literal_column
from sqlalchemy.sql.expression import UpdateBase
from sqlalchemy.sql import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.types import NullType
from manila.db.sqlalchemy import api as db
from manila import exception
from manila.openstack.common.gettextutils import _
from manila.openstack.common import log as logging
from manila.openstack.common import timeutils
LOG = logging.getLogger(__name__)
def get_table(engine, name):
"""Returns an sqlalchemy table dynamically from db.
Needed because the models don't work for us in migrations
as models will be far out of sync with the current data.
"""
metadata = MetaData()
metadata.bind = engine
return Table(name, metadata, autoload=True)
class InsertFromSelect(UpdateBase):
def __init__(self, table, select):
self.table = table
self.select = select
@compiles(InsertFromSelect)
def visit_insert_from_select(element, compiler, **kw):
return "INSERT INTO %s %s" % (
compiler.process(element.table, asfrom=True),
compiler.process(element.select))
def _get_not_supported_column(col_name_col_instance, column_name):
try:
column = col_name_col_instance[column_name]
except Exception:
msg = _("Please specify column %s in col_name_col_instance "
"param. It is required because column has unsupported "
"type by sqlite).")
raise exception.ManilaException(msg % column_name)
if not isinstance(column, Column):
msg = _("col_name_col_instance param has wrong type of "
"column instance for column %s It should be instance "
"of sqlalchemy.Column.")
raise exception.ManilaException(msg % column_name)
return column
def _get_unique_constraints_in_sqlite(migrate_engine, table_name):
regexp = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)"
meta = MetaData(bind=migrate_engine)
table = Table(table_name, meta, autoload=True)
sql_data = migrate_engine.execute(
"""
SELECT sql
FROM
sqlite_master
WHERE
type = 'table' AND
name = :table_name;
""",
table_name=table_name
).fetchone()[0]
uniques = set([
schema.UniqueConstraint(
*[getattr(table.c, c.strip(' "'))
for c in cols.split(",")], name=name
)
for name, cols in re.findall(regexp, sql_data)
])
return uniques
def _drop_unique_constraint_in_sqlite(migrate_engine, table_name, uc_name,
**col_name_col_instance):
insp = reflection.Inspector.from_engine(migrate_engine)
meta = MetaData(bind=migrate_engine)
table = Table(table_name, meta, autoload=True)
columns = []
for column in table.columns:
if isinstance(column.type, NullType):
new_column = _get_not_supported_column(col_name_col_instance,
column.name)
columns.append(new_column)
else:
columns.append(column.copy())
uniques = _get_unique_constraints_in_sqlite(migrate_engine, table_name)
table.constraints.update(uniques)
constraints = [constraint for constraint in table.constraints
if not constraint.name == uc_name and
not isinstance(constraint, schema.ForeignKeyConstraint)]
new_table = Table(table_name + "__tmp__", meta, *(columns + constraints))
new_table.create()
indexes = []
for index in insp.get_indexes(table_name):
column_names = [new_table.c[c] for c in index['column_names']]
indexes.append(Index(index["name"],
*column_names,
unique=index["unique"]))
f_keys = []
for fk in insp.get_foreign_keys(table_name):
refcolumns = [fk['referred_table'] + '.' + col
for col in fk['referred_columns']]
f_keys.append(ForeignKeyConstraint(fk['constrained_columns'],
refcolumns, table=new_table, name=fk['name']))
ins = InsertFromSelect(new_table, table.select())
migrate_engine.execute(ins)
table.drop()
[index.create(migrate_engine) for index in indexes]
for fkey in f_keys:
fkey.create()
new_table.rename(table_name)
def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns,
**col_name_col_instance):
"""
This method drops UC from table and works for mysql, postgresql and sqlite.
In mysql and postgresql we are able to use "alter table" constuction. In
sqlite is only one way to drop UC:
1) Create new table with same columns, indexes and constraints
(except one that we want to drop).
2) Copy data from old table to new.
3) Drop old table.
4) Rename new table to the name of old table.
:param migrate_engine: sqlalchemy engine
:param table_name: name of table that contains uniq constarint.
:param uc_name: name of uniq constraint that will be dropped.
:param columns: columns that are in uniq constarint.
:param col_name_col_instance: contains pair column_name=column_instance.
column_instance is instance of Column. These params
are required only for columns that have unsupported
types by sqlite. For example BigInteger.
"""
if migrate_engine.name == "sqlite":
_drop_unique_constraint_in_sqlite(migrate_engine, table_name, uc_name,
**col_name_col_instance)
else:
meta = MetaData()
meta.bind = migrate_engine
t = Table(table_name, meta, autoload=True)
uc = UniqueConstraint(*columns, table=t, name=uc_name)
uc.drop()
def drop_old_duplicate_entries_from_table(migrate_engine, table_name,
use_soft_delete, *uc_column_names):
"""
This method is used to drop all old rows that have the same values for
columns in uc_columns.
"""
meta = MetaData()
meta.bind = migrate_engine
table = Table(table_name, meta, autoload=True)
columns_for_group_by = [table.c[name] for name in uc_column_names]
columns_for_select = [func.max(table.c.id)]
columns_for_select.extend(list(columns_for_group_by))
duplicated_rows_select = select(columns_for_select,
group_by=columns_for_group_by,
having=func.count(table.c.id) > 1)
for row in migrate_engine.execute(duplicated_rows_select):
# NOTE(boris-42): Do not remove row that has the biggest ID.
delete_condition = table.c.id != row[0]
for name in uc_column_names:
delete_condition &= table.c[name] == row[name]
rows_to_delete_select = select([table.c.id]).where(delete_condition)
for row in migrate_engine.execute(rows_to_delete_select).fetchall():
LOG.info(_("Deleted duplicated row with id: %(id)s from table: "
"%(table)s") % dict(id=row[0], table=table_name))
if use_soft_delete:
delete_statement = table.update().\
where(delete_condition).\
values({
'deleted': literal_column('id'),
'updated_at': literal_column('updated_at'),
'deleted_at': timeutils.utcnow()
})
else:
delete_statement = table.delete().where(delete_condition)
migrate_engine.execute(delete_statement)
def _get_default_deleted_value(table):
if isinstance(table.c.id.type, Integer):
return 0
if isinstance(table.c.id.type, String):
return ""
raise exception.ManilaException(_("Unsupported id columns type"))
def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes):
table = get_table(migrate_engine, table_name)
insp = reflection.Inspector.from_engine(migrate_engine)
real_indexes = insp.get_indexes(table_name)
existing_index_names = dict([(index['name'], index['column_names'])
for index in real_indexes])
# NOTE(boris-42): Restore indexes on `deleted` column
for index in indexes:
if 'deleted' not in index['column_names']:
continue
name = index['name']
if name in existing_index_names:
column_names = [table.c[c] for c in existing_index_names[name]]
old_index = Index(name, *column_names, unique=index["unique"])
old_index.drop(migrate_engine)
column_names = [table.c[c] for c in index['column_names']]
new_index = Index(index["name"], *column_names, unique=index["unique"])
new_index.create(migrate_engine)
def change_deleted_column_type_to_boolean(migrate_engine, table_name,
**col_name_col_instance):
if migrate_engine.name == "sqlite":
return _change_deleted_column_type_to_boolean_sqlite(migrate_engine,
table_name,
**col_name_col_instance)
insp = reflection.Inspector.from_engine(migrate_engine)
indexes = insp.get_indexes(table_name)
table = get_table(migrate_engine, table_name)
old_deleted = Column('old_deleted', Boolean, default=False)
old_deleted.create(table, populate_default=False)
table.update().\
where(table.c.deleted == table.c.id).\
values(old_deleted=True).\
execute()
table.c.deleted.drop()
table.c.old_deleted.alter(name="deleted")
_restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name,
**col_name_col_instance):
insp = reflection.Inspector.from_engine(migrate_engine)
table = get_table(migrate_engine, table_name)
columns = []
for column in table.columns:
column_copy = None
if column.name != "deleted":
if isinstance(column.type, NullType):
column_copy = _get_not_supported_column(col_name_col_instance,
column.name)
else:
column_copy = column.copy()
else:
column_copy = Column('deleted', Boolean, default=0)
columns.append(column_copy)
constraints = [constraint.copy() for constraint in table.constraints]
meta = MetaData(bind=migrate_engine)
new_table = Table(table_name + "__tmp__", meta,
*(columns + constraints))
new_table.create()
indexes = []
for index in insp.get_indexes(table_name):
column_names = [new_table.c[c] for c in index['column_names']]
indexes.append(Index(index["name"], *column_names,
unique=index["unique"]))
c_select = []
for c in table.c:
if c.name != "deleted":
c_select.append(c)
else:
c_select.append(table.c.deleted == table.c.id)
ins = InsertFromSelect(new_table, select(c_select))
migrate_engine.execute(ins)
table.drop()
[index.create(migrate_engine) for index in indexes]
new_table.rename(table_name)
new_table.update().\
where(new_table.c.deleted == new_table.c.id).\
values(deleted=True).\
execute()
def change_deleted_column_type_to_id_type(migrate_engine, table_name,
**col_name_col_instance):
if migrate_engine.name == "sqlite":
return _change_deleted_column_type_to_id_type_sqlite(migrate_engine,
table_name,
**col_name_col_instance)
insp = reflection.Inspector.from_engine(migrate_engine)
indexes = insp.get_indexes(table_name)
table = get_table(migrate_engine, table_name)
new_deleted = Column('new_deleted', table.c.id.type,
default=_get_default_deleted_value(table))
new_deleted.create(table, populate_default=True)
table.update().\
where(table.c.deleted == True).\
values(new_deleted=table.c.id).\
execute()
table.c.deleted.drop()
table.c.new_deleted.alter(name="deleted")
_restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes)
def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name,
**col_name_col_instance):
# NOTE(boris-42): sqlaclhemy-migrate can't drop column with check
# constraints in sqlite DB and our `deleted` column has
# 2 check constraints. So there is only one way to remove
# these constraints:
# 1) Create new table with the same columns, constraints
# and indexes. (except deleted column).
# 2) Copy all data from old to new table.
# 3) Drop old table.
# 4) Rename new table to old table name.
insp = reflection.Inspector.from_engine(migrate_engine)
meta = MetaData(bind=migrate_engine)
table = Table(table_name, meta, autoload=True)
default_deleted_value = _get_default_deleted_value(table)
columns = []
for column in table.columns:
column_copy = None
if column.name != "deleted":
if isinstance(column.type, NullType):
column_copy = _get_not_supported_column(col_name_col_instance,
column.name)
else:
column_copy = column.copy()
else:
column_copy = Column('deleted', table.c.id.type,
default=default_deleted_value)
columns.append(column_copy)
def is_deleted_column_constraint(constraint):
# NOTE(boris-42): There is no other way to check is CheckConstraint
# associated with deleted column.
if not isinstance(constraint, CheckConstraint):
return False
sqltext = str(constraint.sqltext)
return (sqltext.endswith("deleted in (0, 1)") or
sqltext.endswith("deleted IN (:deleted_1, :deleted_2)"))
constraints = []
for constraint in table.constraints:
if not is_deleted_column_constraint(constraint):
constraints.append(constraint.copy())
new_table = Table(table_name + "__tmp__", meta,
*(columns + constraints))
new_table.create()
indexes = []
for index in insp.get_indexes(table_name):
column_names = [new_table.c[c] for c in index['column_names']]
indexes.append(Index(index["name"], *column_names,
unique=index["unique"]))
ins = InsertFromSelect(new_table, table.select())
migrate_engine.execute(ins)
table.drop()
[index.create(migrate_engine) for index in indexes]
new_table.rename(table_name)
new_table.update().\
where(new_table.c.deleted == True).\
values(deleted=new_table.c.id).\
execute()
# NOTE(boris-42): Fix value of deleted column: False -> "" or 0.
new_table.update().\
where(new_table.c.deleted == False).\
values(deleted=default_deleted_value).\
execute()
def _add_index(migrate_engine, table, index_name, idx_columns):
index = Index(
index_name, *[getattr(table.c, col) for col in idx_columns]
)
index.create()
def _drop_index(migrate_engine, table, index_name, idx_columns):
index = Index(
index_name, *[getattr(table.c, col) for col in idx_columns]
)
index.drop()
def _change_index_columns(migrate_engine, table, index_name,
new_columns, old_columns):
Index(
index_name,
*[getattr(table.c, col) for col in old_columns]
).drop(migrate_engine)
Index(
index_name,
*[getattr(table.c, col) for col in new_columns]
).create()
def modify_indexes(migrate_engine, data, upgrade=True):
if migrate_engine.name == 'sqlite':
return
meta = MetaData()
meta.bind = migrate_engine
for table_name, indexes in six.iteritems(data):
table = Table(table_name, meta, autoload=True)
for index_name, old_columns, new_columns in indexes:
if not upgrade:
new_columns, old_columns = old_columns, new_columns
if migrate_engine.name == 'postgresql':
if upgrade:
_add_index(migrate_engine, table, index_name, new_columns)
else:
_drop_index(migrate_engine, table, index_name, old_columns)
elif migrate_engine.name == 'mysql':
_change_index_columns(migrate_engine, table, index_name,
new_columns, old_columns)
else:
raise ValueError('Unsupported DB %s' % migrate_engine.name)

View File

@ -24,7 +24,6 @@ SHOULD include dedicated exception logging.
from oslo.config import cfg from oslo.config import cfg
import six import six
from sqlalchemy import exc as sqa_exc
import webob.exc import webob.exc
from manila.openstack.common import log as logging from manila.openstack.common import log as logging
@ -57,30 +56,6 @@ class Error(Exception):
pass pass
class DBError(Error):
"""Wraps an implementation specific exception."""
def __init__(self, inner_exception=None):
self.inner_exception = inner_exception
super(DBError, self).__init__(str(inner_exception))
def wrap_db_error(f):
def _wrap(*args, **kwargs):
try:
return f(*args, **kwargs)
except UnicodeEncodeError:
raise InvalidUnicodeParameter()
except sqa_exc.IntegrityError as e:
raise Duplicate(message=str(e))
except Duplicate:
raise
except Exception as e:
LOG.exception(_('DB exception wrapped.'))
raise DBError(e)
_wrap.func_name = f.func_name
return _wrap
class ManilaException(Exception): class ManilaException(Exception):
"""Base Manila Exception """Base Manila Exception
@ -165,11 +140,6 @@ class InvalidContentType(Invalid):
message = _("Invalid content type %(content_type)s.") message = _("Invalid content type %(content_type)s.")
class InvalidUnicodeParameter(Invalid):
message = _("Invalid Parameter: "
"Unicode is not supported by the current database.")
# Cannot be templated as the error syntax varies. # Cannot be templated as the error syntax varies.
# msg needs to be constructed when raised. # msg needs to be constructed when raised.
class InvalidParameterValue(Invalid): class InvalidParameterValue(Invalid):
@ -286,11 +256,6 @@ class FileNotFound(NotFound):
message = _("File %(file_path)s could not be found.") message = _("File %(file_path)s could not be found.")
# TODO(bcwaldon): EOL this exception!
class Duplicate(ManilaException):
message = _("Duplicate entry: %(message)s")
class MigrationError(ManilaException): class MigrationError(ManilaException):
message = _("Migration error") + ": %(reason)s" message = _("Migration error") + ": %(reason)s"
@ -352,7 +317,7 @@ class PortLimitExceeded(QuotaError):
message = _("Maximum number of ports exceeded") message = _("Maximum number of ports exceeded")
class ShareAccessExists(Duplicate): class ShareAccessExists(ManilaException):
message = _("Share access %(access_type)s:%(access)s exists") message = _("Share access %(access_type)s:%(access)s exists")

View File

@ -260,9 +260,11 @@ class ShareManager(manager.SchedulerDependentManager):
try: try:
model_update = self.driver.create_snapshot( model_update = self.driver.create_snapshot(
context, snapshot_ref, share_server=share_server) context, snapshot_ref, share_server=share_server)
if model_update: if model_update:
model_dict = model_update.to_dict()
self.db.share_snapshot_update(context, snapshot_ref['id'], self.db.share_snapshot_update(context, snapshot_ref['id'],
model_update) model_dict)
except Exception: except Exception:
with excutils.save_and_reraise_exception(): with excutils.save_and_reraise_exception():

View File

@ -16,6 +16,7 @@
from oslo.config import cfg from oslo.config import cfg
from oslo.db import exception as db_exception
import six import six
from manila import context from manila import context
@ -33,7 +34,7 @@ def create(context, name, extra_specs={}):
type_ref = db.volume_type_create(context, type_ref = db.volume_type_create(context,
dict(name=name, dict(name=name,
extra_specs=extra_specs)) extra_specs=extra_specs))
except exception.DBError as e: except db_exception.DBError as e:
LOG.exception(_('DB error: %s') % e) LOG.exception(_('DB error: %s') % e)
raise exception.VolumeTypeCreateFailed(name=name, raise exception.VolumeTypeCreateFailed(name=name,
extra_specs=extra_specs) extra_specs=extra_specs)

View File

@ -34,7 +34,7 @@ import six
import testtools import testtools
from manila.db import migration from manila.db import migration
from manila.db.sqlalchemy import session as sqla_session from manila.db.sqlalchemy import api as db_api
from manila.openstack.common import importutils from manila.openstack.common import importutils
from manila.openstack.common import log as logging from manila.openstack.common import log as logging
from manila.openstack.common import timeutils from manila.openstack.common import timeutils
@ -129,9 +129,9 @@ class TestCase(testtools.TestCase):
global _DB_CACHE global _DB_CACHE
if not _DB_CACHE: if not _DB_CACHE:
_DB_CACHE = Database( _DB_CACHE = Database(
sqla_session, db_api,
migration, migration,
sql_connection=CONF.sql_connection, sql_connection=CONF.database.connection,
sqlite_db=CONF.sqlite_db, sqlite_db=CONF.sqlite_db,
sqlite_clean_db=CONF.sqlite_clean_db, sqlite_clean_db=CONF.sqlite_clean_db,
) )

View File

@ -14,6 +14,7 @@
# under the License. # under the License.
import mock import mock
from oslo.db import exception as db_exception
from webob import exc as webob_exc from webob import exc as webob_exc
from manila.api.v1 import share_networks from manila.api.v1 import share_networks
@ -108,7 +109,7 @@ class ShareNetworkAPITest(test.TestCase):
def test_create_db_api_exception(self): def test_create_db_api_exception(self):
with mock.patch.object(db_api, with mock.patch.object(db_api,
'share_network_create', 'share_network_create',
mock.Mock(side_effect=exception.DBError)): mock.Mock(side_effect=db_exception.DBError)):
self.assertRaises(webob_exc.HTTPBadRequest, self.assertRaises(webob_exc.HTTPBadRequest,
self.controller.create, self.controller.create,
self.req, self.req,
@ -274,7 +275,7 @@ class ShareNetworkAPITest(test.TestCase):
with mock.patch.object(db_api, with mock.patch.object(db_api,
'share_network_update', 'share_network_update',
mock.Mock(side_effect=exception.DBError)): mock.Mock(side_effect=db_exception.DBError)):
self.assertRaises(webob_exc.HTTPBadRequest, self.assertRaises(webob_exc.HTTPBadRequest,
self.controller.update, self.controller.update,
self.req, self.req,

View File

@ -22,7 +22,7 @@ CONF = cfg.CONF
def set_defaults(conf): def set_defaults(conf):
conf.set_default('connection_type', 'fake') conf.set_default('connection_type', 'fake')
conf.set_default('verbose', True) conf.set_default('verbose', True)
conf.set_default('sql_connection', "sqlite://") conf.set_default('connection', "sqlite://", group='database')
conf.set_default('sqlite_synchronous', False) conf.set_default('sqlite_synchronous', False)
conf.set_default('policy_file', 'manila/tests/policy.json') conf.set_default('policy_file', 'manila/tests/policy.json')
conf.set_default('share_export_ip', '0.0.0.0') conf.set_default('share_export_ip', '0.0.0.0')

View File

@ -13,6 +13,8 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from oslo.db import exception as db_exception
from manila.common import constants from manila.common import constants
from manila import context from manila import context
from manila.db import api as db_api from manila.db import api as db_api
@ -56,7 +58,7 @@ class SecurityServiceDBTest(test.TestCase):
db_api.security_service_create(self.fake_context, db_api.security_service_create(self.fake_context,
security_service_dict) security_service_dict)
self.assertRaises(exception.Duplicate, self.assertRaises(db_exception.DBDuplicateEntry,
db_api.security_service_create, db_api.security_service_create,
self.fake_context, self.fake_context,
security_service_dict) security_service_dict)

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from oslo.db import exception as db_exception
import six import six
from manila.common import constants from manila.common import constants
@ -82,7 +83,7 @@ class ShareNetworkDBTest(test.TestCase):
def test_create_with_duplicated_id(self): def test_create_with_duplicated_id(self):
db_api.share_network_create(self.fake_context, self.share_nw_dict) db_api.share_network_create(self.fake_context, self.share_nw_dict)
self.assertRaises(exception.Duplicate, self.assertRaises(db_exception.DBDuplicateEntry,
db_api.share_network_create, db_api.share_network_create,
self.fake_context, self.fake_context,
self.share_nw_dict) self.share_nw_dict)

View File

@ -119,11 +119,6 @@ class ManilaExceptionResponseCode400(test.TestCase):
self.assertEqual(e.code, 400) self.assertEqual(e.code, 400)
self.assertIn(content_type, e.msg) self.assertIn(content_type, e.msg)
def test_invalid_unicode_parameter(self):
# Verify response code for exception.InvalidUnicodeParameter
e = exception.InvalidUnicodeParameter()
self.assertEqual(e.code, 400)
def test_invalid_parameter_value(self): def test_invalid_parameter_value(self):
# Verify response code for exception.InvalidParameterValue # Verify response code for exception.InvalidParameterValue
err = "fake_err" err = "fake_err"

View File

@ -24,19 +24,18 @@ properly both upgrading and downgrading, and that no data loss occurs
if possible. if possible.
""" """
import commands
import ConfigParser import ConfigParser
import os import os
import urlparse import shutil
import uuid import tempfile
from migrate.versioning import api as migration_api
from migrate.versioning import repository from migrate.versioning import repository
from oslo.db.sqlalchemy import test_migrations
import sqlalchemy import sqlalchemy
import testtools import testtools
import manila.db.migration as migration
import manila.db.sqlalchemy.migrate_repo import manila.db.sqlalchemy.migrate_repo
from manila.db.sqlalchemy.migration import versioning_api as migration_api
from manila.openstack.common import log as logging from manila.openstack.common import log as logging
from manila import test from manila import test
@ -85,141 +84,46 @@ def _is_backend_avail(backend,
return True return True
def _have_mysql(): class TestMigrations(test.TestCase,
present = os.environ.get('NOVA_TEST_MYSQL_PRESENT') test_migrations.BaseMigrationTestCase,
if present is None: test_migrations.WalkVersionsMixin):
return _is_backend_avail('mysql')
return present.lower() in ('', 'true')
def get_table(engine, name):
"""Returns an sqlalchemy table dynamically from db.
Needed because the models don't work for us in migrations
as models will be far out of sync with the current data.
"""
metadata = sqlalchemy.schema.MetaData()
metadata.bind = engine
return sqlalchemy.Table(name, metadata, autoload=True)
class TestMigrations(test.TestCase):
"""Test sqlalchemy-migrate migrations.""" """Test sqlalchemy-migrate migrations."""
DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__), def __init__(self, *args, **kwargs):
'test_migrations.conf') super(TestMigrations, self).__init__(*args, **kwargs)
# Test machines can set the MANILA_TEST_MIGRATIONS_CONF variable
# to override the location of the config file for migration testing self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__),
CONFIG_FILE_PATH = os.environ.get('MANILA_TEST_MIGRATIONS_CONF', 'test_migrations.conf')
DEFAULT_CONFIG_FILE) # Test machines can set the MANILA_TEST_MIGRATIONS_CONF variable
MIGRATE_FILE = manila.db.sqlalchemy.migrate_repo.__file__ # to override the location of the config file for migration testing
REPOSITORY = repository.Repository( self.CONFIG_FILE_PATH = os.environ.get('MANILA_TEST_MIGRATIONS_CONF',
os.path.abspath(os.path.dirname(MIGRATE_FILE))) self.DEFAULT_CONFIG_FILE)
self.MIGRATE_FILE = manila.db.sqlalchemy.migrate_repo.__file__
self.REPOSITORY = repository.Repository(
os.path.abspath(os.path.dirname(self.MIGRATE_FILE)))
self.migration_api = migration_api
self.INIT_VERSION = 000
def setUp(self): def setUp(self):
super(TestMigrations, self).setUp() if not os.environ.get("OSLO_LOCK_PATH"):
lock_dir = tempfile.mkdtemp()
os.environ["OSLO_LOCK_PATH"] = lock_dir
self.addCleanup(self._cleanup)
self.snake_walk = False self.snake_walk = False
self.test_databases = {}
# Load test databases from the config file. Only do this
# once. No need to re-run this on each test...
LOG.debug('config_path is %s' % TestMigrations.CONFIG_FILE_PATH)
if not self.test_databases: if not self.test_databases:
if os.path.exists(TestMigrations.CONFIG_FILE_PATH): super(TestMigrations, self).setUp()
cp = ConfigParser.RawConfigParser() cp = ConfigParser.RawConfigParser()
try: try:
cp.read(TestMigrations.CONFIG_FILE_PATH) cp.read(self.CONFIG_FILE_PATH)
defaults = cp.defaults() self.snake_walk = cp.getboolean('walk_style', 'snake_walk')
for key, value in defaults.items(): except ConfigParser.ParsingError as e:
self.test_databases[key] = value
self.snake_walk = cp.getboolean('walk_style', 'snake_walk')
except ConfigParser.ParsingError as e:
self.fail("Failed to read test_migrations.conf config " self.fail("Failed to read test_migrations.conf config "
"file. Got error: %s" % e) "file. Got error: %s" % e)
else:
self.fail("Failed to find test_migrations.conf config "
"file.")
self.engines = {} def _cleanup(self):
for key, value in self.test_databases.items(): shutil.rmtree(os.environ["OSLO_LOCK_PATH"], ignore_errors=True)
self.engines[key] = sqlalchemy.create_engine(value) del os.environ["OSLO_LOCK_PATH"]
# We start each test case with a completely blank slate.
self._reset_databases()
def tearDown(self):
# We destroy the test data store between each test case,
# and recreate it, which ensures that we have no side-effects
# from the tests
self._reset_databases()
super(TestMigrations, self).tearDown()
def _reset_databases(self):
def execute_cmd(cmd=None):
status, output = commands.getstatusoutput(cmd)
LOG.debug(output)
self.assertEqual(0, status)
for key, engine in self.engines.items():
conn_string = self.test_databases[key]
conn_pieces = urlparse.urlparse(conn_string)
engine.dispose()
if conn_string.startswith('sqlite'):
# We can just delete the SQLite database, which is
# the easiest and cleanest solution
db_path = conn_pieces.path.strip('/')
if os.path.exists(db_path):
os.unlink(db_path)
# No need to recreate the SQLite DB. SQLite will
# create it for us if it's not there...
elif conn_string.startswith('mysql'):
# We can execute the MySQL client to destroy and re-create
# the MYSQL database, which is easier and less error-prone
# than using SQLAlchemy to do this via MetaData...trust me.
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
if auth_pieces[1].strip():
password = "-p\"%s\"" % auth_pieces[1]
sql = ("drop database if exists %(database)s; "
"create database %(database)s;") % locals()
cmd = ("mysql -u \"%(user)s\" %(password)s -h %(host)s "
"-e \"%(sql)s\"") % locals()
execute_cmd(cmd)
elif conn_string.startswith('postgresql'):
database = conn_pieces.path.strip('/')
loc_pieces = conn_pieces.netloc.split('@')
host = loc_pieces[1]
auth_pieces = loc_pieces[0].split(':')
user = auth_pieces[0]
password = ""
if len(auth_pieces) > 1:
password = auth_pieces[1].strip()
# note(krtaylor): File creation problems with tests in
# venv using .pgpass authentication, changed to
# PGPASSWORD environment variable which is no longer
# planned to be deprecated
os.environ['PGPASSWORD'] = password
os.environ['PGUSER'] = user
# note(boris-42): We must create and drop database, we can't
# drop database which we have connected to, so for such
# operations there is a special database template1.
sqlcmd = ("psql -w -U %(user)s -h %(host)s -c"
" '%(sql)s' -d template1")
sql = ("drop database if exists %(database)s;") % locals()
droptable = sqlcmd % locals()
execute_cmd(droptable)
sql = ("create database %(database)s;") % locals()
createtable = sqlcmd % locals()
execute_cmd(createtable)
os.unsetenv('PGPASSWORD')
os.unsetenv('PGUSER')
def test_walk_versions(self): def test_walk_versions(self):
""" """
@ -237,13 +141,17 @@ class TestMigrations(test.TestCase):
if _is_mysql_avail(user="openstack_cifail"): if _is_mysql_avail(user="openstack_cifail"):
self.fail("Shouldn't have connected") self.fail("Shouldn't have connected")
@testtools.skipUnless(_have_mysql(), "mysql not available") @testtools.skipUnless(test_migrations._have_mysql("openstack_citest",
"openstack_citest",
"openstack_citest"),
"mysql not available")
def test_mysql_innodb(self): def test_mysql_innodb(self):
""" """
Test that table creation on mysql only builds InnoDB tables Test that table creation on mysql only builds InnoDB tables
""" """
# add this to the global lists to make reset work with it, it's removed # add this to the global lists to make parent _reset_databases method
# automaticaly in tearDown so no need to clean it up here. # work with it, it's removed automaticaly in parent tearDown method so
# no need to clean it up here.
connect_string = _get_connect_string('mysql') connect_string = _get_connect_string('mysql')
engine = sqlalchemy.create_engine(connect_string) engine = sqlalchemy.create_engine(connect_string)
self.engines["mysqlcitest"] = engine self.engines["mysqlcitest"] = engine
@ -291,83 +199,3 @@ class TestMigrations(test.TestCase):
# build a fully populated postgresql database with all the tables # build a fully populated postgresql database with all the tables
self._reset_databases() self._reset_databases()
self._walk_versions(engine, False, False) self._walk_versions(engine, False, False)
def _walk_versions(self, engine=None, snake_walk=False, downgrade=True):
# Determine latest version script from the repo, then
# upgrade from 1 through to the latest, with no data
# in the databases. This just checks that the schema itself
# upgrades successfully.
# Place the database under version control
migration_api.version_control(engine,
TestMigrations.REPOSITORY,
migration.INIT_VERSION)
self.assertEqual(migration.INIT_VERSION,
migration_api.db_version(engine,
TestMigrations.REPOSITORY))
migration_api.upgrade(engine, TestMigrations.REPOSITORY,
migration.INIT_VERSION + 1)
LOG.debug('latest version is %s' % TestMigrations.REPOSITORY.latest)
for version in xrange(migration.INIT_VERSION + 2,
TestMigrations.REPOSITORY.latest + 1):
# upgrade -> downgrade -> upgrade
self._migrate_up(engine, version, with_data=True)
if snake_walk:
self._migrate_down(engine, version - 1)
self._migrate_up(engine, version)
if downgrade:
# Now walk it back down to 0 from the latest, testing
# the downgrade paths.
for version in reversed(
xrange(migration.INIT_VERSION + 1,
TestMigrations.REPOSITORY.latest)):
# downgrade -> upgrade -> downgrade
self._migrate_down(engine, version)
if snake_walk:
self._migrate_up(engine, version + 1)
self._migrate_down(engine, version)
def _migrate_down(self, engine, version):
migration_api.downgrade(engine,
TestMigrations.REPOSITORY,
version)
self.assertEqual(version,
migration_api.db_version(engine,
TestMigrations.REPOSITORY))
def _migrate_up(self, engine, version, with_data=False):
"""migrate up to a new version of the db.
We allow for data insertion and post checks at every
migration version with special _prerun_### and
_check_### functions in the main test.
"""
# NOTE(sdague): try block is here because it's impossible to debug
# where a failed data migration happens otherwise
try:
if with_data:
data = None
prerun = getattr(self, "_prerun_%3.3d" % version, None)
if prerun:
data = prerun(engine)
migration_api.upgrade(engine,
TestMigrations.REPOSITORY,
version)
self.assertEqual(
version,
migration_api.db_version(engine,
TestMigrations.REPOSITORY))
if with_data:
check = getattr(self, "_check_%3.3d" % version, None)
if check:
check(engine, data)
except Exception:
LOG.error("Failed to migrate to version %s on engine %s" %
(version, engine))
raise

View File

@ -9,6 +9,7 @@ kombu>=2.4.8
lockfile>=0.8 lockfile>=0.8
lxml>=2.3 lxml>=2.3
oslo.config>=1.2.1 oslo.config>=1.2.1
oslo.db>=0.2.0
oslo.messaging>=1.3.0 oslo.messaging>=1.3.0
paramiko>=1.13.0 paramiko>=1.13.0
Paste Paste

View File

@ -5,6 +5,7 @@ discover
fixtures>=0.3.14 fixtures>=0.3.14
mock>=1.0 mock>=1.0
MySQL-python MySQL-python
oslotest
psycopg2 psycopg2
python-subunit python-subunit
sphinx>=1.1.2,!=1.2.0,<1.3 sphinx>=1.1.2,!=1.2.0,<1.3