Merge "[zmq] Fix send_cast in AckManager"

This commit is contained in:
Jenkins 2016-10-06 08:04:52 +00:00 committed by Gerrit Code Review
commit cb13e65bed
11 changed files with 97 additions and 153 deletions

View File

@ -34,8 +34,7 @@ class DealerPublisherBase(zmq_publisher_base.PublisherBase):
def __init__(self, conf, matchmaker, sender, receiver): def __init__(self, conf, matchmaker, sender, receiver):
sockets_manager = zmq_sockets_manager.SocketsManager( sockets_manager = zmq_sockets_manager.SocketsManager(
conf, matchmaker, zmq.ROUTER, zmq.DEALER) conf, matchmaker, zmq.DEALER)
self.socket_type = zmq.DEALER
super(DealerPublisherBase, self).__init__( super(DealerPublisherBase, self).__init__(
sockets_manager, sender, receiver) sockets_manager, sender, receiver)

View File

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from oslo_messaging._drivers.zmq_driver.client.publishers.dealer \ from oslo_messaging._drivers.zmq_driver.client.publishers.dealer \
@ -22,8 +21,6 @@ from oslo_messaging._drivers.zmq_driver.client import zmq_routing_table
from oslo_messaging._drivers.zmq_driver.client import zmq_senders from oslo_messaging._drivers.zmq_driver.client import zmq_senders
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_socket
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -55,8 +52,6 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase):
""" """
def __init__(self, conf, matchmaker): def __init__(self, conf, matchmaker):
self.routing_table = zmq_routing_table.RoutingTableAdaptor(
conf, matchmaker, zmq.ROUTER)
sender = zmq_senders.RequestSenderDirect(conf) sender = zmq_senders.RequestSenderDirect(conf)
if conf.oslo_messaging_zmq.rpc_use_acks: if conf.oslo_messaging_zmq.rpc_use_acks:
receiver = zmq_receivers.AckAndReplyReceiverDirect(conf) receiver = zmq_receivers.AckAndReplyReceiverDirect(conf)
@ -65,6 +60,9 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase):
super(DealerPublisherDirect, self).__init__(conf, matchmaker, sender, super(DealerPublisherDirect, self).__init__(conf, matchmaker, sender,
receiver) receiver)
self.routing_table = zmq_routing_table.RoutingTableAdaptor(
conf, matchmaker, zmq.ROUTER)
def _get_round_robin_host_connection(self, target, socket): def _get_round_robin_host_connection(self, target, socket):
host = self.routing_table.get_round_robin_host(target) host = self.routing_table.get_round_robin_host(target)
socket.connect_to_host(host) socket.connect_to_host(host)
@ -74,8 +72,7 @@ class DealerPublisherDirect(zmq_dealer_publisher_base.DealerPublisherBase):
socket.connect_to_host(host) socket.connect_to_host(host)
def acquire_connection(self, request): def acquire_connection(self, request):
socket = zmq_socket.ZmqSocket(self.conf, self.context, socket = self.sockets_manager.get_socket()
self.socket_type, immediate=False)
if request.msg_type in zmq_names.DIRECT_TYPES: if request.msg_type in zmq_names.DIRECT_TYPES:
self._get_round_robin_host_connection(request.target, socket) self._get_round_robin_host_connection(request.target, socket)
elif request.msg_type in zmq_names.MULTISEND_TYPES: elif request.msg_type in zmq_names.MULTISEND_TYPES:

