diff --git a/nova/db/sqlalchemy/migration.py b/nova/db/sqlalchemy/migration.py index b4be0c9b77..5c0b41a7dc 100644 --- a/nova/db/sqlalchemy/migration.py +++ b/nova/db/sqlalchemy/migration.py @@ -37,10 +37,24 @@ LOG = logging.getLogger(__name__) def get_engine(database='main', context=None): if database == 'main': return db_session.get_engine(context=context) + if database == 'api': return db_session.get_api_engine() +def find_migrate_repo(database='main'): + """Get the path for the migrate repository.""" + global _REPOSITORY + rel_path = 'migrate_repo' + if database == 'api': + rel_path = os.path.join('api_migrations', 'migrate_repo') + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), rel_path) + assert os.path.exists(path) + if _REPOSITORY.get(database) is None: + _REPOSITORY[database] = Repository(path) + return _REPOSITORY[database] + + def db_sync(version=None, database='main', context=None): """Migrate the database to `version` or the most recent version.""" if version is not None: @@ -50,18 +64,17 @@ def db_sync(version=None, database='main', context=None): raise exception.NovaException(_("version should be an integer")) current_version = db_version(database, context=context) - repository = _find_migrate_repo(database) + repository = find_migrate_repo(database) + engine = get_engine(database, context=context) if version is None or version > current_version: - return versioning_api.upgrade(get_engine(database, context=context), - repository, version) + return versioning_api.upgrade(engine, repository, version) else: - return versioning_api.downgrade(get_engine(database, context=context), - repository, version) + return versioning_api.downgrade(engine, repository, version) def db_version(database='main', context=None): """Display the current database version.""" - repository = _find_migrate_repo(database) + repository = find_migrate_repo(database) # NOTE(mdbooth): This is a crude workaround for races in _db_version. The 2 # races we have seen in practise are: @@ -96,20 +109,17 @@ def db_version(database='main', context=None): def _db_version(repository, database, context): + engine = get_engine(database, context=context) try: - return versioning_api.db_version(get_engine(database, context=context), - repository) + return versioning_api.db_version(engine, repository) except versioning_exceptions.DatabaseNotControlledError as exc: meta = sqlalchemy.MetaData() - engine = get_engine(database, context=context) meta.reflect(bind=engine) tables = meta.tables if len(tables) == 0: - db_version_control(INIT_VERSION[database], - database, - context=context) - return versioning_api.db_version( - get_engine(database, context=context), repository) + db_version_control( + INIT_VERSION[database], database, context=context) + return versioning_api.db_version(engine, repository) else: LOG.exception(exc) # Some pre-Essex DB's may not be version controlled. @@ -124,22 +134,7 @@ def db_initial_version(database='main'): def db_version_control(version=None, database='main', context=None): - repository = _find_migrate_repo(database) - versioning_api.version_control(get_engine(database, context=context), - repository, - version) + repository = find_migrate_repo(database) + engine = get_engine(database, context=context) + versioning_api.version_control(engine, repository, version) return version - - -def _find_migrate_repo(database='main'): - """Get the path for the migrate repository.""" - global _REPOSITORY - rel_path = 'migrate_repo' - if database == 'api': - rel_path = os.path.join('api_migrations', 'migrate_repo') - path = os.path.join(os.path.abspath(os.path.dirname(__file__)), - rel_path) - assert os.path.exists(path) - if _REPOSITORY.get(database) is None: - _REPOSITORY[database] = Repository(path) - return _REPOSITORY[database] diff --git a/nova/tests/unit/db/test_sqlalchemy_migration.py b/nova/tests/unit/db/test_sqlalchemy_migration.py index e6f3c3f32c..bfeb72ef37 100644 --- a/nova/tests/unit/db/test_sqlalchemy_migration.py +++ b/nova/tests/unit/db/test_sqlalchemy_migration.py @@ -23,7 +23,7 @@ from nova import test @mock.patch.object(migration, 'db_version', return_value=2) -@mock.patch.object(migration, '_find_migrate_repo', return_value='repo') +@mock.patch.object(migration, 'find_migrate_repo', return_value='repo') @mock.patch.object(versioning_api, 'upgrade') @mock.patch.object(versioning_api, 'downgrade') @mock.patch.object(migration, 'get_engine', return_value='engine') @@ -50,7 +50,7 @@ class TestDbSync(test.NoDBTestCase): self.assertFalse(mock_upgrade.called) -@mock.patch.object(migration, '_find_migrate_repo', return_value='repo') +@mock.patch.object(migration, 'find_migrate_repo', return_value='repo') @mock.patch.object(versioning_api, 'db_version') @mock.patch.object(migration, 'get_engine') class TestDbVersion(test.NoDBTestCase): @@ -63,8 +63,9 @@ class TestDbVersion(test.NoDBTestCase): mock_find_repo.assert_called_once_with(database) mock_db_version.assert_called_once_with('engine', 'repo') - def test_not_controlled(self, mock_get_engine, mock_db_version, - mock_find_repo): + def test_not_controlled( + self, mock_get_engine, mock_db_version, mock_find_repo, + ): database = 'api' mock_get_engine.side_effect = ['engine', 'engine', 'engine'] exc = versioning_exceptions.DatabaseNotControlledError() @@ -79,7 +80,7 @@ class TestDbVersion(test.NoDBTestCase): migration.INIT_VERSION['api'], database, context=None) db_version_calls = [mock.call('engine', 'repo')] * 2 self.assertEqual(db_version_calls, mock_db_version.call_args_list) - engine_calls = [mock.call(database, context=None)] * 3 + engine_calls = [mock.call(database, context=None)] self.assertEqual(engine_calls, mock_get_engine.call_args_list) def test_db_version_init_race(self, mock_get_engine, mock_db_version, @@ -104,7 +105,7 @@ class TestDbVersion(test.NoDBTestCase): migration.INIT_VERSION['api'], database, context=None) db_version_calls = [mock.call('engine', 'repo')] * 2 self.assertEqual(db_version_calls, mock_db_version.call_args_list) - engine_calls = [mock.call(database, context=None)] * 3 + engine_calls = [mock.call(database, context=None)] * 2 self.assertEqual(engine_calls, mock_get_engine.call_args_list) def test_db_version_raise_on_error(self, mock_get_engine, mock_db_version, @@ -127,7 +128,7 @@ class TestDbVersion(test.NoDBTestCase): migration.db_version, database) -@mock.patch.object(migration, '_find_migrate_repo', return_value='repo') +@mock.patch.object(migration, 'find_migrate_repo', return_value='repo') @mock.patch.object(migration, 'get_engine', return_value='engine') @mock.patch.object(versioning_api, 'version_control') class TestDbVersionControl(test.NoDBTestCase):