diff --git a/cinder/cmd/manage.py b/cinder/cmd/manage.py index 4aa8ed629a3..b89460be610 100644 --- a/cinder/cmd/manage.py +++ b/cinder/cmd/manage.py @@ -63,7 +63,6 @@ import time from oslo_config import cfg from oslo_db import exception as db_exc -from oslo_db.sqlalchemy import migration from oslo_log import log as logging from oslo_utils import timeutils import tabulate @@ -204,9 +203,7 @@ class DbCommands(object): def version(self): """Print the current database version.""" - print(migration.db_version(db_api.get_engine(), - db_migration.LEGACY_MIGRATIONS_PATH, - db_migration.INIT_VERSION)) + print(db_migration.db_version()) @args('age_in_days', type=int, help='Purge deleted rows older than age in days') diff --git a/cinder/db/migration.py b/cinder/db/migration.py index 32708524d1d..5c5458b644a 100644 --- a/cinder/db/migration.py +++ b/cinder/db/migration.py @@ -18,11 +18,16 @@ import os +from migrate import exceptions as migrate_exceptions +from migrate.versioning import api as migrate_api +from migrate.versioning import repository as migrate_repository from oslo_config import cfg +from oslo_db import exception from oslo_db import options -from oslo_db.sqlalchemy import migration +import sqlalchemy as sa from cinder.db.sqlalchemy import api as db_api +from cinder.i18n import _ options.set_defaults(cfg.CONF) @@ -33,13 +38,114 @@ LEGACY_MIGRATIONS_PATH = os.path.join( ) +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + if not os.path.exists(abs_path): + raise exception.DBMigrationError("Path %s not found" % abs_path) + return migrate_repository.Repository(abs_path) + + +def _migrate_db_version_control(engine, abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + + try: + migrate_api.version_control(engine, repository, version) + except migrate_exceptions.InvalidVersionError as ex: + raise exception.DBMigrationError("Invalid version : %s" % ex) + except migrate_exceptions.DatabaseAlreadyControlledError: + raise exception.DBMigrationError("Database is already controlled.") + + return version + + +def _migrate_db_version(engine, abs_path, init_version): + """Show the current version of the repository. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository + :param init_version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return migrate_api.db_version(engine, repository) + except migrate_exceptions.DatabaseNotControlledError: + pass + + meta = sa.MetaData() + meta.reflect(bind=engine) + tables = meta.tables + if ( + len(tables) == 0 or + 'alembic_version' in tables or + 'migrate_version' in tables + ): + _migrate_db_version_control(engine, abs_path, version=init_version) + return migrate_api.db_version(engine, repository) + + msg = _( + "The database is not under version control, but has tables. " + "Please stamp the current version of the schema manually." + ) + raise exception.DBMigrationError(msg) + + +def _migrate_db_sync(engine, abs_path, version=None, init_version=0): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param engine: SQLAlchemy engine instance for a given database + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest available version. + :param init_version: Initial database version + """ + + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DBMigrationError(_("version should be an integer")) + + current_version = _migrate_db_version(engine, abs_path, init_version) + repository = _find_migrate_repo(abs_path) + + if version is None or version > current_version: + try: + return migrate_api.upgrade(engine, repository, version) + except Exception as ex: + raise exception.DBMigrationError(ex) + else: + return migrate_api.downgrade(engine, repository, version) + + +def db_version(): + """Get database version.""" + return _migrate_db_version( + db_api.get_engine(), + LEGACY_MIGRATIONS_PATH, + INIT_VERSION) + + def db_sync(version=None, engine=None): """Migrate the database to `version` or the most recent version.""" if engine is None: engine = db_api.get_engine() - return migration.db_sync( + return _migrate_db_sync( engine=engine, abs_path=LEGACY_MIGRATIONS_PATH, version=version, diff --git a/cinder/tests/unit/db/test_migration.py b/cinder/tests/unit/db/test_migration.py new file mode 100644 index 00000000000..1a1de9da1ca --- /dev/null +++ b/cinder/tests/unit/db/test_migration.py @@ -0,0 +1,252 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import tempfile +from unittest import mock + +from migrate import exceptions as migrate_exception +from migrate.versioning import api as migrate_api +from migrate.versioning import repository as migrate_repository +from oslo_db import exception as db_exception +from oslo_db.sqlalchemy import enginefacade +from oslo_db.sqlalchemy import test_fixtures as db_fixtures +from oslotest import base as test_base +import sqlalchemy + +from cinder.db import migration +from cinder import utils + + +class TestMigrationCommon( + db_fixtures.OpportunisticDBTestMixin, test_base.BaseTestCase, +): + + def setUp(self): + super().setUp() + + self.engine = enginefacade.writer.get_engine() + + self.path = tempfile.mkdtemp('test_migration') + self.path1 = tempfile.mkdtemp('test_migration') + self.return_value = '/home/openstack/migrations' + self.return_value1 = '/home/extension/migrations' + self.init_version = 1 + self.test_version = 123 + + self.patcher_repo = mock.patch.object(migrate_repository, 'Repository') + self.repository = self.patcher_repo.start() + self.repository.side_effect = [self.return_value, self.return_value1] + + self.mock_api_db = mock.patch.object(migrate_api, 'db_version') + self.mock_api_db_version = self.mock_api_db.start() + self.mock_api_db_version.return_value = self.test_version + + def tearDown(self): + os.rmdir(self.path) + self.mock_api_db.stop() + self.patcher_repo.stop() + super().tearDown() + + def test_find_migrate_repo_path_not_found(self): + self.assertRaises( + db_exception.DBMigrationError, + migration._find_migrate_repo, + "/foo/bar/", + ) + + def test_find_migrate_repo_called_once(self): + my_repository = migration._find_migrate_repo(self.path) + self.repository.assert_called_once_with(self.path) + self.assertEqual(self.return_value, my_repository) + + def test_find_migrate_repo_called_few_times(self): + repo1 = migration._find_migrate_repo(self.path) + repo2 = migration._find_migrate_repo(self.path1) + self.assertNotEqual(repo1, repo2) + + def test_db_version_control(self): + with utils.nested_contexts( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(migrate_api, 'version_control'), + ) as (mock_find_repo, mock_version_control): + mock_find_repo.return_value = self.return_value + + version = migration._migrate_db_version_control( + self.engine, self.path, self.test_version) + + self.assertEqual(self.test_version, version) + mock_version_control.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + @mock.patch.object(migration, '_find_migrate_repo') + @mock.patch.object(migrate_api, 'version_control') + def test_db_version_control_version_less_than_actual_version( + self, mock_version_control, mock_find_repo, + ): + mock_find_repo.return_value = self.return_value + mock_version_control.side_effect = \ + migrate_exception.DatabaseAlreadyControlledError + self.assertRaises( + db_exception.DBMigrationError, + migration._migrate_db_version_control, self.engine, + self.path, self.test_version - 1) + + @mock.patch.object(migration, '_find_migrate_repo') + @mock.patch.object(migrate_api, 'version_control') + def test_db_version_control_version_greater_than_actual_version( + self, mock_version_control, mock_find_repo, + ): + mock_find_repo.return_value = self.return_value + mock_version_control.side_effect = \ + migrate_exception.InvalidVersionError + self.assertRaises( + db_exception.DBMigrationError, + migration._migrate_db_version_control, self.engine, + self.path, self.test_version + 1) + + def test_db_version_return(self): + ret_val = migration._migrate_db_version( + self.engine, self.path, self.init_version) + self.assertEqual(self.test_version, ret_val) + + def test_db_version_raise_not_controlled_error_first(self): + with mock.patch.object( + migration, '_migrate_db_version_control', + ) as mock_ver: + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), + self.test_version] + + ret_val = migration._migrate_db_version( + self.engine, self.path, self.init_version) + self.assertEqual(self.test_version, ret_val) + mock_ver.assert_called_once_with( + self.engine, self.path, version=self.init_version) + + def test_db_version_raise_not_controlled_error_tables(self): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = \ + migrate_exception.DatabaseNotControlledError('oups') + my_meta = mock.MagicMock() + my_meta.tables = {'a': 1, 'b': 2} + mock_meta.return_value = my_meta + + self.assertRaises( + db_exception.DBMigrationError, migration._migrate_db_version, + self.engine, self.path, self.init_version) + + @mock.patch.object(migrate_api, 'version_control') + def test_db_version_raise_not_controlled_error_no_tables(self, mock_vc): + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = ( + migrate_exception.DatabaseNotControlledError('oups'), + self.init_version) + my_meta = mock.MagicMock() + my_meta.tables = {} + mock_meta.return_value = my_meta + + migration._migrate_db_version( + self.engine, self.path, self.init_version) + + mock_vc.assert_called_once_with( + self.engine, self.return_value1, self.init_version) + + @mock.patch.object(migrate_api, 'version_control') + def test_db_version_raise_not_controlled_alembic_tables(self, mock_vc): + # When there are tables but the alembic control table + # (alembic_version) is present, attempt to version the db. + # This simulates the case where there is are multiple repos (different + # abs_paths) and a different path has been versioned already. + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), None] + my_meta = mock.MagicMock() + my_meta.tables = {'alembic_version': 1, 'b': 2} + mock_meta.return_value = my_meta + + migration._migrate_db_version( + self.engine, self.path, self.init_version) + + mock_vc.assert_called_once_with( + self.engine, self.return_value1, self.init_version) + + @mock.patch.object(migrate_api, 'version_control') + def test_db_version_raise_not_controlled_migrate_tables(self, mock_vc): + # When there are tables but the sqlalchemy-migrate control table + # (migrate_version) is present, attempt to version the db. + # This simulates the case where there is are multiple repos (different + # abs_paths) and a different path has been versioned already. + with mock.patch.object(sqlalchemy, 'MetaData') as mock_meta: + self.mock_api_db_version.side_effect = [ + migrate_exception.DatabaseNotControlledError('oups'), None] + my_meta = mock.MagicMock() + my_meta.tables = {'migrate_version': 1, 'b': 2} + mock_meta.return_value = my_meta + + migration._migrate_db_version( + self.engine, self.path, self.init_version) + + mock_vc.assert_called_once_with( + self.engine, self.return_value1, self.init_version) + + def test_db_sync_wrong_version(self): + self.assertRaises( + db_exception.DBMigrationError, + migration._migrate_db_sync, self.engine, self.path, 'foo') + + @mock.patch.object(migrate_api, 'upgrade') + def test_db_sync_script_not_present(self, upgrade): + # For non existent migration script file sqlalchemy-migrate will raise + # VersionNotFoundError which will be wrapped in DBMigrationError. + upgrade.side_effect = migrate_exception.VersionNotFoundError + self.assertRaises( + db_exception.DBMigrationError, + migration._migrate_db_sync, self.engine, self.path, + self.test_version + 1) + + @mock.patch.object(migrate_api, 'upgrade') + def test_db_sync_known_error_raised(self, upgrade): + upgrade.side_effect = migrate_exception.KnownError + self.assertRaises( + db_exception.DBMigrationError, + migration._migrate_db_sync, self.engine, self.path, + self.test_version + 1) + + def test_db_sync_upgrade(self): + init_ver = 55 + with utils.nested_contexts( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(migrate_api, 'upgrade') + ) as (mock_find_repo, mock_upgrade): + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version - 1 + + migration._migrate_db_sync( + self.engine, self.path, self.test_version, init_ver) + + mock_upgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) + + def test_db_sync_downgrade(self): + with utils.nested_contexts( + mock.patch.object(migration, '_find_migrate_repo'), + mock.patch.object(migrate_api, 'downgrade') + ) as (mock_find_repo, mock_downgrade): + mock_find_repo.return_value = self.return_value + self.mock_api_db_version.return_value = self.test_version + 1 + + migration._migrate_db_sync( + self.engine, self.path, self.test_version) + + mock_downgrade.assert_called_once_with( + self.engine, self.return_value, self.test_version) diff --git a/cinder/tests/unit/test_cmd.py b/cinder/tests/unit/test_cmd.py index b95c22c1f4e..9bd2a409be7 100644 --- a/cinder/tests/unit/test_cmd.py +++ b/cinder/tests/unit/test_cmd.py @@ -380,7 +380,7 @@ class TestCinderManageCmd(test.TestCase): self.assertEqual(cinder_manage.OVO_VERSION, service.object_current_version) - @mock.patch('oslo_db.sqlalchemy.migration.db_version') + @mock.patch('cinder.db.migration.db_version') def test_db_commands_version(self, db_version): db_cmds = cinder_manage.DbCommands() with mock.patch('sys.stdout', new=io.StringIO()): @@ -393,7 +393,7 @@ class TestCinderManageCmd(test.TestCase): exit = self.assertRaises(SystemExit, db_cmds.sync, version + 1) self.assertEqual(1, exit.code) - @mock.patch("oslo_db.sqlalchemy.migration.db_sync") + @mock.patch('cinder.db.migration.db_sync') def test_db_commands_script_not_present(self, db_sync): db_sync.side_effect = oslo_exception.DBMigrationError(None) db_cmds = cinder_manage.DbCommands() diff --git a/cinder/utils.py b/cinder/utils.py index 37c9ffa1f11..5aa4395c928 100644 --- a/cinder/utils.py +++ b/cinder/utils.py @@ -894,6 +894,12 @@ def create_ordereddict(adict: dict) -> OrderedDict: key=operator.itemgetter(0))) +@contextlib.contextmanager +def nested_contexts(*contexts): + with contextlib.ExitStack() as stack: + yield [stack.enter_context(c) for c in contexts] + + class Semaphore(object): """Custom semaphore to workaround eventlet issues with multiprocessing.""" def __init__(self, limit):