View File

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -27,10 +27,8 @@ from oslo_messaging._drivers.zmq_driver.matchmaker import zmq_matchmaker_base
from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_address
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_socket
from oslo_messaging._drivers.zmq_driver import zmq_updater from oslo_messaging._drivers.zmq_driver import zmq_updater
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq() zmq = zmq_async.import_zmq()
@ -47,14 +45,13 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
receiver = zmq_receivers.ReplyReceiverProxy(conf) receiver = zmq_receivers.ReplyReceiverProxy(conf)
super(DealerPublisherProxy, self).__init__(conf, matchmaker, sender, super(DealerPublisherProxy, self).__init__(conf, matchmaker, sender,
receiver) receiver)
self.socket = self.sockets_manager.get_socket_to_publishers( self.socket = self.sockets_manager.get_socket_to_publishers(
self._generate_identity()) self._generate_identity())
self.routing_table = zmq_routing_table.RoutingTableAdaptor( self.routing_table = zmq_routing_table.RoutingTableAdaptor(
conf, matchmaker, zmq.DEALER) conf, matchmaker, zmq.DEALER)
self.connection_updater = PublisherConnectionUpdater(
self.connection_updater = \ self.conf, self.matchmaker, self.socket)
PublisherConnectionUpdater(self.conf, self.matchmaker, self.socket)
def _generate_identity(self): def _generate_identity(self):
return six.b(self.conf.oslo_messaging_zmq.rpc_zmq_host + "/" + return six.b(self.conf.oslo_messaging_zmq.rpc_zmq_host + "/" +
@ -84,50 +81,49 @@ class DealerPublisherProxy(zmq_dealer_publisher_base.DealerPublisherBase):
self.sender.send(socket, request) self.sender.send(socket, request)
def cleanup(self): def cleanup(self):
super(DealerPublisherProxy, self).cleanup()
self.routing_table.cleanup()
self.connection_updater.stop() self.connection_updater.stop()
self.routing_table.cleanup()
self.socket.close() self.socket.close()
super(DealerPublisherProxy, self).cleanup()
class PublisherConnectionUpdater(zmq_updater.ConnectionUpdater): class PublisherConnectionUpdater(zmq_updater.ConnectionUpdater):
def _update_connection(self): def _update_connection(self):
publishers = self.matchmaker.get_publishers() publishers = self.matchmaker.get_publishers()
for pub_address, router_address in publishers: for pub_address, fe_router_address in publishers:
self.socket.connect_to_host(router_address) self.socket.connect_to_host(fe_router_address)
class DealerPublisherProxyDynamic( class DealerPublisherProxyDynamic(
zmq_dealer_publisher_base.DealerPublisherBase): zmq_dealer_publisher_base.DealerPublisherBase):
def __init__(self, conf, matchmaker): def __init__(self, conf, matchmaker):
sender = zmq_senders.RequestSenderProxy(conf)
receiver = zmq_receivers.ReplyReceiverDirect(conf)
super(DealerPublisherProxyDynamic, self).__init__(conf, matchmaker,
sender, receiver)
self.publishers = set() self.publishers = set()
self.updater = DynamicPublishersUpdater(conf, matchmaker, self.updater = DynamicPublishersUpdater(conf, matchmaker,
self.publishers) self.publishers)
self.updater.update_publishers() self.updater.update_publishers()
sender = zmq_senders.RequestSenderProxy(conf)
receiver = zmq_receivers.ReplyReceiverDirect(conf)
super(DealerPublisherProxyDynamic, self).__init__(
conf, matchmaker, sender, receiver)
def acquire_connection(self, request): def acquire_connection(self, request):
socket = zmq_socket.ZmqSocket(self.conf, self.context,
self.socket_type, immediate=False)
if not self.publishers: if not self.publishers:
raise zmq_matchmaker_base.MatchmakerUnavailable() raise zmq_matchmaker_base.MatchmakerUnavailable()
socket = self.sockets_manager.get_socket()
socket.connect_to_host(random.choice(tuple(self.publishers))) socket.connect_to_host(random.choice(tuple(self.publishers)))
return socket return socket
def send_request(self, socket, request): def send_request(self, socket, request):
assert request.msg_type in zmq_names.MULTISEND_TYPES request.routing_key = \
request.routing_key = zmq_address.target_to_subscribe_filter( zmq_address.target_to_subscribe_filter(request.target)
request.target)
self.sender.send(socket, request) self.sender.send(socket, request)
def cleanup(self): def cleanup(self):
super(DealerPublisherProxyDynamic, self).cleanup()
self.updater.cleanup() self.updater.cleanup()
super(DealerPublisherProxyDynamic, self).cleanup()
class DynamicPublishersUpdater(zmq_updater.UpdaterBase): class DynamicPublishersUpdater(zmq_updater.UpdaterBase):
@ -140,5 +136,6 @@ class DynamicPublishersUpdater(zmq_updater.UpdaterBase):
self.publishers = publishers self.publishers = publishers
def update_publishers(self): def update_publishers(self):
for _, pub_frontend in self.matchmaker.get_publishers(): publishers = self.matchmaker.get_publishers()
self.publishers.add(pub_frontend) for pub_address, fe_router_address in publishers:
self.publishers.add(fe_router_address)

View File

@ -20,8 +20,8 @@ import six
import oslo_messaging import oslo_messaging
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq() zmq = zmq_async.import_zmq()
@ -49,7 +49,6 @@ class PublisherBase(object):
:param receiver: reply receiver object :param receiver: reply receiver object
:type receiver: zmq_receivers.ReplyReceiver :type receiver: zmq_receivers.ReplyReceiver
""" """
self.context = zmq.Context()
self.sockets_manager = sockets_manager self.sockets_manager = sockets_manager
self.conf = sockets_manager.conf self.conf = sockets_manager.conf
self.matchmaker = sockets_manager.matchmaker self.matchmaker = sockets_manager.matchmaker
@ -94,4 +93,3 @@ class PublisherBase(object):
def cleanup(self): def cleanup(self):
"""Cleanup publisher. Close allocated connections.""" """Cleanup publisher. Close allocated connections."""
self.receiver.stop() self.receiver.stop()
self.sockets_manager.cleanup()

View File

@ -33,15 +33,17 @@ class AckManager(zmq_publisher_manager.PublisherManagerBase):
size=self.conf.oslo_messaging_zmq.rpc_thread_pool_size size=self.conf.oslo_messaging_zmq.rpc_thread_pool_size
) )
def _wait_for_ack(self, ack_future): def _wait_for_ack(self, request, ack_future=None):
request = ack_future.request if ack_future is None:
ack_future = self._schedule_request_for_ack(request)
retries = \ retries = \
request.retry or self.conf.oslo_messaging_zmq.rpc_retry_attempts request.retry or self.conf.oslo_messaging_zmq.rpc_retry_attempts
if retries is None: if retries is None:
retries = -1 retries = -1
timeout = self.conf.oslo_messaging_zmq.rpc_ack_timeout_base timeout = self.conf.oslo_messaging_zmq.rpc_ack_timeout_base
done = False done = ack_future is None
while not done: while not done:
try: try:
reply_id, response = ack_future.result(timeout=timeout) reply_id, response = ack_future.result(timeout=timeout)
@ -72,39 +74,41 @@ class AckManager(zmq_publisher_manager.PublisherManagerBase):
if request.msg_type != zmq_names.CALL_TYPE: if request.msg_type != zmq_names.CALL_TYPE:
self.receiver.untrack_request(request) self.receiver.untrack_request(request)
def _schedule_request_for_ack(self, request): @zmq_publisher_manager.target_not_found_warn
def _send_request(self, request):
socket = self.publisher.acquire_connection(request) socket = self.publisher.acquire_connection(request)
self.publisher.send_request(socket, request) self.publisher.send_request(socket, request)
return socket
def _schedule_request_for_ack(self, request):
socket = self._send_request(request)
if socket is None:
return None
self.receiver.register_socket(socket) self.receiver.register_socket(socket)
futures_by_type = self.receiver.track_request(request) ack_future = self.receiver.track_request(request)[zmq_names.ACK_TYPE]
ack_future = futures_by_type[zmq_names.ACK_TYPE]
ack_future.request = request
ack_future.socket = socket ack_future.socket = socket
return ack_future return ack_future
@zmq_publisher_manager.target_not_found_timeout
def send_call(self, request): def send_call(self, request):
ack_future = self._schedule_request_for_ack(request)
if ack_future is None:
self.publisher._raise_timeout(request)
self._pool.submit(self._wait_for_ack, request, ack_future)
try: try:
ack_future = self._schedule_request_for_ack(request)
self._pool.submit(self._wait_for_ack, ack_future)
return self.publisher.receive_reply(ack_future.socket, request) return self.publisher.receive_reply(ack_future.socket, request)
finally: finally:
if not ack_future.done(): if not ack_future.done():
ack_future.set_result((None, None)) ack_future.set_result((None, None))
@zmq_publisher_manager.target_not_found_warn
def send_cast(self, request): def send_cast(self, request):
ack_future = self._schedule_request_for_ack(request) self._pool.submit(self._wait_for_ack, request)
self._pool.submit(self._wait_for_ack, ack_future)
@zmq_publisher_manager.target_not_found_warn def send_fanout(self, request):
def _send_request(self, request): self._send_request(request)
socket = self.publisher.acquire_connection(request)
self.publisher.send_request(socket, request) def send_notify(self, request):
self._send_request(request)
def cleanup(self): def cleanup(self):
self._pool.shutdown(wait=True) self._pool.shutdown(wait=True)
super(AckManager, self).cleanup() super(AckManager, self).cleanup()
send_fanout = _send_request
send_notify = _send_request

