Move most new rabbit driver code into amqpdriver

All of this code should be reusable by the rabbit driver.

Change-Id: Ib09e467313c9b68f1eba6b615e6fce83f44fee70
This commit is contained in:
Mark McLoughlin 2013-07-29 07:20:01 +01:00
parent 7bd60904f8
commit 85a386765f
2 changed files with 331 additions and 300 deletions
oslo/messaging/_drivers

@ -0,0 +1,324 @@
# Copyright 2013 Red Hat, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
__all__ = ['AMQPDriverBase']
import logging
import Queue
import threading
import uuid
from oslo.messaging._drivers import amqp as rpc_amqp
from oslo.messaging._drivers import base
from oslo.messaging._drivers import common as rpc_common
from oslo.messaging import _urls as urls
LOG = logging.getLogger(__name__)
class AMQPIncomingMessage(base.IncomingMessage):
def __init__(self, listener, ctxt, message, msg_id, reply_q):
super(AMQPIncomingMessage, self).__init__(listener, ctxt, message)
self.msg_id = msg_id
self.reply_q = reply_q
def _send_reply(self, conn, reply=None, failure=None, ending=False):
# FIXME(markmc): is the reply format really driver specific?
msg = {'result': reply, 'failure': failure}
# FIXME(markmc): given that we're not supporting multicall ...
if ending:
msg['ending'] = True
rpc_amqp._add_unique_id(msg)
# If a reply_q exists, add the msg_id to the reply and pass the
# reply_q to direct_send() to use it as the response queue.
# Otherwise use the msg_id for backward compatibilty.
if self.reply_q:
msg['_msg_id'] = self.msg_id
conn.direct_send(self.reply_q, rpc_common.serialize_msg(msg))
else:
conn.direct_send(self.msg_id, rpc_common.serialize_msg(msg))
def reply(self, reply=None, failure=None):
LOG.info("reply")
with self.listener.driver._get_connection() as conn:
self._send_reply(conn, reply, failure)
self._send_reply(conn, ending=True)
def done(self):
LOG.info("done")
# FIXME(markmc): I'm not sure we need this method ... we've already
# acked the message at this point
class AMQPListener(base.Listener):
def __init__(self, driver, target, conn):
super(AMQPListener, self).__init__(driver, target)
self.conn = conn
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.incoming = []
def __call__(self, message):
# FIXME(markmc): del local.store.context
# FIXME(markmc): logging isn't driver specific
rpc_common._safe_log(LOG.debug, 'received %s', message)
self.msg_id_cache.check_duplicate_message(message)
ctxt = rpc_amqp.unpack_context(self.conf, message)
self.incoming.append(AMQPIncomingMessage(self,
ctxt.to_dict(),
message,
ctxt.msg_id,
ctxt.reply_q))
def poll(self):
while True:
if self.incoming:
return self.incoming.pop(0)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
class ReplyWaiters(object):
def __init__(self):
self._queues = {}
self._wrn_threshhold = 10
def get(self, msg_id):
return self._queues.get(msg_id)
def put(self, msg_id, message_data):
queue = self._queues.get(msg_id)
if not queue:
LOG.warn('No calling threads waiting for msg_id : %(msg_id)s'
', message : %(data)s', {'msg_id': msg_id,
'data': message_data})
LOG.warn('_queues: %s' % str(self._queues))
else:
queue.put(message_data)
def wake_all(self, except_id):
for msg_id in self._queues:
if msg_id == except_id:
continue
self.put(msg_id, None)
def add(self, msg_id, queue):
self._queues[msg_id] = queue
if len(self._queues) > self._wrn_threshhold:
LOG.warn('Number of call queues is greater than warning '
'threshhold: %d. There could be a leak.' %
self._wrn_threshhold)
self._wrn_threshhold *= 2
def remove(self, msg_id):
del self._queues[msg_id]
class ReplyWaiter(object):
def __init__(self, conf, reply_q, conn):
self.conf = conf
self.conn = conn
self.reply_q = reply_q
self.conn_lock = threading.Lock()
self.incoming = []
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.waiters = ReplyWaiters()
conn.declare_direct_consumer(reply_q, self)
def __call__(self, message):
self.incoming.append(message)
def listen(self, msg_id):
queue = Queue.Queue()
self.waiters.add(msg_id, queue)
def unlisten(self, msg_id):
self.waiters.remove(msg_id)
def _process_reply(self, data):
result = None
ending = False
self.msg_id_cache.check_duplicate_message(data)
if data['failure']:
failure = data['failure']
result = rpc_common.deserialize_remote_exception(self.conf,
failure)
elif data.get('ending', False):
ending = True
else:
result = data['result']
return result, ending
def _poll_connection(self, msg_id):
while True:
while self.incoming:
message_data = self.incoming.pop(0)
if message_data.pop('_msg_id', None) == msg_id:
return self._process_reply(message_data)
self.waiters.put(msg_id, message_data)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
def _poll_queue(self, msg_id):
while True:
# FIXME(markmc): timeout?
message = self.waiters.get(msg_id)
if message is None:
return None, None, True # lock was released
reply, ending = self._process_reply(message)
return reply, ending, False
def wait(self, msg_id):
# NOTE(markmc): multiple threads may call this
# First thread calls consume, when it gets its reply
# it wakes up other threads and they call consume
# If a thread gets a message destined for another
# thread, it wakes up the other thread
final_reply = None
while True:
if self.conn_lock.acquire(blocking=False):
try:
reply, ending = self._poll_connection(msg_id)
if reply:
final_reply = reply
elif ending:
return final_reply
finally:
self.conn_lock.release()
self.waiters.wake_all(msg_id)
else:
reply, ending, trylock = self._poll_queue(msg_id)
if trylock:
continue
if reply:
final_reply = reply
elif ending:
return final_reply
class AMQPDriverBase(base.BaseDriver):
def __init__(self, conf, connection_pool, url=None, default_exchange=None):
super(AMQPDriverBase, self).__init__(conf, url, default_exchange)
self._default_exchange = urls.exchange_from_url(url, default_exchange)
# FIXME(markmc): temp hack
if self._default_exchange:
self.conf.set_override('control_exchange', self._default_exchange)
self._connection_pool = connection_pool
self._reply_q_lock = threading.Lock()
self._reply_q = None
self._reply_q_conn = None
self._waiter = None
def _get_connection(self, pooled=True):
return rpc_amqp.ConnectionContext(self.conf,
self._connection_pool,
pooled=pooled)
def _get_reply_q(self):
with self._reply_q_lock:
if self._reply_q is not None:
return self._reply_q
reply_q = 'reply_' + uuid.uuid4().hex
conn = self._get_connection(pooled=False)
self._waiter = ReplyWaiter(self.conf, reply_q, conn)
self._reply_q = reply_q
self._reply_q_conn = conn
return self._reply_q
def send(self, target, ctxt, message,
wait_for_reply=None, timeout=None, envelope=False):
# FIXME(markmc): remove this temporary hack
class Context(object):
def __init__(self, d):
self.d = d
def to_dict(self):
return self.d
context = Context(ctxt)
msg = message
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug('MSG_ID is %s' % (msg_id))
rpc_amqp._add_unique_id(msg)
rpc_amqp.pack_context(msg, context)
msg.update({'_reply_q': self._get_reply_q()})
# FIXME(markmc): handle envelope param
msg = rpc_common.serialize_msg(msg)
if wait_for_reply:
self._waiter.listen(msg_id)
try:
with self._get_connection() as conn:
# FIXME(markmc): check that target.topic is set
if target.fanout:
conn.fanout_send(target.topic, msg)
else:
topic = target.topic
if target.server:
topic = '%s.%s' % (target.topic, target.server)
conn.topic_send(topic, msg, timeout=timeout)
if wait_for_reply:
# FIXME(markmc): timeout?
return self._waiter.wait(msg_id)
finally:
if wait_for_reply:
self._waiter.unlisten(msg_id)
def listen(self, target):
# FIXME(markmc): check that topic.target and topic.server is set
conn = self._get_connection(pooled=False)
listener = AMQPListener(self, target, conn)
conn.declare_topic_consumer(target.topic, listener)
conn.declare_topic_consumer('%s.%s' % (target.topic, target.server),
listener)
conn.declare_fanout_consumer(target.topic, listener)
return listener

