Merge "Refactor type_tunnel/gre/vxlan to reduce duplicate code"
This commit is contained in:
commit
8b427ae869
@ -66,10 +66,11 @@ class GreEndpoints(model_base.BASEV2):
|
||||
return "<GreTunnelEndpoint(%s)>" % self.ip_address
|
||||
|
||||
|
||||
class GreTypeDriver(type_tunnel.TunnelTypeDriver):
|
||||
class GreTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
|
||||
|
||||
def __init__(self):
|
||||
super(GreTypeDriver, self).__init__(GreAllocation)
|
||||
super(GreTypeDriver, self).__init__(
|
||||
GreAllocation, GreEndpoints)
|
||||
|
||||
def get_type(self):
|
||||
return p_const.TYPE_GRE
|
||||
@ -127,45 +128,13 @@ class GreTypeDriver(type_tunnel.TunnelTypeDriver):
|
||||
|
||||
def get_endpoints(self):
|
||||
"""Get every gre endpoints from database."""
|
||||
|
||||
LOG.debug("get_gre_endpoints() called")
|
||||
session = db_api.get_session()
|
||||
|
||||
gre_endpoints = session.query(GreEndpoints)
|
||||
gre_endpoints = self._get_endpoints()
|
||||
return [{'ip_address': gre_endpoint.ip_address,
|
||||
'host': gre_endpoint.host}
|
||||
for gre_endpoint in gre_endpoints]
|
||||
|
||||
def get_endpoint_by_host(self, host):
|
||||
LOG.debug("get_endpoint_by_host() called for host %s", host)
|
||||
session = db_api.get_session()
|
||||
return (session.query(GreEndpoints).
|
||||
filter_by(host=host).first())
|
||||
|
||||
def get_endpoint_by_ip(self, ip):
|
||||
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
return (session.query(GreEndpoints).
|
||||
filter_by(ip_address=ip).first())
|
||||
|
||||
def add_endpoint(self, ip, host):
|
||||
LOG.debug("add_gre_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
try:
|
||||
gre_endpoint = GreEndpoints(ip_address=ip, host=host)
|
||||
gre_endpoint.save(session)
|
||||
except db_exc.DBDuplicateEntry:
|
||||
gre_endpoint = (session.query(GreEndpoints).
|
||||
filter_by(ip_address=ip).one())
|
||||
LOG.warning(_LW("Gre endpoint with ip %s already exists"), ip)
|
||||
return gre_endpoint
|
||||
|
||||
def delete_endpoint(self, ip):
|
||||
LOG.debug("delete_gre_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
|
||||
with session.begin(subtransactions=True):
|
||||
session.query(GreEndpoints).filter_by(ip_address=ip).delete()
|
||||
return self._add_endpoint(ip, host)
|
||||
|
||||
def get_mtu(self, physical_network=None):
|
||||
mtu = super(GreTypeDriver, self).get_mtu(physical_network)
|
||||
|
@ -15,10 +15,12 @@
|
||||
import abc
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_db import exception as db_exc
|
||||
from oslo_log import log
|
||||
|
||||
from neutron.common import exceptions as exc
|
||||
from neutron.common import topics
|
||||
from neutron.db import api as db_api
|
||||
from neutron.i18n import _LI, _LW
|
||||
from neutron.plugins.common import utils as plugin_utils
|
||||
from neutron.plugins.ml2 import driver_api as api
|
||||
@ -196,6 +198,50 @@ class TunnelTypeDriver(helpers.SegmentTypeDriver):
|
||||
return min(mtu) if mtu else 0
|
||||
|
||||
|
||||
class EndpointTunnelTypeDriver(TunnelTypeDriver):
|
||||
|
||||
def __init__(self, segment_model, endpoint_model):
|
||||
super(EndpointTunnelTypeDriver, self).__init__(segment_model)
|
||||
self.endpoint_model = endpoint_model
|
||||
self.segmentation_key = iter(self.primary_keys).next()
|
||||
|
||||
def get_endpoint_by_host(self, host):
|
||||
LOG.debug("get_endpoint_by_host() called for host %s", host)
|
||||
session = db_api.get_session()
|
||||
return (session.query(self.endpoint_model).
|
||||
filter_by(host=host).first())
|
||||
|
||||
def get_endpoint_by_ip(self, ip):
|
||||
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
return (session.query(self.endpoint_model).
|
||||
filter_by(ip_address=ip).first())
|
||||
|
||||
def delete_endpoint(self, ip):
|
||||
LOG.debug("delete_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
with session.begin(subtransactions=True):
|
||||
(session.query(self.endpoint_model).
|
||||
filter_by(ip_address=ip).delete())
|
||||
|
||||
def _get_endpoints(self):
|
||||
LOG.debug("_get_endpoints() called")
|
||||
session = db_api.get_session()
|
||||
return session.query(self.endpoint_model)
|
||||
|
||||
def _add_endpoint(self, ip, host, **kwargs):
|
||||
LOG.debug("_add_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
try:
|
||||
endpoint = self.endpoint_model(ip_address=ip, host=host, **kwargs)
|
||||
endpoint.save(session)
|
||||
except db_exc.DBDuplicateEntry:
|
||||
endpoint = (session.query(self.endpoint_model).
|
||||
filter_by(ip_address=ip).one())
|
||||
LOG.warning(_LW("Endpoint with ip %s already exists"), ip)
|
||||
return endpoint
|
||||
|
||||
|
||||
class TunnelRpcCallbackMixin(object):
|
||||
|
||||
def setup_tunnel_callback_mixin(self, notifier, type_manager):
|
||||
|
@ -14,7 +14,6 @@
|
||||
# under the License.
|
||||
|
||||
from oslo_config import cfg
|
||||
from oslo_db import exception as db_exc
|
||||
from oslo_log import log
|
||||
from six import moves
|
||||
import sqlalchemy as sa
|
||||
@ -23,7 +22,7 @@ from sqlalchemy import sql
|
||||
from neutron.common import exceptions as n_exc
|
||||
from neutron.db import api as db_api
|
||||
from neutron.db import model_base
|
||||
from neutron.i18n import _LE, _LW
|
||||
from neutron.i18n import _LE
|
||||
from neutron.plugins.common import constants as p_const
|
||||
from neutron.plugins.ml2.drivers import type_tunnel
|
||||
|
||||
@ -70,10 +69,11 @@ class VxlanEndpoints(model_base.BASEV2):
|
||||
return "<VxlanTunnelEndpoint(%s)>" % self.ip_address
|
||||
|
||||
|
||||
class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
|
||||
class VxlanTypeDriver(type_tunnel.EndpointTunnelTypeDriver):
|
||||
|
||||
def __init__(self):
|
||||
super(VxlanTypeDriver, self).__init__(VxlanAllocation)
|
||||
super(VxlanTypeDriver, self).__init__(
|
||||
VxlanAllocation, VxlanEndpoints)
|
||||
|
||||
def get_type(self):
|
||||
return p_const.TYPE_VXLAN
|
||||
@ -132,48 +132,14 @@ class VxlanTypeDriver(type_tunnel.TunnelTypeDriver):
|
||||
|
||||
def get_endpoints(self):
|
||||
"""Get every vxlan endpoints from database."""
|
||||
|
||||
LOG.debug("get_vxlan_endpoints() called")
|
||||
session = db_api.get_session()
|
||||
|
||||
vxlan_endpoints = session.query(VxlanEndpoints)
|
||||
vxlan_endpoints = self._get_endpoints()
|
||||
return [{'ip_address': vxlan_endpoint.ip_address,
|
||||
'udp_port': vxlan_endpoint.udp_port,
|
||||
'host': vxlan_endpoint.host}
|
||||
for vxlan_endpoint in vxlan_endpoints]
|
||||
|
||||
def get_endpoint_by_host(self, host):
|
||||
LOG.debug("get_endpoint_by_host() called for host %s", host)
|
||||
session = db_api.get_session()
|
||||
return (session.query(VxlanEndpoints).
|
||||
filter_by(host=host).first())
|
||||
|
||||
def get_endpoint_by_ip(self, ip):
|
||||
LOG.debug("get_endpoint_by_ip() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
return (session.query(VxlanEndpoints).
|
||||
filter_by(ip_address=ip).first())
|
||||
|
||||
def add_endpoint(self, ip, host, udp_port=p_const.VXLAN_UDP_PORT):
|
||||
LOG.debug("add_vxlan_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
try:
|
||||
vxlan_endpoint = VxlanEndpoints(ip_address=ip,
|
||||
udp_port=udp_port,
|
||||
host=host)
|
||||
vxlan_endpoint.save(session)
|
||||
except db_exc.DBDuplicateEntry:
|
||||
vxlan_endpoint = (session.query(VxlanEndpoints).
|
||||
filter_by(ip_address=ip).one())
|
||||
LOG.warning(_LW("Vxlan endpoint with ip %s already exists"), ip)
|
||||
return vxlan_endpoint
|
||||
|
||||
def delete_endpoint(self, ip):
|
||||
LOG.debug("delete_vxlan_endpoint() called for ip %s", ip)
|
||||
session = db_api.get_session()
|
||||
|
||||
with session.begin(subtransactions=True):
|
||||
session.query(VxlanEndpoints).filter_by(ip_address=ip).delete()
|
||||
return self._add_endpoint(ip, host, udp_port=udp_port)
|
||||
|
||||
def get_mtu(self, physical_network=None):
|
||||
mtu = super(VxlanTypeDriver, self).get_mtu()
|
||||
|
@ -21,6 +21,7 @@ from testtools import matchers
|
||||
from neutron.common import exceptions as exc
|
||||
from neutron.db import api as db
|
||||
from neutron.plugins.ml2 import driver_api as api
|
||||
from neutron.plugins.ml2.drivers import type_tunnel
|
||||
|
||||
TUNNEL_IP_ONE = "10.10.10.10"
|
||||
TUNNEL_IP_TWO = "10.10.10.20"
|
||||
@ -33,7 +34,6 @@ UPDATED_TUNNEL_RANGES = [(TUN_MIN + 5, TUN_MAX + 5)]
|
||||
|
||||
|
||||
class TunnelTypeTestMixin(object):
|
||||
DRIVER_MODULE = None
|
||||
DRIVER_CLASS = None
|
||||
TYPE = None
|
||||
|
||||
@ -208,8 +208,7 @@ class TunnelTypeTestMixin(object):
|
||||
def test_add_endpoint_for_existing_tunnel_ip(self):
|
||||
self.add_endpoint()
|
||||
|
||||
log = getattr(self.DRIVER_MODULE, 'LOG')
|
||||
with mock.patch.object(log, 'warning') as log_warn:
|
||||
with mock.patch.object(type_tunnel.LOG, 'warning') as log_warn:
|
||||
self.add_endpoint()
|
||||
log_warn.assert_called_once_with(mock.ANY, TUNNEL_IP_ONE)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user