Merge "Add the corresponding DB context to all SQL transactions"

This commit is contained in:
Zuul 2022-04-08 13:08:32 +00:00 committed by Gerrit Code Review
commit 430abde13e
32 changed files with 275 additions and 158 deletions

View File

@ -77,7 +77,7 @@ def get_availability_zones_by_agent_type(context, agent_type,
availability_zones): availability_zones):
"""Get list of availability zones based on agent type""" """Get list of availability zones based on agent type"""
agents = agent_obj.Agent._get_agents_by_availability_zones_and_agent_type( agents = agent_obj.Agent.get_agents_by_availability_zones_and_agent_type(
context, agent_type=agent_type, availability_zones=availability_zones) context, agent_type=agent_type, availability_zones=availability_zones)
return set(agent.availability_zone for agent in agents) return set(agent.availability_zone for agent in agents)

View File

@ -259,6 +259,7 @@ class DbBasePluginCommon(object):
res.pop('bulk') res.pop('bulk')
return db_utils.resource_fields(res, fields) return db_utils.resource_fields(res, fields)
@db_api.CONTEXT_READER
def _get_network(self, context, id): def _get_network(self, context, id):
try: try:
network = model_query.get_by_id(context, models_v2.Network, id) network = model_query.get_by_id(context, models_v2.Network, id)
@ -266,6 +267,7 @@ class DbBasePluginCommon(object):
raise exceptions.NetworkNotFound(net_id=id) raise exceptions.NetworkNotFound(net_id=id)
return network return network
@db_api.CONTEXT_READER
def _network_exists(self, context, network_id): def _network_exists(self, context, network_id):
query = model_query.query_with_hooks( query = model_query.query_with_hooks(
context, models_v2.Network, field='id') context, models_v2.Network, field='id')
@ -284,6 +286,7 @@ class DbBasePluginCommon(object):
raise exceptions.SubnetPoolNotFound(subnetpool_id=id) raise exceptions.SubnetPoolNotFound(subnetpool_id=id)
return subnetpool return subnetpool
@db_api.CONTEXT_READER
def _get_port(self, context, id, lazy_fields=None): def _get_port(self, context, id, lazy_fields=None):
try: try:
port = model_query.get_by_id(context, models_v2.Port, id, port = model_query.get_by_id(context, models_v2.Port, id,

View File

@ -230,37 +230,39 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
tenant_to_check = policy['target_project'] tenant_to_check = policy['target_project']
if tenant_to_check: if tenant_to_check:
self.ensure_no_tenant_ports_on_network(net['id'], net['tenant_id'], self.ensure_no_tenant_ports_on_network(
tenant_to_check) context, net['id'], net['tenant_id'], tenant_to_check)
def ensure_no_tenant_ports_on_network(self, network_id, net_tenant_id, def ensure_no_tenant_ports_on_network(self, context, network_id,
tenant_id): net_tenant_id, tenant_id):
ctx_admin = ctx.get_admin_context() elevated = context.elevated()
ports = model_query.query_with_hooks(ctx_admin, models_v2.Port).filter( with db_api.CONTEXT_READER.using(elevated):
models_v2.Port.network_id == network_id) ports = model_query.query_with_hooks(
if tenant_id == '*': elevated, models_v2.Port).filter(
# for the wildcard we need to get all of the rbac entries to models_v2.Port.network_id == network_id)
# see if any allow the remaining ports on the network. if tenant_id == '*':
# any port with another RBAC entry covering it or one belonging to # for the wildcard we need to get all of the rbac entries to
# the same tenant as the network owner is ok # see if any allow the remaining ports on the network.
other_rbac_objs = network_obj.NetworkRBAC.get_objects( # any port with another RBAC entry covering it or one belonging
ctx_admin, object_id=network_id, action='access_as_shared') # to the same tenant as the network owner is ok
allowed_tenants = [rbac['target_project'] for rbac other_rbac_objs = network_obj.NetworkRBAC.get_objects(
in other_rbac_objs elevated, object_id=network_id, action='access_as_shared')
if rbac.target_project != tenant_id] allowed_tenants = [rbac['target_project'] for rbac
allowed_tenants.append(net_tenant_id) in other_rbac_objs
ports = ports.filter( if rbac.target_project != tenant_id]
~models_v2.Port.tenant_id.in_(allowed_tenants)) allowed_tenants.append(net_tenant_id)
else: ports = ports.filter(
# if there is a wildcard rule, we can return early because it ~models_v2.Port.tenant_id.in_(allowed_tenants))
# allows any ports else:
if network_obj.NetworkRBAC.get_object( # if there is a wildcard rule, we can return early because it
ctx_admin, object_id=network_id, action='access_as_shared', # allows any ports
target_project='*'): if network_obj.NetworkRBAC.get_object(
return elevated, object_id=network_id,
ports = ports.filter(models_v2.Port.project_id == tenant_id) action='access_as_shared', target_project='*'):
if ports.count(): return
raise exc.InvalidSharedSetting(network=network_id) ports = ports.filter(models_v2.Port.project_id == tenant_id)
if ports.count():
raise exc.InvalidSharedSetting(network=network_id)
def set_ipam_backend(self): def set_ipam_backend(self):
self.ipam = ipam_pluggable_backend.IpamPluggableBackend() self.ipam = ipam_pluggable_backend.IpamPluggableBackend()
@ -487,8 +489,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
registry.publish(resources.NETWORK, events.BEFORE_DELETE, self, registry.publish(resources.NETWORK, events.BEFORE_DELETE, self,
payload=events.DBEventPayload( payload=events.DBEventPayload(
context, resource_id=id)) context, resource_id=id))
self._ensure_network_not_in_use(context, id)
with db_api.CONTEXT_READER.using(context): with db_api.CONTEXT_READER.using(context):
self._ensure_network_not_in_use(context, id)
auto_delete_port_ids = [p.id for p in context.session.query( auto_delete_port_ids = [p.id for p in context.session.query(
models_v2.Port.id).filter_by(network_id=id).filter( models_v2.Port.id).filter_by(network_id=id).filter(
models_v2.Port.device_owner.in_( models_v2.Port.device_owner.in_(
@ -647,10 +649,9 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
s_gateway_ip != cur_subnet['gateway_ip'] and s_gateway_ip != cur_subnet['gateway_ip'] and
not ipv6_utils.is_ipv6_pd_enabled(s)): not ipv6_utils.is_ipv6_pd_enabled(s)):
gateway_ip = str(cur_subnet['gateway_ip']) gateway_ip = str(cur_subnet['gateway_ip'])
with db_api.CONTEXT_READER.using(context): alloc = port_obj.IPAllocation.get_alloc_routerports(
alloc = port_obj.IPAllocation.get_alloc_routerports( context, cur_subnet['id'], gateway_ip=gateway_ip,
context, cur_subnet['id'], gateway_ip=gateway_ip, first=True)
first=True)
if alloc and alloc.port_id: if alloc and alloc.port_id:
raise exc.GatewayIpInUse( raise exc.GatewayIpInUse(
@ -1593,6 +1594,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return query return query
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports(self, context, filters=None, fields=None, def get_ports(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None, sorts=None, limit=None, marker=None,
page_reverse=False): page_reverse=False):
@ -1612,6 +1614,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return items return items
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports_count(self, context, filters=None): def get_ports_count(self, context, filters=None):
return self._get_ports_query(context, filters).count() return self._get_ports_query(context, filters).count()

View File

@ -33,6 +33,7 @@ from neutron._i18n import _
from neutron.db import models_v2 from neutron.db import models_v2
from neutron.extensions import rbac as rbac_ext from neutron.extensions import rbac as rbac_ext
from neutron.objects import network as net_obj from neutron.objects import network as net_obj
from neutron.objects import ports as port_obj
from neutron.objects import router as l3_obj from neutron.objects import router as l3_obj
@ -127,9 +128,9 @@ class External_net_db_mixin(object):
# must make sure we do not have any external gateway ports # must make sure we do not have any external gateway ports
# (and thus, possible floating IPs) on this network before # (and thus, possible floating IPs) on this network before
# allow it to be update to external=False # allow it to be update to external=False
if context.session.query(models_v2.Port.id).filter_by( if port_obj.Port.count(
device_owner=constants.DEVICE_OWNER_ROUTER_GW, context, network_id=net_data['id'],
network_id=net_data['id']).first(): device_owner=constants.DEVICE_OWNER_ROUTER_GW):
raise extnet_exc.ExternalNetworkInUse(net_id=net_id) raise extnet_exc.ExternalNetworkInUse(net_id=net_id)
net_obj.ExternalNetwork.delete_objects( net_obj.ExternalNetwork.delete_objects(
@ -200,10 +201,9 @@ class External_net_db_mixin(object):
if new_project == policy['target_project']: if new_project == policy['target_project']:
# nothing to validate if the tenant didn't change # nothing to validate if the tenant didn't change
return return
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW, gw_ports = port_obj.Port.get_gateway_port_ids_by_network(
network_id=policy['object_id']) context, policy['object_id'])
gw_ports = [gw_port[0] for gw_port in gw_ports]
if policy['target_project'] != '*': if policy['target_project'] != '*':
filters = { filters = {
'gw_port_id': gw_ports, 'gw_port_id': gw_ports,

View File

@ -391,10 +391,9 @@ class L3AgentSchedulerDbMixin(l3agentscheduler.L3AgentSchedulerPluginBase,
rb_obj.RouterL3AgentBinding.get_l3_agents_by_router_ids( rb_obj.RouterL3AgentBinding.get_l3_agents_by_router_ids(
context, router_ids)) context, router_ids))
@db_api.CONTEXT_READER
def list_l3_agents_hosting_router(self, context, router_id): def list_l3_agents_hosting_router(self, context, router_id):
with db_api.CONTEXT_READER.using(context): agents = self._get_l3_agents_hosting_routers(context, [router_id])
agents = self._get_l3_agents_hosting_routers(
context, [router_id])
return {'agents': [self._make_agent_dict(agent) return {'agents': [self._make_agent_dict(agent)
for agent in agents]} for agent in agents]}

View File

@ -622,6 +622,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
return self._make_router_dict(router, fields) return self._make_router_dict(router, fields)
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers(self, context, filters=None, fields=None, def get_routers(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None, sorts=None, limit=None, marker=None,
page_reverse=False): page_reverse=False):
@ -636,6 +637,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
page_reverse=page_reverse) page_reverse=page_reverse)
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers_count(self, context, filters=None): def get_routers_count(self, context, filters=None):
return model_query.get_collection_count( return model_query.get_collection_count(
context, l3_models.Router, filters=filters, context, l3_models.Router, filters=filters,
@ -1365,7 +1367,8 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
fip_id = uuidutils.generate_uuid() fip_id = uuidutils.generate_uuid()
f_net_id = fip['floating_network_id'] f_net_id = fip['floating_network_id']
f_net_db = self._core_plugin._get_network(context, f_net_id) with db_api.CONTEXT_READER.using(context):
f_net_db = self._core_plugin._get_network(context, f_net_id)
if not f_net_db.external: if not f_net_db.external:
msg = _("Network %s is not a valid external network") % f_net_id msg = _("Network %s is not a valid external network") % f_net_id
raise n_exc.BadRequest(resource='floatingip', msg=msg) raise n_exc.BadRequest(resource='floatingip', msg=msg)
@ -1834,6 +1837,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
continue continue
yield port yield port
@db_api.CONTEXT_READER
def _get_subnets_by_network_list(self, context, network_ids): def _get_subnets_by_network_list(self, context, network_ids):
if not network_ids: if not network_ids:
return {} return {}

View File

@ -561,7 +561,6 @@ class L3_HA_NAT_db_mixin(l3_dvr_db.L3_NAT_with_dvr_db_mixin,
for agent in self.get_l3_agents_hosting_routers(context, [router_id]): for agent in self.get_l3_agents_hosting_routers(context, [router_id]):
self.remove_router_from_l3_agent(context, agent['id'], router_id) self.remove_router_from_l3_agent(context, agent['id'], router_id)
@db_api.CONTEXT_READER
def get_ha_router_port_bindings(self, context, router_ids, host=None): def get_ha_router_port_bindings(self, context, router_ids, host=None):
if not router_ids: if not router_ids:
return [] return []

View File

@ -60,6 +60,7 @@ class IpAvailabilityMixin(object):
total_ips_columns.append(mod.IPAllocationPool.last_ip) total_ips_columns.append(mod.IPAllocationPool.last_ip)
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_network_ip_availabilities(cls, context, filters=None): def get_network_ip_availabilities(cls, context, filters=None):
"""Get IP availability stats on a per subnet basis. """Get IP availability stats on a per subnet basis.

View File

@ -42,6 +42,7 @@ def add_model_for_resource(resource, model):
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def add_provisioning_component(context, object_id, object_type, entity): def add_provisioning_component(context, object_id, object_type, entity):
"""Adds a provisioning block by an entity to a given object. """Adds a provisioning block by an entity to a given object.
@ -77,6 +78,7 @@ def add_provisioning_component(context, object_id, object_type, entity):
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def remove_provisioning_component(context, object_id, object_type, entity, def remove_provisioning_component(context, object_id, object_type, entity,
standard_attr_id=None): standard_attr_id=None):
"""Remove a provisioning block for an object without triggering a callback. """Remove a provisioning block for an object without triggering a callback.
@ -125,26 +127,30 @@ def provisioning_complete(context, object_id, object_type, entity):
# tricking us into thinking there are remaining provisioning components # tricking us into thinking there are remaining provisioning components
if utils.is_session_active(context.session): if utils.is_session_active(context.session):
raise RuntimeError(_("Must not be called in a transaction")) raise RuntimeError(_("Must not be called in a transaction"))
standard_attr_id = _get_standard_attr_id(context, object_id, with db_api.CONTEXT_WRITER.using(context):
object_type) standard_attr_id = _get_standard_attr_id(context, object_id,
if not standard_attr_id: object_type)
return if not standard_attr_id:
if remove_provisioning_component(context, object_id, object_type, entity, return
standard_attr_id): if remove_provisioning_component(context, object_id, object_type,
LOG.debug("Provisioning for %(otype)s %(oid)s completed by entity " entity, standard_attr_id):
"%(entity)s.", log_dict) LOG.debug("Provisioning for %(otype)s %(oid)s completed by entity "
# now with that committed, check if any records are left. if None, emit "%(entity)s.", log_dict)
# an event that provisioning is complete. # now with that committed, check if any records are left. if None, emit
if not pb_obj.ProvisioningBlock.objects_exist( # an event that provisioning is complete.
context, standard_attr_id=standard_attr_id): if pb_obj.ProvisioningBlock.objects_exist(
LOG.debug("Provisioning complete for %(otype)s %(oid)s triggered by " context, standard_attr_id=standard_attr_id):
"entity %(entity)s.", log_dict) return
registry.publish(object_type, PROVISIONING_COMPLETE, entity,
payload=events.DBEventPayload( LOG.debug("Provisioning complete for %(otype)s %(oid)s triggered by "
context, resource_id=object_id)) "entity %(entity)s.", log_dict)
registry.publish(object_type, PROVISIONING_COMPLETE, entity,
payload=events.DBEventPayload(
context, resource_id=object_id))
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def is_object_blocked(context, object_id, object_type): def is_object_blocked(context, object_id, object_type):
"""Return boolean indicating if object has a provisioning block. """Return boolean indicating if object has a provisioning block.

View File

@ -44,6 +44,7 @@ class ReservationInfo(collections.namedtuple(
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_quota_usage_by_resource_and_project(context, resource, project_id): def get_quota_usage_by_resource_and_project(context, resource, project_id):
"""Return usage info for a given resource and project. """Return usage info for a given resource and project.

View File

@ -441,6 +441,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
"""Server-side RPC mixin using DB for SG notifications and responses.""" """Server-side RPC mixin using DB for SG notifications and responses."""
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_sg_ids_for_ports(self, context, ports): def _select_sg_ids_for_ports(self, context, ports):
if not ports: if not ports:
return [] return []
@ -451,6 +452,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all() return query.all()
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_rules_for_ports(self, context, ports): def _select_rules_for_ports(self, context, ports):
if not ports: if not ports:
return [] return []
@ -467,6 +469,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all() return query.all()
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_group(self, context, remote_group_ids): def _select_ips_for_remote_group(self, context, remote_group_ids):
ips_by_group = {} ips_by_group = {}
if not remote_group_ids: if not remote_group_ids:
@ -507,6 +510,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return ips_by_group return ips_by_group
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_address_group(self, context, def _select_ips_for_remote_address_group(self, context,
remote_address_group_ids): remote_address_group_ids):
ips_by_group = {} ips_by_group = {}

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
@ -53,6 +54,7 @@ class AddressScope(rbac_db.NeutronRbacObject):
} }
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_network_address_scope(cls, context, network_id, ip_version): def get_network_address_scope(cls, context, network_id, ip_version):
query = context.session.query(cls.db_model) query = context.session.query(cls.db_model)
query = query.join( query = query.join(

View File

@ -13,6 +13,7 @@
# under the License. # under the License.
from neutron_lib import constants as const from neutron_lib import constants as const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from neutron_lib.objects import utils as obj_utils from neutron_lib.objects import utils as obj_utils
from oslo_utils import versionutils from oslo_utils import versionutils
@ -122,11 +123,10 @@ class Agent(base.NeutronDbObject):
group_by(agent_model.Agent). group_by(agent_model.Agent).
filter(agent_model.Agent.id.in_(agent_ids)). filter(agent_model.Agent.id.in_(agent_ids)).
order_by('count')) order_by('count'))
agents = [cls._load_object(context, record[0]) for record in query] return [cls._load_object(context, record[0]) for record in query]
return agents
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ha_agents(cls, context, network_id=None, router_id=None): def get_ha_agents(cls, context, network_id=None, router_id=None):
if not (network_id or router_id): if not (network_id or router_id):
return [] return []
@ -154,7 +154,8 @@ class Agent(base.NeutronDbObject):
return agents return agents
@classmethod @classmethod
def _get_agents_by_availability_zones_and_agent_type( @db_api.CONTEXT_READER
def get_agents_by_availability_zones_and_agent_type(
cls, context, agent_type, availability_zones): cls, context, agent_type, availability_zones):
query = context.session.query( query = context.session.query(
agent_model.Agent).filter_by( agent_model.Agent).filter_by(

View File

@ -16,12 +16,15 @@ from collections import abc as collections_abc
import copy import copy
import functools import functools
import itertools import itertools
import sys
import traceback
from neutron_lib.db import api as db_api from neutron_lib.db import api as db_api
from neutron_lib.db import standard_attr from neutron_lib.db import standard_attr
from neutron_lib import exceptions as n_exc from neutron_lib import exceptions as n_exc
from neutron_lib.objects import exceptions as o_exc from neutron_lib.objects import exceptions as o_exc
from neutron_lib.objects.extensions import standardattributes from neutron_lib.objects.extensions import standardattributes
from oslo_config import cfg
from oslo_db import exception as obj_exc from oslo_db import exception as obj_exc
from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils from oslo_db.sqlalchemy import utils as db_utils
@ -39,10 +42,38 @@ from neutron.objects.db import api as obj_db_api
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
CONF = cfg.CONF
_NO_DB_MODEL = object() _NO_DB_MODEL = object()
# NOTE(ralonsoh): this is a method evaluated anytime an ORM session is
# executing a SQL transaction.
# If "autocommit" is disabled (the default value in SQLAlchemy 1.4 and the
# only value in SQLAlchemy 2.0) and there is not active transaction, that
# means the SQL transaction is being run on an "implicit transaction". Under
# autocommit, this transaction is created, executed and discarded immediately;
# under non-autocommit, a transaction must be explicitly created
# (writer/reader) and sticks open.
# This evaluation is done only in debug mode to monitor the Neutron code
# compliance to SQLAlchemy 2.0.
def do_orm_execute(orm_execute_state):
if not orm_execute_state.session.in_transaction():
trace_string = '\n'.join(traceback.format_stack(sys._getframe(1)))
LOG.warning('ORM session: SQL execution without transaction in '
'progress, traceback:\n%s', trace_string)
try:
_debug = cfg.CONF.debug
except cfg.NoSuchOptError:
_debug = False
if _debug:
db_api.sqla_listen(orm.Session, 'do_orm_execute', do_orm_execute)
def get_object_class_by_model(model): def get_object_class_by_model(model):
for obj_class in NeutronObjectRegistry.obj_classes().values(): for obj_class in NeutronObjectRegistry.obj_classes().values():
obj_class = obj_class[0] obj_class = obj_class[0]
@ -919,6 +950,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
self._captured_db_model = None self._captured_db_model = None
@classmethod @classmethod
@db_api.CONTEXT_READER
def count(cls, context, validate_filters=True, **kwargs): def count(cls, context, validate_filters=True, **kwargs):
"""Count the number of objects matching filtering criteria. """Count the number of objects matching filtering criteria.
@ -935,6 +967,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
) )
@classmethod @classmethod
@db_api.CONTEXT_READER
def objects_exist(cls, context, validate_filters=True, **kwargs): def objects_exist(cls, context, validate_filters=True, **kwargs):
"""Check if objects are present in DB. """Check if objects are present in DB.

View File

@ -13,6 +13,7 @@
# TODO(ihrachys): cover the module with functional tests targeting supported # TODO(ihrachys): cover the module with functional tests targeting supported
# backends # backends
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query from neutron_lib.db import model_query
from neutron_lib import exceptions as n_exc from neutron_lib import exceptions as n_exc
from neutron_lib.objects import utils as obj_utils from neutron_lib.objects import utils as obj_utils
@ -34,6 +35,7 @@ def get_object(obj_cls, context, **kwargs):
return _get_filter_query(obj_cls, context, **kwargs).first() return _get_filter_query(obj_cls, context, **kwargs).first()
@db_api.CONTEXT_READER
def count(obj_cls, context, query_field=None, query_limit=None, **kwargs): def count(obj_cls, context, query_field=None, query_limit=None, **kwargs):
if not query_field and obj_cls.primary_keys: if not query_field and obj_cls.primary_keys:
query_field = obj_cls.primary_keys[0] query_field = obj_cls.primary_keys[0]

View File

@ -13,6 +13,7 @@
# under the License. # under the License.
from neutron_lib import constants from neutron_lib import constants
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
@ -40,6 +41,7 @@ class L3HARouterAgentPortBinding(base.NeutronDbObject):
fields_no_update = ['router_id', 'port_id'] fields_no_update = ['router_id', 'port_id']
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_l3ha_filter_host_router(cls, context, router_ids, host): def get_l3ha_filter_host_router(cls, context, router_ids, host):
query = context.session.query(l3ha.L3HARouterAgentPortBinding) query = context.session.query(l3ha.L3HARouterAgentPortBinding)

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
import sqlalchemy as sa import sqlalchemy as sa
@ -42,6 +43,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
# TODO(ihrachys) return OVO objects not models # TODO(ihrachys) return OVO objects not models
# TODO(ihrachys) move under Agent object class # TODO(ihrachys) move under Agent object class
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_l3_agents_by_router_ids(cls, context, router_ids): def get_l3_agents_by_router_ids(cls, context, router_ids):
query = context.session.query(l3agent.RouterL3AgentBinding) query = context.session.query(l3agent.RouterL3AgentBinding)
query = query.options(joinedload('l3_agent')).filter( query = query.options(joinedload('l3_agent')).filter(
@ -49,6 +51,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
return [db_obj.l3_agent for db_obj in query.all()] return [db_obj.l3_agent for db_obj in query.all()]
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_down_router_bindings(cls, context, cutoff): def get_down_router_bindings(cls, context, cutoff):
query = (context.session.query( query = (context.session.query(
l3agent.RouterL3AgentBinding). l3agent.RouterL3AgentBinding).

View File

@ -15,6 +15,7 @@
import itertools import itertools
import netaddr import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from neutron.db.models import l3 from neutron.db.models import l3
@ -265,21 +266,23 @@ class PortForwarding(base.NeutronDbObject):
return result return result
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_port_forwarding_obj_by_routers(cls, context, router_ids): def get_port_forwarding_obj_by_routers(cls, context, router_ids):
query = context.session.query(cls.db_model, l3.FloatingIP) query = context.session.query(cls.db_model, l3.FloatingIP)
query = query.join(l3.FloatingIP, query = query.join(l3.FloatingIP,
cls.db_model.floatingip_id == l3.FloatingIP.id) cls.db_model.floatingip_id == l3.FloatingIP.id)
query = query.filter(l3.FloatingIP.router_id.in_(router_ids)) query = query.filter(l3.FloatingIP.router_id.in_(router_ids))
return cls._unique_port_forwarding_iterator(query) return cls._unique_port_forwarding(query)
@classmethod @staticmethod
def _unique_port_forwarding_iterator(cls, query): def _unique_port_forwarding(query):
q = query.order_by(l3.FloatingIP.router_id) q = query.order_by(l3.FloatingIP.router_id)
keyfunc = lambda row: row[1] keyfunc = lambda row: row[1]
group_iterator = itertools.groupby(q, keyfunc) group_iterator = itertools.groupby(q, keyfunc)
result = []
for key, value in group_iterator: for key, value in group_iterator:
for row in value: result.extend([(row[1]['router_id'], row[1]['floating_ip_address'],
yield (row[1]['router_id'], row[1]['floating_ip_address'], row[0]['id'], row[1]['id']) for row in value])
row[0]['id'], row[1]['id']) return result

View File

@ -243,6 +243,7 @@ class IPAllocation(base.NeutronDbObject):
alloc_obj.delete() alloc_obj.delete()
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_alloc_routerports(cls, context, subnet_id, gateway_ip=None, def get_alloc_routerports(cls, context, subnet_id, gateway_ip=None,
first=False): first=False):
alloc_qry = context.session.query(cls.db_model.port_id) alloc_qry = context.session.query(cls.db_model.port_id)
@ -466,6 +467,7 @@ class Port(base.NeutronDbObject):
return port_array return port_array
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_auto_deletable_port_ids_and_proper_port_count_by_segment( def get_auto_deletable_port_ids_and_proper_port_count_by_segment(
cls, context, segment_id): cls, context, segment_id):
@ -584,6 +586,7 @@ class Port(base.NeutronDbObject):
primitive.pop('device_profile', None) primitive.pop('device_profile', None)
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_network(cls, context, router_id, owner, def get_ports_by_router_and_network(cls, context, router_id, owner,
network_id): network_id):
"""Returns port objects filtering by router ID, owner and network ID""" """Returns port objects filtering by router ID, owner and network ID"""
@ -593,6 +596,7 @@ class Port(base.NeutronDbObject):
rports_filter, router_filter) rports_filter, router_filter)
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_port(cls, context, router_id, owner, port_id): def get_ports_by_router_and_port(cls, context, router_id, owner, port_id):
"""Returns port objects filtering by router ID, owner and port ID""" """Returns port objects filtering by router ID, owner and port ID"""
rports_filter = (l3.RouterPort.port_id == port_id, ) rports_filter = (l3.RouterPort.port_id == port_id, )
@ -645,6 +649,7 @@ class Port(base.NeutronDbObject):
return ports_rports return ports_rports
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_ids_by_security_groups(cls, context, security_group_ids, def get_ports_ids_by_security_groups(cls, context, security_group_ids,
excluded_device_owners=None): excluded_device_owners=None):
query = context.session.query(sg_models.SecurityGroupPortBinding) query = context.session.query(sg_models.SecurityGroupPortBinding)
@ -658,6 +663,7 @@ class Port(base.NeutronDbObject):
return [port_binding['port_id'] for port_binding in query.all()] return [port_binding['port_id'] for port_binding in query.all()]
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_host(cls, context, host): def get_ports_by_host(cls, context, host):
query = context.session.query(models_v2.Port.id).join( query = context.session.query(models_v2.Port.id).join(
ml2_models.PortBinding) ml2_models.PortBinding)
@ -666,6 +672,7 @@ class Port(base.NeutronDbObject):
return [port_id[0] for port_id in query.all()] return [port_id[0] for port_id in query.all()]
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_binding_type_and_host(cls, context, def get_ports_by_binding_type_and_host(cls, context,
binding_type, host): binding_type, host):
query = context.session.query(models_v2.Port).join( query = context.session.query(models_v2.Port).join(
@ -676,6 +683,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()] return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_vnic_type_and_host( def get_ports_by_vnic_type_and_host(
cls, context, vnic_type, host): cls, context, vnic_type, host):
query = context.session.query(models_v2.Port).join( query = context.session.query(models_v2.Port).join(
@ -686,6 +694,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()] return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod @classmethod
@db_api.CONTEXT_READER
def check_network_ports_by_binding_types( def check_network_ports_by_binding_types(
cls, context, network_id, binding_types, negative_search=False): cls, context, network_id, binding_types, negative_search=False):
"""This method is to check whether networks have ports with given """This method is to check whether networks have ports with given
@ -710,6 +719,7 @@ class Port(base.NeutronDbObject):
return bool(query.count()) return bool(query.count())
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_allocated_by_subnet_id(cls, context, subnet_id): def get_ports_allocated_by_subnet_id(cls, context, subnet_id):
"""Return ports with fixed IPs in a subnet""" """Return ports with fixed IPs in a subnet"""
return context.session.query(models_v2.Port).filter( return context.session.query(models_v2.Port).filter(
@ -731,3 +741,11 @@ class Port(base.NeutronDbObject):
for _binding in port.bindings: for _binding in port.bindings:
if _binding.get('profile', {}).get('pci_slot') == pci_slot: if _binding.get('profile', {}).get('pci_slot') == pci_slot:
return port return port
@classmethod
@db_api.CONTEXT_READER
def get_gateway_port_ids_by_network(cls, context, network_id):
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=network_id)
return [gw_port[0] for gw_port in gw_ports]

View File

@ -15,6 +15,7 @@
import abc import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from sqlalchemy import and_ from sqlalchemy import and_
from sqlalchemy import exists from sqlalchemy import exists
@ -55,6 +56,7 @@ class QosPolicyPortBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.port_id _bound_model_id = db_model.port_id
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_ports_by_network_id(cls, context, network_id, policy_id=None): def get_ports_by_network_id(cls, context, network_id, policy_id=None):
query = context.session.query(models_v2.Port).filter( query = context.session.query(models_v2.Port).filter(
models_v2.Port.network_id == network_id) models_v2.Port.network_id == network_id)
@ -103,6 +105,7 @@ class QosPolicyFloatingIPBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.fip_id _bound_model_id = db_model.fip_id
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_fips_by_network_id(cls, context, network_id, policy_id=None): def get_fips_by_network_id(cls, context, network_id, policy_id=None):
"""Return the FIP belonging to a network, filtered by a QoS policy """Return the FIP belonging to a network, filtered by a QoS policy

View File

@ -15,6 +15,7 @@
import abc import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from oslo_utils import versionutils from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
@ -39,6 +40,7 @@ class RBACBaseObject(base.NeutronDbObject, metaclass=abc.ABCMeta):
fields_no_update = ['id', 'project_id', 'object_id'] fields_no_update = ['id', 'project_id', 'object_id']
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_projects(cls, context, object_id=None, action=None, def get_projects(cls, context, object_id=None, action=None,
target_project=None): target_project=None):
clauses = [] clauses = []

View File

@ -18,6 +18,7 @@ import itertools
from neutron_lib.callbacks import events from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources from neutron_lib.callbacks import resources
from neutron_lib.db import api as db_api
from neutron_lib import exceptions from neutron_lib import exceptions
from sqlalchemy import and_ from sqlalchemy import and_
@ -104,6 +105,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
rbac_db_model.target_project != '*')))) rbac_db_model.target_project != '*'))))
@classmethod @classmethod
@db_api.CONTEXT_READER
def _validate_rbac_policy_delete(cls, context, obj_id, target_project): def _validate_rbac_policy_delete(cls, context, obj_id, target_project):
ctx_admin = context.elevated() ctx_admin = context.elevated()
rb_model = cls.rbac_db_cls.db_model rb_model = cls.rbac_db_cls.db_model
@ -147,13 +149,14 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
if policy['action'] != models.ACCESS_SHARED: if policy['action'] != models.ACCESS_SHARED:
return return
target_project = policy['target_project'] target_project = policy['target_project']
db_obj = obj_db_api.get_object( elevated_context = context.elevated()
cls, context.elevated(), id=policy['object_id']) with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if db_obj.project_id == target_project: if db_obj.project_id == target_project:
return return
cls._validate_rbac_policy_delete(context=context, cls._validate_rbac_policy_delete(context, policy['object_id'],
obj_id=policy['object_id'], target_project)
target_project=target_project)
@classmethod @classmethod
def validate_rbac_policy_create(cls, resource, event, trigger, def validate_rbac_policy_create(cls, resource, event, trigger,
@ -199,8 +202,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
# (hopefully) melded with this one. # (hopefully) melded with this one.
if object_type != cls.rbac_db_cls.db_model.object_type: if object_type != cls.rbac_db_cls.db_model.object_type:
return return
db_obj = obj_db_api.get_object( elevated_context = context.elevated()
cls, context.elevated(), id=policy['object_id']) with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE): if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE):
if (not context.is_admin and if (not context.is_admin and
db_obj['project_id'] != context.project_id): db_obj['project_id'] != context.project_id):
@ -225,23 +230,23 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
def update_shared(self, is_shared_new, obj_id): def update_shared(self, is_shared_new, obj_id):
admin_context = self.obj_context.elevated() admin_context = self.obj_context.elevated()
shared_prev = obj_db_api.get_object(self.rbac_db_cls, admin_context, with db_api.CONTEXT_WRITER.using(admin_context):
object_id=obj_id, shared_prev = obj_db_api.get_object(
target_project='*', self.rbac_db_cls, admin_context, object_id=obj_id,
action=models.ACCESS_SHARED) target_project='*', action=models.ACCESS_SHARED)
is_shared_prev = bool(shared_prev) is_shared_prev = bool(shared_prev)
if is_shared_prev == is_shared_new: if is_shared_prev == is_shared_new:
return return
# 'shared' goes False -> True # 'shared' goes False -> True
if not is_shared_prev and is_shared_new: if not is_shared_prev and is_shared_new:
self.attach_rbac(obj_id, self.obj_context.project_id) self.attach_rbac(obj_id, self.obj_context.project_id)
return return
# 'shared' goes True -> False is actually an attempt to delete # 'shared' goes True -> False is actually an attempt to delete
# rbac rule for sharing obj_id with target_project = '*' # rbac rule for sharing obj_id with target_project = '*'
self._validate_rbac_policy_delete(self.obj_context, obj_id, '*') self._validate_rbac_policy_delete(self.obj_context, obj_id, '*')
return self.obj_context.session.delete(shared_prev) return self.obj_context.session.delete(shared_prev)
def from_db_object(self, db_obj): def from_db_object(self, db_obj):
self._load_shared(db_obj) self._load_shared(db_obj)

View File

@ -17,6 +17,7 @@ import netaddr
from neutron_lib.api.definitions import availability_zone as az_def from neutron_lib.api.definitions import availability_zone as az_def
from neutron_lib.api.validators import availability_zone as az_validator from neutron_lib.api.validators import availability_zone as az_validator
from neutron_lib import constants as n_const from neutron_lib import constants as n_const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils from oslo_utils import versionutils
@ -108,6 +109,7 @@ class RouterExtraAttributes(base.NeutronDbObject):
return result return result
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_router_agents_count(cls, context): def get_router_agents_count(cls, context):
# TODO(sshank): This is pulled out from l3_agentschedulers_db.py # TODO(sshank): This is pulled out from l3_agentschedulers_db.py
# until a way to handle joins is figured out. # until a way to handle joins is figured out.
@ -146,6 +148,7 @@ class RouterPort(base.NeutronDbObject):
} }
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_router_ids_by_subnetpool(cls, context, subnetpool_id): def get_router_ids_by_subnetpool(cls, context, subnetpool_id):
query = context.session.query(l3.RouterPort.router_id) query = context.session.query(l3.RouterPort.router_id)
query = query.join(models_v2.Port) query = query.join(models_v2.Port)
@ -220,6 +223,7 @@ class Router(base.NeutronDbObject):
fields_no_update = ['project_id'] fields_no_update = ['project_id']
@classmethod @classmethod
@db_api.CONTEXT_READER
def check_routers_not_owned_by_projects(cls, context, gw_ports, projects): def check_routers_not_owned_by_projects(cls, context, gw_ports, projects):
"""This method is to check whether routers that aren't owned by """This method is to check whether routers that aren't owned by
existing projects or not existing projects or not
@ -376,6 +380,7 @@ class FloatingIP(base.NeutronDbObject):
primitive.pop('qos_network_policy_id', None) primitive.pop('qos_network_policy_id', None)
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_scoped_floating_ips(cls, context, router_ids): def get_scoped_floating_ips(cls, context, router_ids):
query = context.session.query(l3.FloatingIP, query = context.session.query(l3.FloatingIP,
models_v2.SubnetPool.address_scope_id) models_v2.SubnetPool.address_scope_id)
@ -410,6 +415,7 @@ class FloatingIP(base.NeutronDbObject):
yield (cls._load_object(context, row[0]), row[1]) yield (cls._load_object(context, row[0]), row[1])
@classmethod @classmethod
@db_api.CONTEXT_READER
def get_disassociated_ids_for_net(cls, context, network_id): def get_disassociated_ids_for_net(cls, context, network_id):
query = context.session.query(cls.db_model.id) query = context.session.query(cls.db_model.id)
query = query.filter_by( query = query.filter_by(

View File

@ -11,6 +11,7 @@
# under the License. # under the License.
from neutron_lib import context as context_lib from neutron_lib import context as context_lib
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils from oslo_utils import versionutils
@ -239,11 +240,14 @@ class SecurityGroupRule(base.NeutronDbObject):
- The rule belongs to a security group that belongs to the project_id - The rule belongs to a security group that belongs to the project_id
""" """
context = context_lib.get_admin_context() context = context_lib.get_admin_context()
query = context.session.query(cls.db_model.id) # NOTE(ralonsoh): do no use a READER decorator in this method. Elevated
query = query.join( # permissions are needed here.
SecurityGroup.db_model, with db_api.CONTEXT_READER.using(context):
cls.db_model.security_group_id == SecurityGroup.db_model.id) query = context.session.query(cls.db_model.id)
clauses = or_(SecurityGroup.db_model.project_id == project_id, query = query.join(
cls.db_model.project_id == project_id) SecurityGroup.db_model,
rule_ids = query.filter(clauses).all() cls.db_model.security_group_id == SecurityGroup.db_model.id)
return [rule_id[0] for rule_id in rule_ids] clauses = or_(SecurityGroup.db_model.project_id == project_id,
cls.db_model.project_id == project_id)
rule_ids = query.filter(clauses).all()
return [rule_id[0] for rule_id in rule_ids]

View File

@ -14,6 +14,7 @@
# under the License. # under the License.
import netaddr import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query from neutron_lib.db import model_query
from neutron_lib.objects import common_types from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
@ -123,21 +124,22 @@ class SubnetPool(rbac_db.NeutronRbacObject):
# Nothing to validate # Nothing to validate
return return
rbac_as_model = rbac_db_models.AddressScopeRBAC with db_api.CONTEXT_READER.using(context):
rbac_as_model = rbac_db_models.AddressScopeRBAC
# Ensure that target project has access to AS # Ensure that target project has access to AS
shared_to_target_project_or_to_all = ( shared_to_target_project_or_to_all = (
sa.and_( sa.and_(
rbac_as_model.target_project.in_( rbac_as_model.target_project.in_(
["*", policy['target_project']] ["*", policy['target_project']]
), ),
rbac_as_model.object_id == db_obj["address_scope_id"] rbac_as_model.object_id == db_obj["address_scope_id"]
)
) )
)
matching_policies = model_query.query_with_hooks( matching_policies = model_query.query_with_hooks(
context, rbac_db_models.AddressScopeRBAC context, rbac_db_models.AddressScopeRBAC
).filter(shared_to_target_project_or_to_all).count() ).filter(shared_to_target_project_or_to_all).count()
if matching_policies == 0: if matching_policies == 0:
raise ext_rbac.RbacPolicyInitError( raise ext_rbac.RbacPolicyInitError(

View File

@ -315,10 +315,9 @@ def _prevent_segment_delete_with_port_bound(resource, event, trigger,
# don't check for network deletes # don't check for network deletes
return return
with db_api.CONTEXT_READER.using(payload.context): auto_delete_port_ids, proper_port_count = port_obj.Port.\
auto_delete_port_ids, proper_port_count = port_obj.Port.\ get_auto_deletable_port_ids_and_proper_port_count_by_segment(
get_auto_deletable_port_ids_and_proper_port_count_by_segment( payload.context, segment_id=payload.resource_id)
payload.context, segment_id=payload.resource_id)
if proper_port_count: if proper_port_count:
reason = (_("The segment is still bound with %s port(s)") % reason = (_("The segment is still bound with %s port(s)") %

View File

@ -345,44 +345,54 @@ class EndpointTunnelTypeDriver(ML2TunnelTypeDriver):
def get_endpoint_by_host(self, host): def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host) LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_reader_session() ctx = context.get_admin_context()
return (session.query(self.endpoint_model). with db_api.CONTEXT_READER.using(ctx):
filter_by(host=host).first()) return (ctx.session.query(self.endpoint_model).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip): def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip) LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_reader_session() ctx = context.get_admin_context()
return (session.query(self.endpoint_model). with db_api.CONTEXT_READER.using(ctx):
filter_by(ip_address=ip).first()) return (ctx.session.query(self.endpoint_model).
filter_by(ip_address=ip).first())
def delete_endpoint(self, ip): def delete_endpoint(self, ip):
LOG.debug("delete_endpoint() called for ip %s", ip) LOG.debug("delete_endpoint() called for ip %s", ip)
session = db_api.get_writer_session() ctx = context.get_admin_context()
session.query(self.endpoint_model).filter_by(ip_address=ip).delete() with db_api.CONTEXT_WRITER.using(ctx):
ctx.session.query(self.endpoint_model).filter_by(
ip_address=ip).delete()
def delete_endpoint_by_host_or_ip(self, host, ip): def delete_endpoint_by_host_or_ip(self, host, ip):
LOG.debug("delete_endpoint_by_host_or_ip() called for " LOG.debug("delete_endpoint_by_host_or_ip() called for "
"host %(host)s or %(ip)s", {'host': host, 'ip': ip}) "host %(host)s or %(ip)s", {'host': host, 'ip': ip})
session = db_api.get_writer_session() ctx = context.get_admin_context()
session.query(self.endpoint_model).filter( with db_api.CONTEXT_WRITER.using(ctx):
or_(self.endpoint_model.host == host, ctx.session.query(self.endpoint_model).filter(
self.endpoint_model.ip_address == ip)).delete() or_(self.endpoint_model.host == host,
self.endpoint_model.ip_address == ip)).delete()
def _get_endpoints(self): def _get_endpoints(self):
LOG.debug("_get_endpoints() called") LOG.debug("_get_endpoints() called")
session = db_api.get_reader_session() ctx = context.get_admin_context()
return session.query(self.endpoint_model) with db_api.CONTEXT_READER.using(ctx):
return ctx.session.query(self.endpoint_model).all()
def _add_endpoint(self, ip, host, **kwargs): def _add_endpoint(self, ip, host, **kwargs):
LOG.debug("_add_endpoint() called for ip %s", ip) LOG.debug("_add_endpoint() called for ip %s", ip)
session = db_api.get_writer_session() ctx = context.get_admin_context()
try: try:
endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs) with db_api.CONTEXT_WRITER.using(ctx):
endpoint.save(session) endpoint = self.endpoint_model(ip_address=ip, host=host,
**kwargs)
endpoint.save(ctx.session)
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
endpoint = (session.query(self.endpoint_model). with db_api.CONTEXT_READER.using(ctx):
filter_by(ip_address=ip).one()) endpoint = (ctx.session.query(self.endpoint_model).
LOG.warning("Endpoint with ip %s already exists", ip) filter_by(ip_address=ip).one())
LOG.warning("Endpoint with ip %s already exists", ip)
return endpoint return endpoint

View File

@ -2017,12 +2017,13 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
@utils.transaction_guard @utils.transaction_guard
@db_api.retry_if_session_inactive() @db_api.retry_if_session_inactive()
def delete_port(self, context, id, l3_port_check=True): def delete_port(self, context, id, l3_port_check=True):
try: with db_api.CONTEXT_READER.using(context):
port_db = self._get_port(context, id) try:
port = self._make_port_dict(port_db) port_db = self._get_port(context, id)
except exc.PortNotFound: port = self._make_port_dict(port_db)
LOG.debug("The port '%s' was deleted", id) except exc.PortNotFound:
return LOG.debug("The port '%s' was deleted", id)
return
self._pre_delete_port(context, id, l3_port_check, port) self._pre_delete_port(context, id, l3_port_check, port)
# TODO(armax): get rid of the l3 dependency in the with block # TODO(armax): get rid of the l3 dependency in the with block

View File

@ -268,6 +268,7 @@ class TrackedResource(BaseResource):
# Update quota usage # Update quota usage
return self._resync(context, project_id, in_use) return self._resync(context, project_id, in_use)
@db_api.CONTEXT_WRITER
def count_used(self, context, project_id, resync_usage=True): def count_used(self, context, project_id, resync_usage=True):
"""Returns the current usage count for the resource. """Returns the current usage count for the resource.

View File

@ -50,6 +50,7 @@ class TagPlugin(tagging.TagPluginBase):
tags = [tag_db.tag for tag_db in db_data.standard_attr.tags] tags = [tag_db.tag for tag_db in db_data.standard_attr.tags]
response_data['tags'] = tags response_data['tags'] = tags
@db_api.CONTEXT_READER
def _get_resource(self, context, resource, resource_id): def _get_resource(self, context, resource, resource_id):
model = resource_model_map[resource] model = resource_model_map[resource]
try: try:

View File

@ -176,24 +176,24 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
def test_ensure_no_port_in_asterisk(self): def test_ensure_no_port_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*') self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_no_port_in_tenant_1(self): def test_ensure_no_port_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1) self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_port_in_tenant_2(self): def test_ensure_no_port_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_2) self.cxt, self.network_id, self.tenant_1, self.tenant_2)
def test_ensure_port_tenant_1_in_asterisk(self): def test_ensure_port_tenant_1_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True) self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id) self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*') self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_2_in_asterisk(self): def test_ensure_port_tenant_2_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
@ -201,21 +201,21 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id) self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting, self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network, self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, '*') self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_1_in_tenant_1(self): def test_ensure_port_tenant_1_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True) self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True) self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id) self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1) self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_1(self): def test_ensure_no_share_port_tenant_2_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, False) self._create_network(self.tenant_1, self.network_id, False)
self._create_subnet(self.tenant_1, self.subnet_1_id, True) self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_2, self.network_id, self.port_id) self._create_port(self.tenant_2, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network( self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1) self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_2(self): def test_ensure_no_share_port_tenant_2_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, False) self._create_network(self.tenant_1, self.network_id, False)
@ -223,4 +223,5 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id) self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting, self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network, self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, self.tenant_2) self.cxt, self.network_id, self.tenant_1,
self.tenant_2)

View File

@ -249,9 +249,7 @@ class RbacNeutronDbObjectTestCase(test_rbac.RBACBaseObjectIfaceTestCase,
'_get_projects_with_shared_access_to_db_obj') as sh_tids: '_get_projects_with_shared_access_to_db_obj') as sh_tids:
get_rbac_entries_mock.filter.return_value.count.return_value = 0 get_rbac_entries_mock.filter.return_value.count.return_value = 0
self._test_class._validate_rbac_policy_delete( self._test_class._validate_rbac_policy_delete(
context=context, context, 'fake_obj_id', 'fake_tid1')
obj_id='fake_obj_id',
target_project='fake_tid1')
sh_tids.assert_not_called() sh_tids.assert_not_called()
@mock.patch.object(_test_class, '_get_db_obj_rbac_entries') @mock.patch.object(_test_class, '_get_db_obj_rbac_entries')