View File

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -12,7 +12,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from oslo_messaging._drivers import common from oslo_messaging._drivers import common
from oslo_messaging._drivers.zmq_driver.client import zmq_client_base from oslo_messaging._drivers.zmq_driver.client import zmq_client_base
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
@ -69,8 +68,9 @@ class ZmqClientDirect(zmq_client_base.ZmqClientBase):
super(ZmqClientDirect, self).__init__( super(ZmqClientDirect, self).__init__(
conf, matchmaker, allowed_remote_exmods, conf, matchmaker, allowed_remote_exmods,
publishers={"default": self._create_publisher_direct( publishers={
conf, matchmaker)} "default": self._create_publisher_direct(conf, matchmaker)
}
) )
@ -91,6 +91,7 @@ class ZmqClientProxy(zmq_client_base.ZmqClientBase):
super(ZmqClientProxy, self).__init__( super(ZmqClientProxy, self).__init__(
conf, matchmaker, allowed_remote_exmods, conf, matchmaker, allowed_remote_exmods,
publishers={"default": self._create_publisher_proxy( publishers={
conf, matchmaker)} "default": self._create_publisher_proxy(conf, matchmaker)
}
) )

View File

@ -1,4 +1,4 @@
# Copyright 2015 Mirantis, Inc. # Copyright 2015-2016 Mirantis, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
@ -73,8 +73,7 @@ class ZmqClientBase(object):
def _create_publisher_direct(conf, matchmaker): def _create_publisher_direct(conf, matchmaker):
publisher_direct = zmq_dealer_publisher_direct.DealerPublisherDirect( publisher_direct = zmq_dealer_publisher_direct.DealerPublisherDirect(
conf, matchmaker) conf, matchmaker)
return zmq_publisher_manager.PublisherManagerDynamic( return zmq_publisher_manager.PublisherManagerDynamic(publisher_direct)
publisher_direct)
@staticmethod @staticmethod
def _create_publisher_proxy(conf, matchmaker): def _create_publisher_proxy(conf, matchmaker):
@ -86,9 +85,10 @@ class ZmqClientBase(object):
@staticmethod @staticmethod
def _create_publisher_proxy_dynamic(conf, matchmaker): def _create_publisher_proxy_dynamic(conf, matchmaker):
return zmq_publisher_manager.PublisherManagerDynamic( publisher_proxy = \
zmq_dealer_publisher_proxy.DealerPublisherProxyDynamic( zmq_dealer_publisher_proxy.DealerPublisherProxyDynamic(
conf, matchmaker)) conf, matchmaker)
return zmq_publisher_manager.PublisherManagerDynamic(publisher_proxy)
def cleanup(self): def cleanup(self):
cleaned = set() cleaned = set()

