Merge "Add the corresponding DB context to all SQL transactions"

This commit is contained in:
Zuul 2022-04-08 13:08:32 +00:00 committed by Gerrit Code Review
commit 430abde13e
32 changed files with 275 additions and 158 deletions

View File

@ -77,7 +77,7 @@ def get_availability_zones_by_agent_type(context, agent_type,
availability_zones):
"""Get list of availability zones based on agent type"""
agents = agent_obj.Agent._get_agents_by_availability_zones_and_agent_type(
agents = agent_obj.Agent.get_agents_by_availability_zones_and_agent_type(
context, agent_type=agent_type, availability_zones=availability_zones)
return set(agent.availability_zone for agent in agents)

View File

@ -259,6 +259,7 @@ class DbBasePluginCommon(object):
res.pop('bulk')
return db_utils.resource_fields(res, fields)
@db_api.CONTEXT_READER
def _get_network(self, context, id):
try:
network = model_query.get_by_id(context, models_v2.Network, id)
@ -266,6 +267,7 @@ class DbBasePluginCommon(object):
raise exceptions.NetworkNotFound(net_id=id)
return network
@db_api.CONTEXT_READER
def _network_exists(self, context, network_id):
query = model_query.query_with_hooks(
context, models_v2.Network, field='id')
@ -284,6 +286,7 @@ class DbBasePluginCommon(object):
raise exceptions.SubnetPoolNotFound(subnetpool_id=id)
return subnetpool
@db_api.CONTEXT_READER
def _get_port(self, context, id, lazy_fields=None):
try:
port = model_query.get_by_id(context, models_v2.Port, id,

View File

@ -230,21 +230,23 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
tenant_to_check = policy['target_project']
if tenant_to_check:
self.ensure_no_tenant_ports_on_network(net['id'], net['tenant_id'],
tenant_to_check)
self.ensure_no_tenant_ports_on_network(
context, net['id'], net['tenant_id'], tenant_to_check)
def ensure_no_tenant_ports_on_network(self, network_id, net_tenant_id,
tenant_id):
ctx_admin = ctx.get_admin_context()
ports = model_query.query_with_hooks(ctx_admin, models_v2.Port).filter(
def ensure_no_tenant_ports_on_network(self, context, network_id,
net_tenant_id, tenant_id):
elevated = context.elevated()
with db_api.CONTEXT_READER.using(elevated):
ports = model_query.query_with_hooks(
elevated, models_v2.Port).filter(
models_v2.Port.network_id == network_id)
if tenant_id == '*':
# for the wildcard we need to get all of the rbac entries to
# see if any allow the remaining ports on the network.
# any port with another RBAC entry covering it or one belonging to
# the same tenant as the network owner is ok
# any port with another RBAC entry covering it or one belonging
# to the same tenant as the network owner is ok
other_rbac_objs = network_obj.NetworkRBAC.get_objects(
ctx_admin, object_id=network_id, action='access_as_shared')
elevated, object_id=network_id, action='access_as_shared')
allowed_tenants = [rbac['target_project'] for rbac
in other_rbac_objs
if rbac.target_project != tenant_id]
@ -255,8 +257,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
# if there is a wildcard rule, we can return early because it
# allows any ports
if network_obj.NetworkRBAC.get_object(
ctx_admin, object_id=network_id, action='access_as_shared',
target_project='*'):
elevated, object_id=network_id,
action='access_as_shared', target_project='*'):
return
ports = ports.filter(models_v2.Port.project_id == tenant_id)
if ports.count():
@ -487,8 +489,8 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
registry.publish(resources.NETWORK, events.BEFORE_DELETE, self,
payload=events.DBEventPayload(
context, resource_id=id))
self._ensure_network_not_in_use(context, id)
with db_api.CONTEXT_READER.using(context):
self._ensure_network_not_in_use(context, id)
auto_delete_port_ids = [p.id for p in context.session.query(
models_v2.Port.id).filter_by(network_id=id).filter(
models_v2.Port.device_owner.in_(
@ -647,7 +649,6 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
s_gateway_ip != cur_subnet['gateway_ip'] and
not ipv6_utils.is_ipv6_pd_enabled(s)):
gateway_ip = str(cur_subnet['gateway_ip'])
with db_api.CONTEXT_READER.using(context):
alloc = port_obj.IPAllocation.get_alloc_routerports(
context, cur_subnet['id'], gateway_ip=gateway_ip,
first=True)
@ -1593,6 +1594,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return query
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
@ -1612,6 +1614,7 @@ class NeutronDbPluginV2(db_base_plugin_common.DbBasePluginCommon,
return items
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_ports_count(self, context, filters=None):
return self._get_ports_query(context, filters).count()

View File

@ -33,6 +33,7 @@ from neutron._i18n import _
from neutron.db import models_v2
from neutron.extensions import rbac as rbac_ext
from neutron.objects import network as net_obj
from neutron.objects import ports as port_obj
from neutron.objects import router as l3_obj
@ -127,9 +128,9 @@ class External_net_db_mixin(object):
# must make sure we do not have any external gateway ports
# (and thus, possible floating IPs) on this network before
# allow it to be update to external=False
if context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=net_data['id']).first():
if port_obj.Port.count(
context, network_id=net_data['id'],
device_owner=constants.DEVICE_OWNER_ROUTER_GW):
raise extnet_exc.ExternalNetworkInUse(net_id=net_id)
net_obj.ExternalNetwork.delete_objects(
@ -200,10 +201,9 @@ class External_net_db_mixin(object):
if new_project == policy['target_project']:
# nothing to validate if the tenant didn't change
return
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=policy['object_id'])
gw_ports = [gw_port[0] for gw_port in gw_ports]
gw_ports = port_obj.Port.get_gateway_port_ids_by_network(
context, policy['object_id'])
if policy['target_project'] != '*':
filters = {
'gw_port_id': gw_ports,

View File

@ -391,10 +391,9 @@ class L3AgentSchedulerDbMixin(l3agentscheduler.L3AgentSchedulerPluginBase,
rb_obj.RouterL3AgentBinding.get_l3_agents_by_router_ids(
context, router_ids))
@db_api.CONTEXT_READER
def list_l3_agents_hosting_router(self, context, router_id):
with db_api.CONTEXT_READER.using(context):
agents = self._get_l3_agents_hosting_routers(
context, [router_id])
agents = self._get_l3_agents_hosting_routers(context, [router_id])
return {'agents': [self._make_agent_dict(agent)
for agent in agents]}

View File

@ -622,6 +622,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
return self._make_router_dict(router, fields)
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers(self, context, filters=None, fields=None,
sorts=None, limit=None, marker=None,
page_reverse=False):
@ -636,6 +637,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
page_reverse=page_reverse)
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_routers_count(self, context, filters=None):
return model_query.get_collection_count(
context, l3_models.Router, filters=filters,
@ -1365,6 +1367,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
fip_id = uuidutils.generate_uuid()
f_net_id = fip['floating_network_id']
with db_api.CONTEXT_READER.using(context):
f_net_db = self._core_plugin._get_network(context, f_net_id)
if not f_net_db.external:
msg = _("Network %s is not a valid external network") % f_net_id
@ -1834,6 +1837,7 @@ class L3_NAT_dbonly_mixin(l3.RouterPluginBase,
continue
yield port
@db_api.CONTEXT_READER
def _get_subnets_by_network_list(self, context, network_ids):
if not network_ids:
return {}

View File

@ -561,7 +561,6 @@ class L3_HA_NAT_db_mixin(l3_dvr_db.L3_NAT_with_dvr_db_mixin,
for agent in self.get_l3_agents_hosting_routers(context, [router_id]):
self.remove_router_from_l3_agent(context, agent['id'], router_id)
@db_api.CONTEXT_READER
def get_ha_router_port_bindings(self, context, router_ids, host=None):
if not router_ids:
return []

View File

@ -60,6 +60,7 @@ class IpAvailabilityMixin(object):
total_ips_columns.append(mod.IPAllocationPool.last_ip)
@classmethod
@db_api.CONTEXT_READER
def get_network_ip_availabilities(cls, context, filters=None):
"""Get IP availability stats on a per subnet basis.

View File

@ -42,6 +42,7 @@ def add_model_for_resource(resource, model):
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def add_provisioning_component(context, object_id, object_type, entity):
"""Adds a provisioning block by an entity to a given object.
@ -77,6 +78,7 @@ def add_provisioning_component(context, object_id, object_type, entity):
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_WRITER
def remove_provisioning_component(context, object_id, object_type, entity,
standard_attr_id=None):
"""Remove a provisioning block for an object without triggering a callback.
@ -125,18 +127,21 @@ def provisioning_complete(context, object_id, object_type, entity):
# tricking us into thinking there are remaining provisioning components
if utils.is_session_active(context.session):
raise RuntimeError(_("Must not be called in a transaction"))
with db_api.CONTEXT_WRITER.using(context):
standard_attr_id = _get_standard_attr_id(context, object_id,
object_type)
if not standard_attr_id:
return
if remove_provisioning_component(context, object_id, object_type, entity,
standard_attr_id):
if remove_provisioning_component(context, object_id, object_type,
entity, standard_attr_id):
LOG.debug("Provisioning for %(otype)s %(oid)s completed by entity "
"%(entity)s.", log_dict)
# now with that committed, check if any records are left. if None, emit
# an event that provisioning is complete.
if not pb_obj.ProvisioningBlock.objects_exist(
if pb_obj.ProvisioningBlock.objects_exist(
context, standard_attr_id=standard_attr_id):
return
LOG.debug("Provisioning complete for %(otype)s %(oid)s triggered by "
"entity %(entity)s.", log_dict)
registry.publish(object_type, PROVISIONING_COMPLETE, entity,
@ -145,6 +150,7 @@ def provisioning_complete(context, object_id, object_type, entity):
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def is_object_blocked(context, object_id, object_type):
"""Return boolean indicating if object has a provisioning block.

View File

@ -44,6 +44,7 @@ class ReservationInfo(collections.namedtuple(
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def get_quota_usage_by_resource_and_project(context, resource, project_id):
"""Return usage info for a given resource and project.

View File

@ -441,6 +441,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
"""Server-side RPC mixin using DB for SG notifications and responses."""
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_sg_ids_for_ports(self, context, ports):
if not ports:
return []
@ -451,6 +452,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all()
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_rules_for_ports(self, context, ports):
if not ports:
return []
@ -467,6 +469,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return query.all()
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_group(self, context, remote_group_ids):
ips_by_group = {}
if not remote_group_ids:
@ -507,6 +510,7 @@ class SecurityGroupServerRpcMixin(SecurityGroupInfoAPIMixin,
return ips_by_group
@db_api.retry_if_session_inactive()
@db_api.CONTEXT_READER
def _select_ips_for_remote_address_group(self, context,
remote_address_group_ids):
ips_by_group = {}

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -53,6 +54,7 @@ class AddressScope(rbac_db.NeutronRbacObject):
}
@classmethod
@db_api.CONTEXT_READER
def get_network_address_scope(cls, context, network_id, ip_version):
query = context.session.query(cls.db_model)
query = query.join(

View File

@ -13,6 +13,7 @@
# under the License.
from neutron_lib import constants as const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.objects import utils as obj_utils
from oslo_utils import versionutils
@ -122,11 +123,10 @@ class Agent(base.NeutronDbObject):
group_by(agent_model.Agent).
filter(agent_model.Agent.id.in_(agent_ids)).
order_by('count'))
agents = [cls._load_object(context, record[0]) for record in query]
return agents
return [cls._load_object(context, record[0]) for record in query]
@classmethod
@db_api.CONTEXT_READER
def get_ha_agents(cls, context, network_id=None, router_id=None):
if not (network_id or router_id):
return []
@ -154,7 +154,8 @@ class Agent(base.NeutronDbObject):
return agents
@classmethod
def _get_agents_by_availability_zones_and_agent_type(
@db_api.CONTEXT_READER
def get_agents_by_availability_zones_and_agent_type(
cls, context, agent_type, availability_zones):
query = context.session.query(
agent_model.Agent).filter_by(

View File

@ -16,12 +16,15 @@ from collections import abc as collections_abc
import copy
import functools
import itertools
import sys
import traceback
from neutron_lib.db import api as db_api
from neutron_lib.db import standard_attr
from neutron_lib import exceptions as n_exc
from neutron_lib.objects import exceptions as o_exc
from neutron_lib.objects.extensions import standardattributes
from oslo_config import cfg
from oslo_db import exception as obj_exc
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils
@ -39,10 +42,38 @@ from neutron.objects.db import api as obj_db_api
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
_NO_DB_MODEL = object()
# NOTE(ralonsoh): this is a method evaluated anytime an ORM session is
# executing a SQL transaction.
# If "autocommit" is disabled (the default value in SQLAlchemy 1.4 and the
# only value in SQLAlchemy 2.0) and there is not active transaction, that
# means the SQL transaction is being run on an "implicit transaction". Under
# autocommit, this transaction is created, executed and discarded immediately;
# under non-autocommit, a transaction must be explicitly created
# (writer/reader) and sticks open.
# This evaluation is done only in debug mode to monitor the Neutron code
# compliance to SQLAlchemy 2.0.
def do_orm_execute(orm_execute_state):
if not orm_execute_state.session.in_transaction():
trace_string = '\n'.join(traceback.format_stack(sys._getframe(1)))
LOG.warning('ORM session: SQL execution without transaction in '
'progress, traceback:\n%s', trace_string)
try:
_debug = cfg.CONF.debug
except cfg.NoSuchOptError:
_debug = False
if _debug:
db_api.sqla_listen(orm.Session, 'do_orm_execute', do_orm_execute)
def get_object_class_by_model(model):
for obj_class in NeutronObjectRegistry.obj_classes().values():
obj_class = obj_class[0]
@ -919,6 +950,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
self._captured_db_model = None
@classmethod
@db_api.CONTEXT_READER
def count(cls, context, validate_filters=True, **kwargs):
"""Count the number of objects matching filtering criteria.
@ -935,6 +967,7 @@ class NeutronDbObject(NeutronObject, metaclass=DeclarativeObject):
)
@classmethod
@db_api.CONTEXT_READER
def objects_exist(cls, context, validate_filters=True, **kwargs):
"""Check if objects are present in DB.

View File

@ -13,6 +13,7 @@
# TODO(ihrachys): cover the module with functional tests targeting supported
# backends
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query
from neutron_lib import exceptions as n_exc
from neutron_lib.objects import utils as obj_utils
@ -34,6 +35,7 @@ def get_object(obj_cls, context, **kwargs):
return _get_filter_query(obj_cls, context, **kwargs).first()
@db_api.CONTEXT_READER
def count(obj_cls, context, query_field=None, query_limit=None, **kwargs):
if not query_field and obj_cls.primary_keys:
query_field = obj_cls.primary_keys[0]

View File

@ -13,6 +13,7 @@
# under the License.
from neutron_lib import constants
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -40,6 +41,7 @@ class L3HARouterAgentPortBinding(base.NeutronDbObject):
fields_no_update = ['router_id', 'port_id']
@classmethod
@db_api.CONTEXT_READER
def get_l3ha_filter_host_router(cls, context, router_ids, host):
query = context.session.query(l3ha.L3HARouterAgentPortBinding)

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
import sqlalchemy as sa
@ -42,6 +43,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
# TODO(ihrachys) return OVO objects not models
# TODO(ihrachys) move under Agent object class
@classmethod
@db_api.CONTEXT_READER
def get_l3_agents_by_router_ids(cls, context, router_ids):
query = context.session.query(l3agent.RouterL3AgentBinding)
query = query.options(joinedload('l3_agent')).filter(
@ -49,6 +51,7 @@ class RouterL3AgentBinding(base.NeutronDbObject):
return [db_obj.l3_agent for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_down_router_bindings(cls, context, cutoff):
query = (context.session.query(
l3agent.RouterL3AgentBinding).

View File

@ -15,6 +15,7 @@
import itertools
import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron.db.models import l3
@ -265,21 +266,23 @@ class PortForwarding(base.NeutronDbObject):
return result
@classmethod
@db_api.CONTEXT_READER
def get_port_forwarding_obj_by_routers(cls, context, router_ids):
query = context.session.query(cls.db_model, l3.FloatingIP)
query = query.join(l3.FloatingIP,
cls.db_model.floatingip_id == l3.FloatingIP.id)
query = query.filter(l3.FloatingIP.router_id.in_(router_ids))
return cls._unique_port_forwarding_iterator(query)
return cls._unique_port_forwarding(query)
@classmethod
def _unique_port_forwarding_iterator(cls, query):
@staticmethod
def _unique_port_forwarding(query):
q = query.order_by(l3.FloatingIP.router_id)
keyfunc = lambda row: row[1]
group_iterator = itertools.groupby(q, keyfunc)
result = []
for key, value in group_iterator:
for row in value:
yield (row[1]['router_id'], row[1]['floating_ip_address'],
row[0]['id'], row[1]['id'])
result.extend([(row[1]['router_id'], row[1]['floating_ip_address'],
row[0]['id'], row[1]['id']) for row in value])
return result

View File

@ -243,6 +243,7 @@ class IPAllocation(base.NeutronDbObject):
alloc_obj.delete()
@classmethod
@db_api.CONTEXT_READER
def get_alloc_routerports(cls, context, subnet_id, gateway_ip=None,
first=False):
alloc_qry = context.session.query(cls.db_model.port_id)
@ -466,6 +467,7 @@ class Port(base.NeutronDbObject):
return port_array
@classmethod
@db_api.CONTEXT_READER
def get_auto_deletable_port_ids_and_proper_port_count_by_segment(
cls, context, segment_id):
@ -584,6 +586,7 @@ class Port(base.NeutronDbObject):
primitive.pop('device_profile', None)
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_network(cls, context, router_id, owner,
network_id):
"""Returns port objects filtering by router ID, owner and network ID"""
@ -593,6 +596,7 @@ class Port(base.NeutronDbObject):
rports_filter, router_filter)
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_router_and_port(cls, context, router_id, owner, port_id):
"""Returns port objects filtering by router ID, owner and port ID"""
rports_filter = (l3.RouterPort.port_id == port_id, )
@ -645,6 +649,7 @@ class Port(base.NeutronDbObject):
return ports_rports
@classmethod
@db_api.CONTEXT_READER
def get_ports_ids_by_security_groups(cls, context, security_group_ids,
excluded_device_owners=None):
query = context.session.query(sg_models.SecurityGroupPortBinding)
@ -658,6 +663,7 @@ class Port(base.NeutronDbObject):
return [port_binding['port_id'] for port_binding in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_host(cls, context, host):
query = context.session.query(models_v2.Port.id).join(
ml2_models.PortBinding)
@ -666,6 +672,7 @@ class Port(base.NeutronDbObject):
return [port_id[0] for port_id in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_binding_type_and_host(cls, context,
binding_type, host):
query = context.session.query(models_v2.Port).join(
@ -676,6 +683,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_vnic_type_and_host(
cls, context, vnic_type, host):
query = context.session.query(models_v2.Port).join(
@ -686,6 +694,7 @@ class Port(base.NeutronDbObject):
return [cls._load_object(context, db_obj) for db_obj in query.all()]
@classmethod
@db_api.CONTEXT_READER
def check_network_ports_by_binding_types(
cls, context, network_id, binding_types, negative_search=False):
"""This method is to check whether networks have ports with given
@ -710,6 +719,7 @@ class Port(base.NeutronDbObject):
return bool(query.count())
@classmethod
@db_api.CONTEXT_READER
def get_ports_allocated_by_subnet_id(cls, context, subnet_id):
"""Return ports with fixed IPs in a subnet"""
return context.session.query(models_v2.Port).filter(
@ -731,3 +741,11 @@ class Port(base.NeutronDbObject):
for _binding in port.bindings:
if _binding.get('profile', {}).get('pci_slot') == pci_slot:
return port
@classmethod
@db_api.CONTEXT_READER
def get_gateway_port_ids_by_network(cls, context, network_id):
gw_ports = context.session.query(models_v2.Port.id).filter_by(
device_owner=constants.DEVICE_OWNER_ROUTER_GW,
network_id=network_id)
return [gw_port[0] for gw_port in gw_ports]

View File

@ -15,6 +15,7 @@
import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from sqlalchemy import and_
from sqlalchemy import exists
@ -55,6 +56,7 @@ class QosPolicyPortBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.port_id
@classmethod
@db_api.CONTEXT_READER
def get_ports_by_network_id(cls, context, network_id, policy_id=None):
query = context.session.query(models_v2.Port).filter(
models_v2.Port.network_id == network_id)
@ -103,6 +105,7 @@ class QosPolicyFloatingIPBinding(base.NeutronDbObject, _QosPolicyBindingMixin):
_bound_model_id = db_model.fip_id
@classmethod
@db_api.CONTEXT_READER
def get_fips_by_network_id(cls, context, network_id, policy_id=None):
"""Return the FIP belonging to a network, filtered by a QoS policy

View File

@ -15,6 +15,7 @@
import abc
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from oslo_utils import versionutils
from oslo_versionedobjects import fields as obj_fields
@ -39,6 +40,7 @@ class RBACBaseObject(base.NeutronDbObject, metaclass=abc.ABCMeta):
fields_no_update = ['id', 'project_id', 'object_id']
@classmethod
@db_api.CONTEXT_READER
def get_projects(cls, context, object_id=None, action=None,
target_project=None):
clauses = []

View File

@ -18,6 +18,7 @@ import itertools
from neutron_lib.callbacks import events
from neutron_lib.callbacks import registry
from neutron_lib.callbacks import resources
from neutron_lib.db import api as db_api
from neutron_lib import exceptions
from sqlalchemy import and_
@ -104,6 +105,7 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
rbac_db_model.target_project != '*'))))
@classmethod
@db_api.CONTEXT_READER
def _validate_rbac_policy_delete(cls, context, obj_id, target_project):
ctx_admin = context.elevated()
rb_model = cls.rbac_db_cls.db_model
@ -147,13 +149,14 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
if policy['action'] != models.ACCESS_SHARED:
return
target_project = policy['target_project']
db_obj = obj_db_api.get_object(
cls, context.elevated(), id=policy['object_id'])
elevated_context = context.elevated()
with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if db_obj.project_id == target_project:
return
cls._validate_rbac_policy_delete(context=context,
obj_id=policy['object_id'],
target_project=target_project)
cls._validate_rbac_policy_delete(context, policy['object_id'],
target_project)
@classmethod
def validate_rbac_policy_create(cls, resource, event, trigger,
@ -199,8 +202,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
# (hopefully) melded with this one.
if object_type != cls.rbac_db_cls.db_model.object_type:
return
db_obj = obj_db_api.get_object(
cls, context.elevated(), id=policy['object_id'])
elevated_context = context.elevated()
with db_api.CONTEXT_READER.using(elevated_context):
db_obj = obj_db_api.get_object(cls, elevated_context,
id=policy['object_id'])
if event in (events.BEFORE_CREATE, events.BEFORE_UPDATE):
if (not context.is_admin and
db_obj['project_id'] != context.project_id):
@ -225,10 +230,10 @@ class RbacNeutronDbObjectMixin(rbac_db_mixin.RbacPluginMixin,
def update_shared(self, is_shared_new, obj_id):
admin_context = self.obj_context.elevated()
shared_prev = obj_db_api.get_object(self.rbac_db_cls, admin_context,
object_id=obj_id,
target_project='*',
action=models.ACCESS_SHARED)
with db_api.CONTEXT_WRITER.using(admin_context):
shared_prev = obj_db_api.get_object(
self.rbac_db_cls, admin_context, object_id=obj_id,
target_project='*', action=models.ACCESS_SHARED)
is_shared_prev = bool(shared_prev)
if is_shared_prev == is_shared_new:
return

View File

@ -17,6 +17,7 @@ import netaddr
from neutron_lib.api.definitions import availability_zone as az_def
from neutron_lib.api.validators import availability_zone as az_validator
from neutron_lib import constants as n_const
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils
@ -108,6 +109,7 @@ class RouterExtraAttributes(base.NeutronDbObject):
return result
@classmethod
@db_api.CONTEXT_READER
def get_router_agents_count(cls, context):
# TODO(sshank): This is pulled out from l3_agentschedulers_db.py
# until a way to handle joins is figured out.
@ -146,6 +148,7 @@ class RouterPort(base.NeutronDbObject):
}
@classmethod
@db_api.CONTEXT_READER
def get_router_ids_by_subnetpool(cls, context, subnetpool_id):
query = context.session.query(l3.RouterPort.router_id)
query = query.join(models_v2.Port)
@ -220,6 +223,7 @@ class Router(base.NeutronDbObject):
fields_no_update = ['project_id']
@classmethod
@db_api.CONTEXT_READER
def check_routers_not_owned_by_projects(cls, context, gw_ports, projects):
"""This method is to check whether routers that aren't owned by
existing projects or not
@ -376,6 +380,7 @@ class FloatingIP(base.NeutronDbObject):
primitive.pop('qos_network_policy_id', None)
@classmethod
@db_api.CONTEXT_READER
def get_scoped_floating_ips(cls, context, router_ids):
query = context.session.query(l3.FloatingIP,
models_v2.SubnetPool.address_scope_id)
@ -410,6 +415,7 @@ class FloatingIP(base.NeutronDbObject):
yield (cls._load_object(context, row[0]), row[1])
@classmethod
@db_api.CONTEXT_READER
def get_disassociated_ids_for_net(cls, context, network_id):
query = context.session.query(cls.db_model.id)
query = query.filter_by(

View File

@ -11,6 +11,7 @@
# under the License.
from neutron_lib import context as context_lib
from neutron_lib.db import api as db_api
from neutron_lib.objects import common_types
from neutron_lib.utils import net as net_utils
from oslo_utils import versionutils
@ -239,6 +240,9 @@ class SecurityGroupRule(base.NeutronDbObject):
- The rule belongs to a security group that belongs to the project_id
"""
context = context_lib.get_admin_context()
# NOTE(ralonsoh): do no use a READER decorator in this method. Elevated
# permissions are needed here.
with db_api.CONTEXT_READER.using(context):
query = context.session.query(cls.db_model.id)
query = query.join(
SecurityGroup.db_model,

View File

@ -14,6 +14,7 @@
# under the License.
import netaddr
from neutron_lib.db import api as db_api
from neutron_lib.db import model_query
from neutron_lib.objects import common_types
from oslo_versionedobjects import fields as obj_fields
@ -123,6 +124,7 @@ class SubnetPool(rbac_db.NeutronRbacObject):
# Nothing to validate
return
with db_api.CONTEXT_READER.using(context):
rbac_as_model = rbac_db_models.AddressScopeRBAC
# Ensure that target project has access to AS

View File

@ -315,7 +315,6 @@ def _prevent_segment_delete_with_port_bound(resource, event, trigger,
# don't check for network deletes
return
with db_api.CONTEXT_READER.using(payload.context):
auto_delete_port_ids, proper_port_count = port_obj.Port.\
get_auto_deletable_port_ids_and_proper_port_count_by_segment(
payload.context, segment_id=payload.resource_id)

View File

@ -345,42 +345,52 @@ class EndpointTunnelTypeDriver(ML2TunnelTypeDriver):
def get_endpoint_by_host(self, host):
LOG.debug("get_endpoint_by_host() called for host %s", host)
session = db_api.get_reader_session()
return (session.query(self.endpoint_model).
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return (ctx.session.query(self.endpoint_model).
filter_by(host=host).first())
def get_endpoint_by_ip(self, ip):
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
session = db_api.get_reader_session()
return (session.query(self.endpoint_model).
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return (ctx.session.query(self.endpoint_model).
filter_by(ip_address=ip).first())
def delete_endpoint(self, ip):
LOG.debug("delete_endpoint() called for ip %s", ip)
session = db_api.get_writer_session()
session.query(self.endpoint_model).filter_by(ip_address=ip).delete()
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
ctx.session.query(self.endpoint_model).filter_by(
ip_address=ip).delete()
def delete_endpoint_by_host_or_ip(self, host, ip):
LOG.debug("delete_endpoint_by_host_or_ip() called for "
"host %(host)s or %(ip)s", {'host': host, 'ip': ip})
session = db_api.get_writer_session()
session.query(self.endpoint_model).filter(
ctx = context.get_admin_context()
with db_api.CONTEXT_WRITER.using(ctx):
ctx.session.query(self.endpoint_model).filter(
or_(self.endpoint_model.host == host,
self.endpoint_model.ip_address == ip)).delete()
def _get_endpoints(self):
LOG.debug("_get_endpoints() called")
session = db_api.get_reader_session()
return session.query(self.endpoint_model)
ctx = context.get_admin_context()
with db_api.CONTEXT_READER.using(ctx):
return ctx.session.query(self.endpoint_model).all()
def _add_endpoint(self, ip, host, **kwargs):
LOG.debug("_add_endpoint() called for ip %s", ip)
session = db_api.get_writer_session()
ctx = context.get_admin_context()
try:
endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs)
endpoint.save(session)
with db_api.CONTEXT_WRITER.using(ctx):
endpoint = self.endpoint_model(ip_address=ip, host=host,
**kwargs)
endpoint.save(ctx.session)
except db_exc.DBDuplicateEntry:
endpoint = (session.query(self.endpoint_model).
with db_api.CONTEXT_READER.using(ctx):
endpoint = (ctx.session.query(self.endpoint_model).
filter_by(ip_address=ip).one())
LOG.warning("Endpoint with ip %s already exists", ip)
return endpoint

View File

@ -2017,6 +2017,7 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2,
@utils.transaction_guard
@db_api.retry_if_session_inactive()
def delete_port(self, context, id, l3_port_check=True):
with db_api.CONTEXT_READER.using(context):
try:
port_db = self._get_port(context, id)
port = self._make_port_dict(port_db)

View File

@ -268,6 +268,7 @@ class TrackedResource(BaseResource):
# Update quota usage
return self._resync(context, project_id, in_use)
@db_api.CONTEXT_WRITER
def count_used(self, context, project_id, resync_usage=True):
"""Returns the current usage count for the resource.

View File

@ -50,6 +50,7 @@ class TagPlugin(tagging.TagPluginBase):
tags = [tag_db.tag for tag_db in db_data.standard_attr.tags]
response_data['tags'] = tags
@db_api.CONTEXT_READER
def _get_resource(self, context, resource, resource_id):
model = resource_model_map[resource]
try:

View File

@ -176,24 +176,24 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
def test_ensure_no_port_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_no_port_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_port_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, True)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_2)
self.cxt, self.network_id, self.tenant_1, self.tenant_2)
def test_ensure_port_tenant_1_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_2_in_asterisk(self):
self._create_network(self.tenant_1, self.network_id, True)
@ -201,21 +201,21 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, '*')
self.cxt, self.network_id, self.tenant_1, '*')
def test_ensure_port_tenant_1_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, True)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_1, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_1(self):
self._create_network(self.tenant_1, self.network_id, False)
self._create_subnet(self.tenant_1, self.subnet_1_id, True)
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.plugin.ensure_no_tenant_ports_on_network(
self.network_id, self.tenant_1, self.tenant_1)
self.cxt, self.network_id, self.tenant_1, self.tenant_1)
def test_ensure_no_share_port_tenant_2_in_tenant_2(self):
self._create_network(self.tenant_1, self.network_id, False)
@ -223,4 +223,5 @@ class NetworkRBACTestCase(testlib_api.SqlTestCase):
self._create_port(self.tenant_2, self.network_id, self.port_id)
self.assertRaises(n_exc.InvalidSharedSetting,
self.plugin.ensure_no_tenant_ports_on_network,
self.network_id, self.tenant_1, self.tenant_2)
self.cxt, self.network_id, self.tenant_1,
self.tenant_2)

View File

@ -249,9 +249,7 @@ class RbacNeutronDbObjectTestCase(test_rbac.RBACBaseObjectIfaceTestCase,
'_get_projects_with_shared_access_to_db_obj') as sh_tids:
get_rbac_entries_mock.filter.return_value.count.return_value = 0
self._test_class._validate_rbac_policy_delete(
context=context,
obj_id='fake_obj_id',
target_project='fake_tid1')
context, 'fake_obj_id', 'fake_tid1')
sh_tids.assert_not_called()
@mock.patch.object(_test_class, '_get_db_obj_rbac_entries')