diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 4406fc6480..80dca9cacc 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -4659,6 +4659,7 @@ def console_get(context, console_id, instance_uuid=None): ################## +@main_context_manager.writer def flavor_create(context, values, projects=None): """Create a new instance type. In order to pass in extra specs, the values dict should contain a 'extra_specs' key/value pair: @@ -4682,21 +4683,19 @@ def flavor_create(context, values, projects=None): if projects is None: projects = [] - session = get_session() - with session.begin(): - try: - instance_type_ref.save() - except db_exc.DBDuplicateEntry as e: - if 'flavorid' in e.columns: - raise exception.FlavorIdExists(flavor_id=values['flavorid']) - raise exception.FlavorExists(name=values['name']) - except Exception as e: - raise db_exc.DBError(e) - for project in set(projects): - access_ref = models.InstanceTypeProjects() - access_ref.update({"instance_type_id": instance_type_ref.id, - "project_id": project}) - access_ref.save() + try: + instance_type_ref.save(context.session) + except db_exc.DBDuplicateEntry as e: + if 'flavorid' in e.columns: + raise exception.FlavorIdExists(flavor_id=values['flavorid']) + raise exception.FlavorExists(name=values['name']) + except Exception as e: + raise db_exc.DBError(e) + for project in set(projects): + access_ref = models.InstanceTypeProjects() + access_ref.update({"instance_type_id": instance_type_ref.id, + "project_id": project}) + access_ref.save(context.session) return _dict_with_extra_specs(instance_type_ref) @@ -4720,8 +4719,8 @@ def _dict_with_extra_specs(inst_type_query): return inst_type_dict -def _flavor_get_query(context, session=None, read_deleted=None): - query = model_query(context, models.InstanceTypes, session=session, +def _flavor_get_query(context, read_deleted=None): + query = model_query(context, models.InstanceTypes, read_deleted=read_deleted).\ options(joinedload('extra_specs')) if not context.is_admin: @@ -4734,6 +4733,7 @@ def _flavor_get_query(context, session=None, read_deleted=None): @require_context +@main_context_manager.reader def flavor_get_all(context, inactive=False, filters=None, sort_key='flavorid', sort_dir='asc', limit=None, marker=None): @@ -4791,23 +4791,22 @@ def flavor_get_all(context, inactive=False, filters=None, return [_dict_with_extra_specs(i) for i in inst_types] -def _flavor_get_id_from_flavor_query(context, flavor_id, session=None): +def _flavor_get_id_from_flavor_query(context, flavor_id): return model_query(context, models.InstanceTypes, (models.InstanceTypes.id,), - read_deleted="no", session=session).\ + read_deleted="no").\ filter_by(flavorid=flavor_id) -def _flavor_get_id_from_flavor(context, flavor_id, session=None): - result = _flavor_get_id_from_flavor_query(context, flavor_id, - session=session).\ - first() +def _flavor_get_id_from_flavor(context, flavor_id): + result = _flavor_get_id_from_flavor_query(context, flavor_id).first() if not result: raise exception.FlavorNotFound(flavor_id=flavor_id) return result[0] @require_context +@main_context_manager.reader def flavor_get(context, id): """Returns a dict describing specific flavor.""" result = _flavor_get_query(context).\ @@ -4819,6 +4818,7 @@ def flavor_get(context, id): @require_context +@main_context_manager.reader def flavor_get_by_name(context, name): """Returns a dict describing specific flavor.""" result = _flavor_get_query(context).\ @@ -4830,6 +4830,7 @@ def flavor_get_by_name(context, name): @require_context +@main_context_manager.reader def flavor_get_by_flavor_id(context, flavor_id, read_deleted): """Returns a dict describing specific flavor_id.""" result = _flavor_get_query(context, read_deleted=read_deleted).\ @@ -4842,43 +4843,40 @@ def flavor_get_by_flavor_id(context, flavor_id, read_deleted): return _dict_with_extra_specs(result) +@main_context_manager.writer def flavor_destroy(context, name): """Marks specific flavor as deleted.""" - session = get_session() - with session.begin(): - ref = model_query(context, models.InstanceTypes, session=session, - read_deleted="no").\ - filter_by(name=name).\ - first() - if not ref: - raise exception.FlavorNotFoundByName(flavor_name=name) + ref = model_query(context, models.InstanceTypes, read_deleted="no").\ + filter_by(name=name).\ + first() + if not ref: + raise exception.FlavorNotFoundByName(flavor_name=name) - ref.soft_delete(session=session) - model_query(context, models.InstanceTypeExtraSpecs, - session=session, read_deleted="no").\ - filter_by(instance_type_id=ref['id']).\ - soft_delete() - model_query(context, models.InstanceTypeProjects, - session=session, read_deleted="no").\ - filter_by(instance_type_id=ref['id']).\ - soft_delete() + ref.soft_delete(context.session) + model_query(context, models.InstanceTypeExtraSpecs, read_deleted="no").\ + filter_by(instance_type_id=ref['id']).\ + soft_delete() + model_query(context, models.InstanceTypeProjects, read_deleted="no").\ + filter_by(instance_type_id=ref['id']).\ + soft_delete() -def _flavor_access_query(context, session=None): - return model_query(context, models.InstanceTypeProjects, session=session, - read_deleted="no") +def _flavor_access_query(context): + return model_query(context, models.InstanceTypeProjects, read_deleted="no") +@main_context_manager.reader def flavor_access_get_by_flavor_id(context, flavor_id): """Get flavor access list by flavor id.""" - instance_type_id_subq = \ - _flavor_get_id_from_flavor_query(context, flavor_id) + instance_type_id_subq = _flavor_get_id_from_flavor_query(context, + flavor_id) access_refs = _flavor_access_query(context).\ filter_by(instance_type_id=instance_type_id_subq).\ all() return access_refs +@main_context_manager.writer def flavor_access_add(context, flavor_id, project_id): """Add given tenant to the flavor access list.""" instance_type_id = _flavor_get_id_from_flavor(context, flavor_id) @@ -4887,13 +4885,14 @@ def flavor_access_add(context, flavor_id, project_id): access_ref.update({"instance_type_id": instance_type_id, "project_id": project_id}) try: - access_ref.save() + access_ref.save(context.session) except db_exc.DBDuplicateEntry: raise exception.FlavorAccessExists(flavor_id=flavor_id, project_id=project_id) return access_ref +@main_context_manager.writer def flavor_access_remove(context, flavor_id, project_id): """Remove given tenant from the flavor access list.""" instance_type_id = _flavor_get_id_from_flavor(context, flavor_id) @@ -4907,22 +4906,24 @@ def flavor_access_remove(context, flavor_id, project_id): project_id=project_id) -def _flavor_extra_specs_get_query(context, flavor_id, session=None): - instance_type_id_subq = \ - _flavor_get_id_from_flavor_query(context, flavor_id) +def _flavor_extra_specs_get_query(context, flavor_id): + instance_type_id_subq = _flavor_get_id_from_flavor_query(context, + flavor_id) - return model_query(context, models.InstanceTypeExtraSpecs, session=session, + return model_query(context, models.InstanceTypeExtraSpecs, read_deleted="no").\ filter_by(instance_type_id=instance_type_id_subq) @require_context +@main_context_manager.reader def flavor_extra_specs_get(context, flavor_id): rows = _flavor_extra_specs_get_query(context, flavor_id).all() return {row['key']: row['value'] for row in rows} @require_context +@main_context_manager.writer def flavor_extra_specs_delete(context, flavor_id, key): result = _flavor_extra_specs_get_query(context, flavor_id).\ filter(models.InstanceTypeExtraSpecs.key == key).\ @@ -4934,34 +4935,34 @@ def flavor_extra_specs_delete(context, flavor_id, key): @require_context +@main_context_manager.writer def flavor_extra_specs_update_or_create(context, flavor_id, specs, max_retries=10): for attempt in range(max_retries): try: - session = get_session() - with session.begin(): - instance_type_id = _flavor_get_id_from_flavor(context, - flavor_id, session) + instance_type_id = _flavor_get_id_from_flavor(context, flavor_id) - spec_refs = model_query(context, models.InstanceTypeExtraSpecs, - session=session, read_deleted="no").\ - filter_by(instance_type_id=instance_type_id).\ - filter(models.InstanceTypeExtraSpecs.key.in_(specs.keys())).\ - all() + spec_refs = model_query(context, models.InstanceTypeExtraSpecs, + read_deleted="no").\ + filter_by(instance_type_id=instance_type_id).\ + filter(models.InstanceTypeExtraSpecs.key.in_(specs.keys())).\ + all() - existing_keys = set() - for spec_ref in spec_refs: - key = spec_ref["key"] - existing_keys.add(key) + existing_keys = set() + for spec_ref in spec_refs: + key = spec_ref["key"] + existing_keys.add(key) + with main_context_manager.writer.savepoint.using(context): spec_ref.update({"value": specs[key]}) - for key, value in specs.items(): - if key in existing_keys: - continue - spec_ref = models.InstanceTypeExtraSpecs() + for key, value in specs.items(): + if key in existing_keys: + continue + spec_ref = models.InstanceTypeExtraSpecs() + with main_context_manager.writer.savepoint.using(context): spec_ref.update({"key": key, "value": value, "instance_type_id": instance_type_id}) - session.add(spec_ref) + context.session.add(spec_ref) return specs except db_exc.DBDuplicateEntry: diff --git a/nova/exception.py b/nova/exception.py index c44a6b45e0..39c752674f 100644 --- a/nova/exception.py +++ b/nova/exception.py @@ -1183,7 +1183,7 @@ class FlavorAccessNotFound(NotFound): class FlavorExtraSpecUpdateCreateFailed(NovaException): - msg_fmt = _("Flavor %(id)d extra spec cannot be updated or created " + msg_fmt = _("Flavor %(id)s extra spec cannot be updated or created " "after %(retries)d retries.") diff --git a/nova/tests/unit/db/test_db_api.py b/nova/tests/unit/db/test_db_api.py index dae90ebb09..3c5fe1ca97 100644 --- a/nova/tests/unit/db/test_db_api.py +++ b/nova/tests/unit/db/test_db_api.py @@ -4001,6 +4001,16 @@ class InstanceTypeTestCase(BaseInstanceTypeTestCase): ignored_keys) self._assertEqualObjects(extra_specs, flavor['extra_specs']) + @mock.patch('sqlalchemy.orm.query.Query.all', return_value=[]) + def test_flavor_create_with_extra_specs_duplicate(self, mock_all): + extra_specs = dict(key='value') + flavorid = 'flavorid' + self._create_flavor({'flavorid': flavorid, 'extra_specs': extra_specs}) + + self.assertRaises(exception.FlavorExtraSpecUpdateCreateFailed, + db.flavor_extra_specs_update_or_create, + self.ctxt, flavorid, extra_specs) + def test_flavor_get_all(self): # NOTE(boris-42): Remove base instance types for it in db.flavor_get_all(self.ctxt): @@ -4283,7 +4293,7 @@ class InstanceTypeExtraSpecsTestCase(BaseInstanceTypeTestCase): def test_flavor_extra_specs_update_or_create_retry(self): def counted(): - def get_id(context, flavorid, session): + def get_id(context, flavorid): get_id.counter += 1 raise db_exc.DBDuplicateEntry get_id.counter = 0