View File

@ -28,16 +28,20 @@ LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq() zmq = zmq_async.import_zmq()
def _drop_message_warn(request):
LOG.warning(_LW("Matchmaker contains no records for specified "
"target %(target)s. Dropping message %(msg_id)s.")
% {"target": request.target,
"msg_id": request.message_id})
def target_not_found_warn(func): def target_not_found_warn(func):
def _target_not_found_warn(self, request, *args, **kwargs): def _target_not_found_warn(self, request, *args, **kwargs):
try: try:
return func(self, request, *args, **kwargs) return func(self, request, *args, **kwargs)
except (zmq_matchmaker_base.MatchmakerUnavailable, except (zmq_matchmaker_base.MatchmakerUnavailable,
retrying.RetryError): retrying.RetryError):
LOG.warning(_LW("Matchmaker contains no records for specified " _drop_message_warn(request)
"target %(target)s. Dropping message %(msg_id)s.")
% {"target": request.target,
"msg_id": request.message_id})
return _target_not_found_warn return _target_not_found_warn
@ -47,6 +51,7 @@ def target_not_found_timeout(func):
return func(self, request, *args, **kwargs) return func(self, request, *args, **kwargs)
except (zmq_matchmaker_base.MatchmakerUnavailable, except (zmq_matchmaker_base.MatchmakerUnavailable,
retrying.RetryError): retrying.RetryError):
_drop_message_warn(request)
self.publisher._raise_timeout(request) self.publisher._raise_timeout(request)
return _target_not_found_timeout return _target_not_found_timeout
@ -72,31 +77,31 @@ class PublisherManagerBase(object):
"""Send call request """Send call request
:param request: request object :param request: request object
:type senders: zmq_request.Request :type request: zmq_request.CallRequest
""" """
@abc.abstractmethod @abc.abstractmethod
def send_cast(self, request): def send_cast(self, request):
"""Send call request """Send cast request
:param request: request object :param request: request object
:type senders: zmq_request.Request :type request: zmq_request.CastRequest
""" """
@abc.abstractmethod @abc.abstractmethod
def send_fanout(self, request): def send_fanout(self, request):
"""Send call request """Send fanout request
:param request: request object :param request: request object
:type senders: zmq_request.Request :type request: zmq_request.FanoutRequest
""" """
@abc.abstractmethod @abc.abstractmethod
def send_notify(self, request): def send_notify(self, request):
"""Send call request """Send notification request
:param request: request object :param request: request object
:type senders: zmq_request.Request :type request: zmq_request.NotificationRequest
""" """
def cleanup(self): def cleanup(self):
@ -107,8 +112,8 @@ class PublisherManagerDynamic(PublisherManagerBase):
@target_not_found_timeout @target_not_found_timeout
def send_call(self, request): def send_call(self, request):
with contextlib.closing( with contextlib.closing(self.publisher.acquire_connection(request)) \
self.publisher.acquire_connection(request)) as socket: as socket:
self.publisher.send_request(socket, request) self.publisher.send_request(socket, request)
reply = self.publisher.receive_reply(socket, request) reply = self.publisher.receive_reply(socket, request)
return reply return reply