@ -17,10 +17,8 @@
import functools import functools
import itertools import itertools
import logging import logging
import Queue
import socket import socket
import ssl import ssl
import threading
import time import time
import uuid import uuid
@ -33,9 +31,8 @@ import kombu.messaging
from oslo.config import cfg from oslo.config import cfg
from oslo.messaging._drivers import amqp as rpc_amqp from oslo.messaging._drivers import amqp as rpc_amqp
from oslo.messaging._drivers import base from oslo.messaging._drivers import amqpdriver
from oslo.messaging._drivers import common as rpc_common from oslo.messaging._drivers import common as rpc_common
from oslo.messaging import _urls as urls
from oslo.messaging.openstack.common import excutils from oslo.messaging.openstack.common import excutils
from oslo.messaging.openstack.common import network_utils from oslo.messaging.openstack.common import network_utils
from oslo.messaging.openstack.common import sslutils from oslo.messaging.openstack.common import sslutils
@ -874,303 +871,13 @@ def cleanup():
return rpc_amqp.cleanup(Connection.pool) return rpc_amqp.cleanup(Connection.pool)
class RabbitIncomingMessage(base.IncomingMessage): class RabbitDriver(amqpdriver.AMQPDriverBase):
def __init__(self, listener, ctxt, message, msg_id, reply_q):
super(RabbitIncomingMessage, self).__init__(listener, ctxt, message)
self.msg_id = msg_id
self.reply_q = reply_q
def _send_reply(self, conn, reply=None, failure=None, ending=False):
# FIXME(markmc): is the reply format really driver specific?
msg = {'result': reply, 'failure': failure}
# FIXME(markmc): given that we're not supporting multicall ...
if ending:
msg['ending'] = True
rpc_amqp._add_unique_id(msg)
# If a reply_q exists, add the msg_id to the reply and pass the
# reply_q to direct_send() to use it as the response queue.
# Otherwise use the msg_id for backward compatibilty.
if self.reply_q:
msg['_msg_id'] = self.msg_id
conn.direct_send(self.reply_q, rpc_common.serialize_msg(msg))
else:
conn.direct_send(self.msg_id, rpc_common.serialize_msg(msg))
def reply(self, reply=None, failure=None):
LOG.info("reply")
with self.listener.driver._get_connection() as conn:
self._send_reply(conn, reply, failure)
self._send_reply(conn, ending=True)
def done(self):
LOG.info("done")
# FIXME(markmc): I'm not sure we need this method ... we've already
# acked the message at this point
class RabbitListener(base.Listener):
def __init__(self, driver, target, conn):
super(RabbitListener, self).__init__(driver, target)
self.conn = conn
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.incoming = []
def __call__(self, message):
# FIXME(markmc): del local.store.context
# FIXME(markmc): logging isn't driver specific
rpc_common._safe_log(LOG.debug, _('received %s'), message)
self.msg_id_cache.check_duplicate_message(message)
ctxt = rpc_amqp.unpack_context(self.conf, message)
self.incoming.append(RabbitIncomingMessage(self,
ctxt.to_dict(),
message,
ctxt.msg_id,
ctxt.reply_q))
def poll(self):
while True:
if self.incoming:
return self.incoming.pop(0)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
class ReplyWaiters(object):
def __init__(self):
self._queues = {}
self._wrn_threshhold = 10
def get(self, msg_id):
return self._queues.get(msg_id)
def put(self, msg_id, message_data):
queue = self._queues.get(msg_id)
if not queue:
LOG.warn(_('No calling threads waiting for msg_id : %(msg_id)s'
', message : %(data)s'), {'msg_id': msg_id,
'data': message_data})
LOG.warn(_('_queues: %s') % str(self._queues))
else:
queue.put(message_data)
def wake_all(self, except_id):
for msg_id in self._queues:
if msg_id == except_id:
continue
self.put(msg_id, None)
def add(self, msg_id, queue):
self._queues[msg_id] = queue
if len(self._queues) > self._wrn_threshhold:
LOG.warn(_('Number of call queues is greater than warning '
'threshhold: %d. There could be a leak.') %
self._wrn_threshhold)
self._wrn_threshhold *= 2
def remove(self, msg_id):
del self._queues[msg_id]
class RabbitWaiter(object):
def __init__(self, conf, reply_q, conn):
self.conf = conf
self.conn = conn
self.reply_q = reply_q
self.conn_lock = threading.Lock()
self.incoming = []
self.msg_id_cache = rpc_amqp._MsgIdCache()
self.waiters = ReplyWaiters()
conn.declare_direct_consumer(reply_q, self)
def __call__(self, message):
self.incoming.append(message)
def listen(self, msg_id):
queue = Queue.Queue()
self.waiters.add(msg_id, queue)
def unlisten(self, msg_id):
self.waiters.remove(msg_id)
def _process_reply(self, data):
result = None
ending = False
self.msg_id_cache.check_duplicate_message(data)
if data['failure']:
failure = data['failure']
result = rpc_common.deserialize_remote_exception(self.conf,
failure)
elif data.get('ending', False):
ending = True
else:
result = data['result']
return result, ending
def _poll_connection(self, msg_id):
while True:
while self.incoming:
message_data = self.incoming.pop(0)
if message_data.pop('_msg_id', None) == msg_id:
return self._process_reply(message_data)
self.waiters.put(msg_id, message_data)
# FIXME(markmc): timeout?
self.conn.consume(limit=1)
def _poll_queue(self, msg_id):
while True:
# FIXME(markmc): timeout?
message = self.waiters.get(msg_id)
if message is None:
return None, None, True # lock was released
reply, ending = self._process_reply(message)
return reply, ending, False
def wait(self, msg_id):
# NOTE(markmc): multiple threads may call this
# First thread calls consume, when it gets its reply
# it wakes up other threads and they call consume
# If a thread gets a message destined for another
# thread, it wakes up the other thread
final_reply = None
while True:
if self.conn_lock.acquire(blocking=False):
try:
reply, ending = self._poll_connection(msg_id)
if reply:
final_reply = reply
elif ending:
return final_reply
finally:
self.conn_lock.release()
self.waiters.wake_all(msg_id)
else:
reply, ending, trylock = self._poll_queue(msg_id)
if trylock:
continue
if reply:
final_reply = reply
elif ending:
return final_reply
class RabbitDriver(base.BaseDriver):
def __init__(self, conf, url=None, default_exchange=None): def __init__(self, conf, url=None, default_exchange=None):
super(RabbitDriver, self).__init__(conf, url, default_exchange) conf.register_opts(rabbit_opts)
conf.register_opts(rpc_amqp.amqp_opts)
self.conf.register_opts(rabbit_opts) connection_pool = rpc_amqp.get_connection_pool(conf, Connection)
self.conf.register_opts(rpc_amqp.amqp_opts)
self._default_exchange = urls.exchange_from_url(url, default_exchange) super(RabbitDriver, self).__init__(conf, connection_pool,
url, default_exchange)
# FIXME(markmc): temp hack
if self._default_exchange:
self.conf.set_override('control_exchange', self._default_exchange)
# FIXME(markmc): get connection params from URL in addition to conf
# FIXME(markmc): close connections
self._connection_pool = rpc_amqp.get_connection_pool(self.conf,
Connection)
self._reply_q_lock = threading.Lock()
self._reply_q = None
self._reply_q_conn = None
self._waiter = None
def _get_connection(self, pooled=True):
return rpc_amqp.ConnectionContext(self.conf,
self._connection_pool,
pooled=pooled)
def _get_reply_q(self):
with self._reply_q_lock:
if self._reply_q is not None:
return self._reply_q
reply_q = 'reply_' + uuid.uuid4().hex
conn = self._get_connection(pooled=False)
self._waiter = RabbitWaiter(self.conf, reply_q, conn)
self._reply_q = reply_q
self._reply_q_conn = conn
return self._reply_q
def send(self, target, ctxt, message,
wait_for_reply=None, timeout=None, envelope=False):
# FIXME(markmc): remove this temporary hack
class Context(object):
def __init__(self, d):
self.d = d
def to_dict(self):
return self.d
context = Context(ctxt)
msg = message
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id))
rpc_amqp._add_unique_id(msg)
rpc_amqp.pack_context(msg, context)
msg.update({'_reply_q': self._get_reply_q()})
# FIXME(markmc): handle envelope param
msg = rpc_common.serialize_msg(msg)
if wait_for_reply:
self._waiter.listen(msg_id)
try:
with self._get_connection() as conn:
# FIXME(markmc): check that target.topic is set
if target.fanout:
conn.fanout_send(target.topic, msg)
else:
topic = target.topic
if target.server:
topic = '%s.%s' % (target.topic, target.server)
conn.topic_send(topic, msg, timeout=timeout)
if wait_for_reply:
# FIXME(markmc): timeout?
return self._waiter.wait(msg_id)
finally:
if wait_for_reply:
self._waiter.unlisten(msg_id)
def listen(self, target):
# FIXME(markmc): check that topic.target and topic.server is set
conn = self._get_connection(pooled=False)
listener = RabbitListener(self, target, conn)
conn.declare_topic_consumer(target.topic, listener)
conn.declare_topic_consumer('%s.%s' % (target.topic, target.server),
listener)
conn.declare_fanout_consumer(target.topic, listener)
return listener