Merge "db: Clean up migration code"
This commit is contained in:
@@ -37,10 +37,24 @@ LOG = logging.getLogger(__name__)
|
||||
def get_engine(database='main', context=None):
|
||||
if database == 'main':
|
||||
return db_session.get_engine(context=context)
|
||||
|
||||
if database == 'api':
|
||||
return db_session.get_api_engine()
|
||||
|
||||
|
||||
def find_migrate_repo(database='main'):
|
||||
"""Get the path for the migrate repository."""
|
||||
global _REPOSITORY
|
||||
rel_path = 'migrate_repo'
|
||||
if database == 'api':
|
||||
rel_path = os.path.join('api_migrations', 'migrate_repo')
|
||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), rel_path)
|
||||
assert os.path.exists(path)
|
||||
if _REPOSITORY.get(database) is None:
|
||||
_REPOSITORY[database] = Repository(path)
|
||||
return _REPOSITORY[database]
|
||||
|
||||
|
||||
def db_sync(version=None, database='main', context=None):
|
||||
"""Migrate the database to `version` or the most recent version."""
|
||||
if version is not None:
|
||||
@@ -50,18 +64,17 @@ def db_sync(version=None, database='main', context=None):
|
||||
raise exception.NovaException(_("version should be an integer"))
|
||||
|
||||
current_version = db_version(database, context=context)
|
||||
repository = _find_migrate_repo(database)
|
||||
repository = find_migrate_repo(database)
|
||||
engine = get_engine(database, context=context)
|
||||
if version is None or version > current_version:
|
||||
return versioning_api.upgrade(get_engine(database, context=context),
|
||||
repository, version)
|
||||
return versioning_api.upgrade(engine, repository, version)
|
||||
else:
|
||||
return versioning_api.downgrade(get_engine(database, context=context),
|
||||
repository, version)
|
||||
return versioning_api.downgrade(engine, repository, version)
|
||||
|
||||
|
||||
def db_version(database='main', context=None):
|
||||
"""Display the current database version."""
|
||||
repository = _find_migrate_repo(database)
|
||||
repository = find_migrate_repo(database)
|
||||
|
||||
# NOTE(mdbooth): This is a crude workaround for races in _db_version. The 2
|
||||
# races we have seen in practise are:
|
||||
@@ -96,20 +109,17 @@ def db_version(database='main', context=None):
|
||||
|
||||
|
||||
def _db_version(repository, database, context):
|
||||
engine = get_engine(database, context=context)
|
||||
try:
|
||||
return versioning_api.db_version(get_engine(database, context=context),
|
||||
repository)
|
||||
return versioning_api.db_version(engine, repository)
|
||||
except versioning_exceptions.DatabaseNotControlledError as exc:
|
||||
meta = sqlalchemy.MetaData()
|
||||
engine = get_engine(database, context=context)
|
||||
meta.reflect(bind=engine)
|
||||
tables = meta.tables
|
||||
if len(tables) == 0:
|
||||
db_version_control(INIT_VERSION[database],
|
||||
database,
|
||||
context=context)
|
||||
return versioning_api.db_version(
|
||||
get_engine(database, context=context), repository)
|
||||
db_version_control(
|
||||
INIT_VERSION[database], database, context=context)
|
||||
return versioning_api.db_version(engine, repository)
|
||||
else:
|
||||
LOG.exception(exc)
|
||||
# Some pre-Essex DB's may not be version controlled.
|
||||
@@ -124,22 +134,7 @@ def db_initial_version(database='main'):
|
||||
|
||||
|
||||
def db_version_control(version=None, database='main', context=None):
|
||||
repository = _find_migrate_repo(database)
|
||||
versioning_api.version_control(get_engine(database, context=context),
|
||||
repository,
|
||||
version)
|
||||
repository = find_migrate_repo(database)
|
||||
engine = get_engine(database, context=context)
|
||||
versioning_api.version_control(engine, repository, version)
|
||||
return version
|
||||
|
||||
|
||||
def _find_migrate_repo(database='main'):
|
||||
"""Get the path for the migrate repository."""
|
||||
global _REPOSITORY
|
||||
rel_path = 'migrate_repo'
|
||||
if database == 'api':
|
||||
rel_path = os.path.join('api_migrations', 'migrate_repo')
|
||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
|
||||
rel_path)
|
||||
assert os.path.exists(path)
|
||||
if _REPOSITORY.get(database) is None:
|
||||
_REPOSITORY[database] = Repository(path)
|
||||
return _REPOSITORY[database]
|
||||
|
||||
@@ -23,7 +23,7 @@ from nova import test
|
||||
|
||||
|
||||
@mock.patch.object(migration, 'db_version', return_value=2)
|
||||
@mock.patch.object(migration, '_find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(migration, 'find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(versioning_api, 'upgrade')
|
||||
@mock.patch.object(versioning_api, 'downgrade')
|
||||
@mock.patch.object(migration, 'get_engine', return_value='engine')
|
||||
@@ -50,7 +50,7 @@ class TestDbSync(test.NoDBTestCase):
|
||||
self.assertFalse(mock_upgrade.called)
|
||||
|
||||
|
||||
@mock.patch.object(migration, '_find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(migration, 'find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(versioning_api, 'db_version')
|
||||
@mock.patch.object(migration, 'get_engine')
|
||||
class TestDbVersion(test.NoDBTestCase):
|
||||
@@ -63,8 +63,9 @@ class TestDbVersion(test.NoDBTestCase):
|
||||
mock_find_repo.assert_called_once_with(database)
|
||||
mock_db_version.assert_called_once_with('engine', 'repo')
|
||||
|
||||
def test_not_controlled(self, mock_get_engine, mock_db_version,
|
||||
mock_find_repo):
|
||||
def test_not_controlled(
|
||||
self, mock_get_engine, mock_db_version, mock_find_repo,
|
||||
):
|
||||
database = 'api'
|
||||
mock_get_engine.side_effect = ['engine', 'engine', 'engine']
|
||||
exc = versioning_exceptions.DatabaseNotControlledError()
|
||||
@@ -79,7 +80,7 @@ class TestDbVersion(test.NoDBTestCase):
|
||||
migration.INIT_VERSION['api'], database, context=None)
|
||||
db_version_calls = [mock.call('engine', 'repo')] * 2
|
||||
self.assertEqual(db_version_calls, mock_db_version.call_args_list)
|
||||
engine_calls = [mock.call(database, context=None)] * 3
|
||||
engine_calls = [mock.call(database, context=None)]
|
||||
self.assertEqual(engine_calls, mock_get_engine.call_args_list)
|
||||
|
||||
def test_db_version_init_race(self, mock_get_engine, mock_db_version,
|
||||
@@ -104,7 +105,7 @@ class TestDbVersion(test.NoDBTestCase):
|
||||
migration.INIT_VERSION['api'], database, context=None)
|
||||
db_version_calls = [mock.call('engine', 'repo')] * 2
|
||||
self.assertEqual(db_version_calls, mock_db_version.call_args_list)
|
||||
engine_calls = [mock.call(database, context=None)] * 3
|
||||
engine_calls = [mock.call(database, context=None)] * 2
|
||||
self.assertEqual(engine_calls, mock_get_engine.call_args_list)
|
||||
|
||||
def test_db_version_raise_on_error(self, mock_get_engine, mock_db_version,
|
||||
@@ -127,7 +128,7 @@ class TestDbVersion(test.NoDBTestCase):
|
||||
migration.db_version, database)
|
||||
|
||||
|
||||
@mock.patch.object(migration, '_find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(migration, 'find_migrate_repo', return_value='repo')
|
||||
@mock.patch.object(migration, 'get_engine', return_value='engine')
|
||||
@mock.patch.object(versioning_api, 'version_control')
|
||||
class TestDbVersionControl(test.NoDBTestCase):
|
||||
|
||||
Reference in New Issue
Block a user