View File

@ -12,12 +12,11 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import itertools
import logging import logging
import threading import threading
import time import time
import itertools
from oslo_messaging._drivers.zmq_driver.matchmaker import zmq_matchmaker_base from oslo_messaging._drivers.zmq_driver.matchmaker import zmq_matchmaker_base
from oslo_messaging._drivers.zmq_driver import zmq_address from oslo_messaging._drivers.zmq_driver import zmq_address
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
@ -25,10 +24,10 @@ from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_updater from oslo_messaging._drivers.zmq_driver import zmq_updater
from oslo_messaging._i18n import _LW from oslo_messaging._i18n import _LW
zmq = zmq_async.import_zmq()
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
zmq = zmq_async.import_zmq()
class RoutingTableAdaptor(object): class RoutingTableAdaptor(object):
@ -63,8 +62,8 @@ class RoutingTableAdaptor(object):
return host return host
def get_fanout_hosts(self, target): def get_fanout_hosts(self, target):
target_key = zmq_address.target_to_key( target_key = zmq_address.prefix_str(
target, zmq_names.socket_type_str(self.listener_type)) target.topic, zmq_names.socket_type_str(self.listener_type))
LOG.debug("Processing target %s for fanout." % target_key) LOG.debug("Processing target %s for fanout." % target_key)
@ -123,14 +122,13 @@ class RoutingTable(object):
self.targets[target_key] = (hosts_updated, self._create_tm()) self.targets[target_key] = (hosts_updated, self._create_tm())
def get_hosts_round_robin(self, target_key): def get_hosts_round_robin(self, target_key):
while self._contains_hosts(target_key): while self.contains(target_key):
for host in self._get_hosts_rr(target_key): for host in self._get_hosts_rr(target_key):
yield host yield host
def get_hosts_fanout(self, target_key): def get_hosts_fanout(self, target_key):
hosts, _ = self._get_hosts(target_key) hosts, _ = self._get_hosts(target_key)
for host in hosts: return hosts
yield host
def contains(self, target_key): def contains(self, target_key):
with self._lock: with self._lock:
@ -147,10 +145,6 @@ class RoutingTable(object):
_, tm = self.targets.get(target_key) _, tm = self.targets.get(target_key)
return tm return tm
def _contains_hosts(self, target_key):
with self._lock:
return target_key in self.targets
def _is_target_changed(self, target_key, tm_orig): def _is_target_changed(self, target_key, tm_orig):
return self._get_tm(target_key) != tm_orig return self._get_tm(target_key) != tm_orig

View File

