diff --git a/oslo/messaging/_drivers/impl_rabbit.py b/oslo/messaging/_drivers/impl_rabbit.py index 202db86a0..3d2987070 100644 --- a/oslo/messaging/_drivers/impl_rabbit.py +++ b/oslo/messaging/_drivers/impl_rabbit.py @@ -17,8 +17,10 @@ import functools import itertools import logging +import Queue import socket import ssl +import threading import time import uuid @@ -31,7 +33,9 @@ import kombu.messaging from oslo.config import cfg from oslo.messaging._drivers import amqp as rpc_amqp +from oslo.messaging._drivers import base from oslo.messaging._drivers import common as rpc_common +from oslo.messaging import _urls as urls from oslo.messaging.openstack.common import excutils from oslo.messaging.openstack.common import network_utils from oslo.messaging.openstack.common import sslutils @@ -39,6 +43,16 @@ from oslo.messaging.openstack.common import sslutils # FIXME(markmc): remove this _ = lambda s: s +# FIXME(markmc): these were toplevel in openstack.common.rpc +rabbit_opts = [ + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.BoolOpt('fake_rabbit', + default=False, + help='If passed, use a fake RabbitMQ provider'), +] + kombu_opts = [ cfg.StrOpt('kombu_ssl_version', default='', @@ -864,3 +878,306 @@ def notify(conf, context, topic, msg, envelope): def cleanup(): return rpc_amqp.cleanup(Connection.pool) + + +class RabbitIncomingMessage(base.IncomingMessage): + + def __init__(self, listener, ctxt, message, msg_id, reply_q): + super(RabbitIncomingMessage, self).__init__(listener, ctxt, message) + + self.msg_id = msg_id + self.reply_q = reply_q + + def _send_reply(self, conn, reply=None, failure=None, ending=False): + # FIXME(markmc): is the reply format really driver specific? + msg = {'result': reply, 'failure': failure} + + # FIXME(markmc): given that we're not supporting multicall ... + if ending: + msg['ending'] = True + + rpc_amqp._add_unique_id(msg) + + # If a reply_q exists, add the msg_id to the reply and pass the + # reply_q to direct_send() to use it as the response queue. + # Otherwise use the msg_id for backward compatibilty. + if self.reply_q: + msg['_msg_id'] = self.msg_id + conn.direct_send(self.reply_q, rpc_common.serialize_msg(msg)) + else: + conn.direct_send(self.msg_id, rpc_common.serialize_msg(msg)) + + def reply(self, reply=None, failure=None): + LOG.info("reply") + with self.listener.driver._get_connection() as conn: + self._send_reply(conn, reply, failure) + self._send_reply(conn, ending=True) + + def done(self): + LOG.info("done") + # FIXME(markmc): I'm not sure we need this method ... we've already + # acked the message at this point + + +class RabbitListener(base.Listener): + + def __init__(self, driver, target, conn): + super(RabbitListener, self).__init__(driver, target) + self.conn = conn + self.msg_id_cache = rpc_amqp._MsgIdCache() + self.incoming = [] + + def __call__(self, message): + # FIXME(markmc): del local.store.context + + # FIXME(markmc): logging isn't driver specific + rpc_common._safe_log(LOG.debug, _('received %s'), message) + + self.msg_id_cache.check_duplicate_message(message) + ctxt = rpc_amqp.unpack_context(self.conf, message) + + self.incoming.append(RabbitIncomingMessage(self, + ctxt.to_dict(), + message, + ctxt.msg_id, + ctxt.reply_q)) + + def poll(self): + while True: + if self.incoming: + return self.incoming.pop(0) + + # FIXME(markmc): timeout? + self.conn.consume(limit=1) + + +class ReplyWaiters(object): + + def __init__(self): + self._queues = {} + self._wrn_threshhold = 10 + + def get(self, msg_id): + return self._queues.get(msg_id) + + def put(self, msg_id, message_data): + queue = self._queues.get(msg_id) + if not queue: + LOG.warn(_('No calling threads waiting for msg_id : %(msg_id)s' + ', message : %(data)s'), {'msg_id': msg_id, + 'data': message_data}) + LOG.warn(_('_queues: %s') % str(self._queues)) + else: + queue.put(message_data) + + def wake_all(self, except_id): + for msg_id in self._queues: + if msg_id == except_id: + continue + self.put(msg_id, None) + + def add(self, msg_id, queue): + self._queues[msg_id] = queue + if len(self._queues) > self._wrn_threshhold: + LOG.warn(_('Number of call queues is greater than warning ' + 'threshhold: %d. There could be a leak.') % + self._wrn_threshhold) + self._wrn_threshhold *= 2 + + def remove(self, msg_id): + del self._queues[msg_id] + + +class RabbitWaiter(object): + + def __init__(self, conf, reply_q, conn): + self.conf = conf + self.conn = conn + self.reply_q = reply_q + + self.conn_lock = threading.Lock() + self.incoming = [] + self.msg_id_cache = rpc_amqp._MsgIdCache() + self.waiters = ReplyWaiters() + + conn.declare_direct_consumer(reply_q, self) + + def __call__(self, message): + self.incoming.append(message) + + def listen(self, msg_id): + queue = Queue.Queue() + self.waiters.add(msg_id, queue) + + def unlisten(self, msg_id): + self.waiters.remove(msg_id) + + def _process_reply(self, data): + result = None + ending = False + self.msg_id_cache.check_duplicate_message(data) + if data['failure']: + failure = data['failure'] + result = rpc_common.deserialize_remote_exception(self.conf, + failure) + elif data.get('ending', False): + ending = True + else: + result = data['result'] + return result, ending + + def _poll_connection(self, msg_id): + while True: + while self.incoming: + message_data = self.incoming.pop(0) + if message_data.pop('_msg_id', None) == msg_id: + return self._process_reply(message_data) + + self.waiters.put(msg_id, message_data) + + # FIXME(markmc): timeout? + self.conn.consume(limit=1) + + def _poll_queue(self, msg_id): + while True: + # FIXME(markmc): timeout? + message = self.waiters.get(msg_id) + if message is None: + return None, None, True # lock was released + + reply, ending = self._process_reply(message) + return reply, ending, False + + def wait(self, msg_id): + # NOTE(markmc): multiple threads may call this + # First thread calls consume, when it gets its reply + # it wakes up other threads and they call consume + # If a thread gets a message destined for another + # thread, it wakes up the other thread + final_reply = None + while True: + if self.conn_lock.acquire(blocking=False): + try: + reply, ending = self._poll_connection(msg_id) + if reply: + final_reply = reply + elif ending: + return final_reply + finally: + self.conn_lock.release() + self.waiters.wake_all(msg_id) + else: + reply, ending, trylock = self._poll_queue(msg_id) + if trylock: + continue + if reply: + final_reply = reply + elif ending: + return final_reply + + +class RabbitDriver(base.BaseDriver): + + def __init__(self, conf, url=None, default_exchange=None): + super(RabbitDriver, self).__init__(conf, url, default_exchange) + + self.conf.register_opts(kombu_opts) + self.conf.register_opts(rabbit_opts) + self.conf.register_opts(rpc_amqp.amqp_opts) + + self._default_exchange = urls.exchange_from_url(url, default_exchange) + + # FIXME(markmc): temp hack + if self._default_exchange: + self.conf.set_override('control_exchange', self._default_exchange) + + # FIXME(markmc): get connection params from URL in addition to conf + # FIXME(markmc): close connections + self._connection_pool = rpc_amqp.get_connection_pool(self.conf, + Connection) + + self._reply_q_lock = threading.Lock() + self._reply_q = None + self._reply_q_conn = None + self._waiter = None + + def _get_connection(self, pooled=True): + return rpc_amqp.ConnectionContext(self.conf, + self._connection_pool, + pooled=pooled) + + def _get_reply_q(self): + with self._reply_q_lock: + if self._reply_q is not None: + return self._reply_q + + reply_q = 'reply_' + uuid.uuid4().hex + + conn = self._get_connection(pooled=False) + + self._waiter = RabbitWaiter(self.conf, reply_q, conn) + + self._reply_q = reply_q + self._reply_q_conn = conn + + return self._reply_q + + def send(self, target, ctxt, message, + wait_for_reply=None, timeout=None, envelope=False): + + # FIXME(markmc): remove this temporary hack + class Context(object): + def __init__(self, d): + self.d = d + + def to_dict(self): + return self.d + + context = Context(ctxt) + msg = message + + msg_id = uuid.uuid4().hex + msg.update({'_msg_id': msg_id}) + LOG.debug(_('MSG_ID is %s') % (msg_id)) + rpc_amqp._add_unique_id(msg) + rpc_amqp.pack_context(msg, context) + + msg.update({'_reply_q': self._get_reply_q()}) + + # FIXME(markmc): handle envelope param + msg = rpc_common.serialize_msg(msg) + + if wait_for_reply: + self._waiter.listen(msg_id) + + try: + with self._get_connection() as conn: + # FIXME(markmc): check that target.topic is set + if target.fanout: + conn.fanout_send(target.topic, msg) + else: + topic = target.topic + if target.server: + topic = '%s.%s' % (target.topic, target.server) + conn.topic_send(topic, msg, timeout=timeout) + + if wait_for_reply: + # FIXME(markmc): timeout? + return self._waiter.wait(msg_id) + finally: + if wait_for_reply: + self._waiter.unlisten(msg_id) + + def listen(self, target): + # FIXME(markmc): check that topic.target and topic.server is set + + conn = self._get_connection(pooled=False) + + listener = RabbitListener(self, target, conn) + + conn.declare_topic_consumer(target.topic, listener) + conn.declare_topic_consumer('%s.%s' % (target.topic, target.server), + listener) + conn.declare_fanout_consumer(target.topic, listener) + + return listener