@ -12,10 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import time
from oslo_messaging._drivers.zmq_driver import zmq_async from oslo_messaging._drivers.zmq_driver import zmq_async
from oslo_messaging._drivers.zmq_driver import zmq_names
from oslo_messaging._drivers.zmq_driver import zmq_socket from oslo_messaging._drivers.zmq_driver import zmq_socket
zmq = zmq_async.import_zmq() zmq = zmq_async.import_zmq()
@ -23,61 +20,17 @@ zmq = zmq_async.import_zmq()
class SocketsManager(object): class SocketsManager(object):
def __init__(self, conf, matchmaker, listener_type, socket_type): def __init__(self, conf, matchmaker, socket_type):
self.conf = conf self.conf = conf
self.matchmaker = matchmaker self.matchmaker = matchmaker
self.listener_type = listener_type
self.socket_type = socket_type self.socket_type = socket_type
self.zmq_context = zmq.Context() self.zmq_context = zmq.Context()
self.outbound_sockets = {}
self.socket_to_publishers = None self.socket_to_publishers = None
self.socket_to_routers = None self.socket_to_routers = None
def get_hosts(self, target): def get_socket(self):
return self.matchmaker.get_hosts_retry( socket = zmq_socket.ZmqSocket(self.conf, self.zmq_context,
target, zmq_names.socket_type_str(self.listener_type)) self.socket_type, immediate=False)
def get_hosts_fanout(self, target):
return self.matchmaker.get_hosts_fanout_retry(
target, zmq_names.socket_type_str(self.listener_type))
@staticmethod
def _key_from_target(target):
return target.topic if target.fanout else str(target)
def _get_hosts_and_track(self, socket, target):
self._get_hosts_and_connect(socket, target)
self._track_socket(socket, target)
def _get_hosts_and_connect(self, socket, target):
get_hosts = self.get_hosts_fanout if target.fanout else self.get_hosts
hosts = get_hosts(target)
self._connect_to_hosts(socket, hosts)
def _track_socket(self, socket, target):
key = self._key_from_target(target)
self.outbound_sockets[key] = (socket, time.time())
def _connect_to_hosts(self, socket, hosts):
for host in hosts:
socket.connect_to_host(host)
def _check_for_new_hosts(self, target):
key = self._key_from_target(target)
socket, tm = self.outbound_sockets[key]
if 0 <= self.conf.oslo_messaging_zmq.zmq_target_expire \
<= time.time() - tm:
self._get_hosts_and_track(socket, target)
return socket
def get_socket(self, target):
key = self._key_from_target(target)
if key in self.outbound_sockets:
socket = self._check_for_new_hosts(target)
else:
socket = zmq_socket.ZmqSocket(self.conf, self.zmq_context,
self.socket_type, immediate=False)
self._get_hosts_and_track(socket, target)
return socket return socket
def get_socket_to_publishers(self, identity=None): def get_socket_to_publishers(self, identity=None):
@ -88,8 +41,8 @@ class SocketsManager(object):
immediate=self.conf.oslo_messaging_zmq.zmq_immediate, immediate=self.conf.oslo_messaging_zmq.zmq_immediate,
identity=identity) identity=identity)
publishers = self.matchmaker.get_publishers() publishers = self.matchmaker.get_publishers()
for pub_address, router_address in publishers: for pub_address, fe_router_address in publishers:
self.socket_to_publishers.connect_to_host(router_address) self.socket_to_publishers.connect_to_host(fe_router_address)
return self.socket_to_publishers return self.socket_to_publishers
def get_socket_to_routers(self, identity=None): def get_socket_to_routers(self, identity=None):
@ -100,10 +53,6 @@ class SocketsManager(object):
immediate=self.conf.oslo_messaging_zmq.zmq_immediate, immediate=self.conf.oslo_messaging_zmq.zmq_immediate,
identity=identity) identity=identity)
routers = self.matchmaker.get_routers() routers = self.matchmaker.get_routers()
for router_address in routers: for be_router_address in routers:
self.socket_to_routers.connect_to_host(router_address) self.socket_to_routers.connect_to_host(be_router_address)
return self.socket_to_routers return self.socket_to_routers
def cleanup(self):
for socket, tm in self.outbound_sockets.values():
socket.close()

View File

@ -41,7 +41,7 @@ class DealerConsumer(zmq_consumer_base.SingleSocketConsumer):
def __init__(self, conf, poller, server): def __init__(self, conf, poller, server):
self.reply_sender = zmq_senders.ReplySenderProxy(conf) self.reply_sender = zmq_senders.ReplySenderProxy(conf)
self.sockets_manager = zmq_sockets_manager.SocketsManager( self.sockets_manager = zmq_sockets_manager.SocketsManager(
conf, server.matchmaker, zmq.ROUTER, zmq.DEALER) conf, server.matchmaker, zmq.DEALER)
self.host = None self.host = None
super(DealerConsumer, self).__init__(conf, poller, server, zmq.DEALER) super(DealerConsumer, self).__init__(conf, poller, server, zmq.DEALER)
self.connection_updater = ConsumerConnectionUpdater( self.connection_updater = ConsumerConnectionUpdater(