Allow dispatcher to restrict endpoint methods.

Implements access_policy for dispatcher to restrict endpoint methods.

Implements the following access policies:
* LegacyRPCAccessPolicy
* DefaultRPCAccessPolicy
* ExplicitRPCAccessPolicy

* Implement decorator @rpc.expose for use with the
 ExplicitRPCAccessPolicy

* Modify get_rpc_server to allow optional access_policy argument
* Set default access_policy to LegacyRPCAccessPolicy (Nova exposes
 _associate_floating_ip in tempest tests). Added debtcollector
 notification.
* Add test cases for access_policy=None
* Clarify documentation

Change-Id: I42239e6c8a8be158ddf5c3b1773463b7dc93e881
Closes-Bug: 1194279
Closes-Bug: 1555845
This commit is contained in:
Paul Vinciguerra 2016-08-21 19:46:28 -04:00
parent c6ce04c975
commit d3a8f280eb
6 changed files with 227 additions and 20 deletions

View File

@ -8,6 +8,14 @@ Server
.. autofunction:: get_rpc_server .. autofunction:: get_rpc_server
.. autoclass:: RPCAccessPolicyBase
.. autoclass:: LegacyRPCAccessPolicy
.. autoclass:: DefaultRPCAccessPolicy
.. autoclass:: ExplicitRPCAccessPolicy
.. autoclass:: RPCDispatcher .. autoclass:: RPCDispatcher
.. autoclass:: MessageHandlingServer .. autoclass:: MessageHandlingServer
@ -15,6 +23,8 @@ Server
.. autofunction:: expected_exceptions .. autofunction:: expected_exceptions
.. autofunction:: expose
.. autoexception:: ExpectedException .. autoexception:: ExpectedException
.. autofunction:: get_local_context .. autofunction:: get_local_context

View File

@ -18,6 +18,10 @@ __all__ = [
'ExpectedException', 'ExpectedException',
'NoSuchMethod', 'NoSuchMethod',
'RPCClient', 'RPCClient',
'RPCAccessPolicyBase',
'LegacyRPCAccessPolicy',
'DefaultRPCAccessPolicy',
'ExplicitRPCAccessPolicy',
'RPCDispatcher', 'RPCDispatcher',
'RPCDispatcherError', 'RPCDispatcherError',
'RPCVersionCapError', 'RPCVersionCapError',
@ -25,6 +29,7 @@ __all__ = [
'UnsupportedVersion', 'UnsupportedVersion',
'expected_exceptions', 'expected_exceptions',
'get_rpc_server', 'get_rpc_server',
'expose'
] ]
from .client import * from .client import *

View File

@ -18,17 +18,24 @@
__all__ = [ __all__ = [
'NoSuchMethod', 'NoSuchMethod',
'RPCAccessPolicyBase',
'LegacyRPCAccessPolicy',
'DefaultRPCAccessPolicy',
'ExplicitRPCAccessPolicy',
'RPCDispatcher', 'RPCDispatcher',
'RPCDispatcherError', 'RPCDispatcherError',
'UnsupportedVersion', 'UnsupportedVersion',
'ExpectedException', 'ExpectedException',
] ]
from abc import ABCMeta
from abc import abstractmethod
import logging import logging
import sys import sys
import six import six
from debtcollector.updating import updated_kwarg_default_value
from oslo_messaging import _utils as utils from oslo_messaging import _utils as utils
from oslo_messaging import dispatcher from oslo_messaging import dispatcher
from oslo_messaging import serializer as msg_serializer from oslo_messaging import serializer as msg_serializer
@ -74,6 +81,52 @@ class UnsupportedVersion(RPCDispatcherError):
self.method = method self.method = method
@six.add_metaclass(ABCMeta)
class RPCAccessPolicyBase(object):
"""Determines which endpoint methods may be invoked via RPC"""
@abstractmethod
def is_allowed(self, endpoint, method):
"""Applies an access policy to the rpc method
:param endpoint: the instance of a rpc endpoint
:param method: the method of the endpoint
:return: True if the method may be invoked via RPC, else False.
"""
class LegacyRPCAccessPolicy(RPCAccessPolicyBase):
"""The legacy access policy allows RPC access to all callable endpoint
methods including private methods (methods prefixed by '_')
"""
def is_allowed(self, endpoint, method):
return True
class DefaultRPCAccessPolicy(RPCAccessPolicyBase):
"""The default access policy prevents RPC calls to private methods
(methods prefixed by '_')
.. note::
LegacyRPCAdapterPolicy currently needs to be the default while we have
projects that rely on exposing private methods.
"""
def is_allowed(self, endpoint, method):
return not method.startswith('_')
class ExplicitRPCAccessPolicy(RPCAccessPolicyBase):
"""Policy which requires decorated endpoint methods to allow dispatch"""
def is_allowed(self, endpoint, method):
if hasattr(endpoint, method):
return hasattr(getattr(endpoint, method), 'exposed')
return False
class RPCDispatcher(dispatcher.DispatcherBase): class RPCDispatcher(dispatcher.DispatcherBase):
"""A message dispatcher which understands RPC messages. """A message dispatcher which understands RPC messages.
@ -86,13 +139,24 @@ class RPCDispatcher(dispatcher.DispatcherBase):
in the message and matches those against a list of available endpoints. in the message and matches those against a list of available endpoints.
Endpoints may have a target attribute describing the namespace and version Endpoints may have a target attribute describing the namespace and version
of the methods exposed by that object. All public methods on an endpoint of the methods exposed by that object.
object are remotely invokable by clients.
The RPCDispatcher may have an access_policy attribute which determines
which of the endpoint methods are to be dispatched.
The default access_policy dispatches all public methods
on an endpoint object.
""" """
@updated_kwarg_default_value('access_policy', None, DefaultRPCAccessPolicy,
def __init__(self, endpoints, serializer): message='access_policy defaults to '
'LegacyRPCAccessPolicy which '
'exposes private methods. Explicitly '
'set access_policy to '
'DefaultRPCAccessPolicy or '
'ExplicitRPCAccessPolicy.',
version='?')
def __init__(self, endpoints, serializer, access_policy=None):
"""Construct a rpc server dispatcher. """Construct a rpc server dispatcher.
:param endpoints: list of endpoint objects for dispatching to :param endpoints: list of endpoint objects for dispatching to
@ -102,6 +166,16 @@ class RPCDispatcher(dispatcher.DispatcherBase):
self.endpoints = endpoints self.endpoints = endpoints
self.serializer = serializer or msg_serializer.NoOpSerializer() self.serializer = serializer or msg_serializer.NoOpSerializer()
self._default_target = msg_target.Target() self._default_target = msg_target.Target()
if access_policy is not None:
if issubclass(access_policy, RPCAccessPolicyBase):
self.access_policy = access_policy()
else:
raise TypeError('access_policy must be a subclass of '
'RPCAccessPolicyBase')
else:
# TODO(pvinci): Change to DefaultRPCAccessPolicy when setting to
# DefaultRCPAccessPolicy no longer breaks in tempest tests.
self.access_policy = LegacyRPCAccessPolicy()
@staticmethod @staticmethod
def _is_namespace(target, namespace): def _is_namespace(target, namespace):
@ -147,7 +221,8 @@ class RPCDispatcher(dispatcher.DispatcherBase):
continue continue
if hasattr(endpoint, method): if hasattr(endpoint, method):
return self._do_dispatch(endpoint, method, ctxt, args) if self.access_policy.is_allowed(endpoint, method):
return self._do_dispatch(endpoint, method, ctxt, args)
found_compatible = True found_compatible = True

View File

@ -100,6 +100,7 @@ to - primitive types.
__all__ = [ __all__ = [
'get_rpc_server', 'get_rpc_server',
'expected_exceptions', 'expected_exceptions',
'expose'
] ]
import logging import logging
@ -156,7 +157,7 @@ class RPCServer(msg_server.MessageHandlingServer):
def get_rpc_server(transport, target, endpoints, def get_rpc_server(transport, target, endpoints,
executor='blocking', serializer=None): executor='blocking', serializer=None, access_policy=None):
"""Construct an RPC server. """Construct an RPC server.
The executor parameter controls how incoming messages will be received and The executor parameter controls how incoming messages will be received and
@ -177,8 +178,12 @@ def get_rpc_server(transport, target, endpoints,
:type executor: str :type executor: str
:param serializer: an optional entity serializer :param serializer: an optional entity serializer
:type serializer: Serializer :type serializer: Serializer
:param access_policy: an optional access policy.
Defaults to LegacyRPCAccessPolicy
:type access_policy: RPCAccessPolicyBase
""" """
dispatcher = rpc_dispatcher.RPCDispatcher(endpoints, serializer) dispatcher = rpc_dispatcher.RPCDispatcher(endpoints, serializer,
access_policy)
return RPCServer(transport, target, dispatcher, executor) return RPCServer(transport, target, dispatcher, executor)
@ -207,3 +212,25 @@ def expected_exceptions(*exceptions):
raise rpc_dispatcher.ExpectedException() raise rpc_dispatcher.ExpectedException()
return inner return inner
return outer return outer
def expose(func):
"""Decorator for RPC endpoint methods that are exposed to the RPC client.
If the dispatcher's access_policy is set to ExplicitRPCAccessPolicy then
endpoint methods need to be explicitly exposed.::
# foo() cannot be invoked by an RPC client
def foo(self):
pass
# bar() can be invoked by an RPC client
@rpc.expose
def bar(self):
pass
"""
func.exposed = True
return func

View File

@ -1,4 +1,3 @@
# Copyright 2013 Red Hat, Inc. # Copyright 2013 Red Hat, Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -16,6 +15,7 @@
import testscenarios import testscenarios
import oslo_messaging import oslo_messaging
from oslo_messaging import rpc
from oslo_messaging import serializer as msg_serializer from oslo_messaging import serializer as msg_serializer
from oslo_messaging.tests import utils as test_utils from oslo_messaging.tests import utils as test_utils
from six.moves import mock from six.moves import mock
@ -24,92 +24,161 @@ load_tests = testscenarios.load_tests_apply_scenarios
class _FakeEndpoint(object): class _FakeEndpoint(object):
def __init__(self, target=None): def __init__(self, target=None):
self.target = target self.target = target
def foo(self, ctxt, **kwargs): def foo(self, ctxt, **kwargs):
pass pass
@rpc.expose
def bar(self, ctxt, **kwargs): def bar(self, ctxt, **kwargs):
pass pass
def _foobar(self, ctxt, **kwargs):
pass
class TestDispatcher(test_utils.BaseTestCase): class TestDispatcher(test_utils.BaseTestCase):
scenarios = [ scenarios = [
('no_endpoints', ('no_endpoints',
dict(endpoints=[], dict(endpoints=[],
access_policy=None,
dispatch_to=None, dispatch_to=None,
ctxt={}, msg=dict(method='foo'), ctxt={}, msg=dict(method='foo'),
exposed_methods=['foo', 'bar', '_foobar'],
success=False, ex=oslo_messaging.UnsupportedVersion)), success=False, ex=oslo_messaging.UnsupportedVersion)),
('default_target', ('default_target',
dict(endpoints=[{}], dict(endpoints=[{}],
access_policy=None,
dispatch_to=dict(endpoint=0, method='foo'), dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo'), ctxt={}, msg=dict(method='foo'),
exposed_methods=['foo', 'bar', '_foobar'],
success=True, ex=None)), success=True, ex=None)),
('default_target_ctxt_and_args', ('default_target_ctxt_and_args',
dict(endpoints=[{}], dict(endpoints=[{}],
access_policy=oslo_messaging.LegacyRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='bar'), dispatch_to=dict(endpoint=0, method='bar'),
ctxt=dict(user='bob'), msg=dict(method='bar', ctxt=dict(user='bob'), msg=dict(method='bar',
args=dict(blaa=True)), args=dict(blaa=True)),
exposed_methods=['foo', 'bar', '_foobar'],
success=True, ex=None)), success=True, ex=None)),
('default_target_namespace', ('default_target_namespace',
dict(endpoints=[{}], dict(endpoints=[{}],
access_policy=oslo_messaging.LegacyRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='foo'), dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo', namespace=None), ctxt={}, msg=dict(method='foo', namespace=None),
exposed_methods=['foo', 'bar', '_foobar'],
success=True, ex=None)), success=True, ex=None)),
('default_target_version', ('default_target_version',
dict(endpoints=[{}], dict(endpoints=[{}],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='foo'), dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo', version='1.0'), ctxt={}, msg=dict(method='foo', version='1.0'),
exposed_methods=['foo', 'bar'],
success=True, ex=None)), success=True, ex=None)),
('default_target_no_such_method', ('default_target_no_such_method',
dict(endpoints=[{}], dict(endpoints=[{}],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=None, dispatch_to=None,
ctxt={}, msg=dict(method='foobar'), ctxt={}, msg=dict(method='foobar'),
exposed_methods=['foo', 'bar'],
success=False, ex=oslo_messaging.NoSuchMethod)), success=False, ex=oslo_messaging.NoSuchMethod)),
('namespace', ('namespace',
dict(endpoints=[{}, dict(namespace='testns')], dict(endpoints=[{}, dict(namespace='testns')],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=1, method='foo'), dispatch_to=dict(endpoint=1, method='foo'),
ctxt={}, msg=dict(method='foo', namespace='testns'), ctxt={}, msg=dict(method='foo', namespace='testns'),
exposed_methods=['foo', 'bar'],
success=True, ex=None)), success=True, ex=None)),
('namespace_mismatch', ('namespace_mismatch',
dict(endpoints=[{}, dict(namespace='testns')], dict(endpoints=[{}, dict(namespace='testns')],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=None, dispatch_to=None,
ctxt={}, msg=dict(method='foo', namespace='nstest'), ctxt={}, msg=dict(method='foo', namespace='nstest'),
exposed_methods=['foo', 'bar'],
success=False, ex=oslo_messaging.UnsupportedVersion)), success=False, ex=oslo_messaging.UnsupportedVersion)),
('version', ('version',
dict(endpoints=[dict(version='1.5'), dict(version='3.4')], dict(endpoints=[dict(version='1.5'), dict(version='3.4')],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=1, method='foo'), dispatch_to=dict(endpoint=1, method='foo'),
ctxt={}, msg=dict(method='foo', version='3.2'), ctxt={}, msg=dict(method='foo', version='3.2'),
exposed_methods=['foo', 'bar'],
success=True, ex=None)), success=True, ex=None)),
('version_mismatch', ('version_mismatch',
dict(endpoints=[dict(version='1.5'), dict(version='3.0')], dict(endpoints=[dict(version='1.5'), dict(version='3.0')],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=None, dispatch_to=None,
ctxt={}, msg=dict(method='foo', version='3.2'), ctxt={}, msg=dict(method='foo', version='3.2'),
exposed_methods=['foo', 'bar'],
success=False, ex=oslo_messaging.UnsupportedVersion)), success=False, ex=oslo_messaging.UnsupportedVersion)),
('message_in_null_namespace_with_multiple_namespaces', ('message_in_null_namespace_with_multiple_namespaces',
dict(endpoints=[dict(namespace='testns', dict(endpoints=[dict(namespace='testns',
legacy_namespaces=[None])], legacy_namespaces=[None])],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='foo'), dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo', namespace=None), ctxt={}, msg=dict(method='foo', namespace=None),
exposed_methods=['foo', 'bar'],
success=True, ex=None)), success=True, ex=None)),
('message_in_wrong_namespace_with_multiple_namespaces', ('message_in_wrong_namespace_with_multiple_namespaces',
dict(endpoints=[dict(namespace='testns', dict(endpoints=[dict(namespace='testns',
legacy_namespaces=['second', None])], legacy_namespaces=['second', None])],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=None, dispatch_to=None,
ctxt={}, msg=dict(method='foo', namespace='wrong'), ctxt={}, msg=dict(method='foo', namespace='wrong'),
exposed_methods=['foo', 'bar'],
success=False, ex=oslo_messaging.UnsupportedVersion)), success=False, ex=oslo_messaging.UnsupportedVersion)),
('message_with_endpoint_no_private_and_public_method',
dict(endpoints=[dict(namespace='testns',
legacy_namespaces=['second', None])],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo', namespace='testns'),
exposed_methods=['foo', 'bar'],
success=True, ex=None)),
('message_with_endpoint_no_private_and_private_method',
dict(endpoints=[dict(namespace='testns',
legacy_namespaces=['second', None], )],
access_policy=oslo_messaging.DefaultRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='_foobar'),
ctxt={}, msg=dict(method='_foobar', namespace='testns'),
exposed_methods=['foo', 'bar'],
success=False, ex=oslo_messaging.NoSuchMethod)),
('message_with_endpoint_explicitly_exposed_without_exposed_method',
dict(endpoints=[dict(namespace='testns',
legacy_namespaces=['second', None], )],
access_policy=oslo_messaging.ExplicitRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='foo'),
ctxt={}, msg=dict(method='foo', namespace='testns'),
exposed_methods=['bar'],
success=False, ex=oslo_messaging.NoSuchMethod)),
('message_with_endpoint_explicitly_exposed_with_exposed_method',
dict(endpoints=[dict(namespace='testns',
legacy_namespaces=['second', None], )],
access_policy=oslo_messaging.ExplicitRPCAccessPolicy,
dispatch_to=dict(endpoint=0, method='bar'),
ctxt={}, msg=dict(method='bar', namespace='testns'),
exposed_methods=['bar'],
success=True, ex=None)),
] ]
def test_dispatcher(self): def test_dispatcher(self):
endpoints = [mock.Mock(spec=_FakeEndpoint,
target=oslo_messaging.Target(**e)) def _set_endpoint_mock_properties(endpoint):
for e in self.endpoints] endpoint.foo = mock.Mock(spec=dir(_FakeEndpoint.foo))
# mock doesn't pick up the decorated method.
endpoint.bar = mock.Mock(spec=dir(_FakeEndpoint.bar))
endpoint.bar.exposed = mock.PropertyMock(return_value=True)
endpoint._foobar = mock.Mock(spec=dir(_FakeEndpoint._foobar))
return endpoint
endpoints = [_set_endpoint_mock_properties(mock.Mock(
spec=_FakeEndpoint, target=oslo_messaging.Target(**e)))
for e in self.endpoints]
serializer = None serializer = None
dispatcher = oslo_messaging.RPCDispatcher(endpoints, serializer) dispatcher = oslo_messaging.RPCDispatcher(endpoints, serializer,
self.access_policy)
incoming = mock.Mock(ctxt=self.ctxt, message=self.msg) incoming = mock.Mock(ctxt=self.ctxt, message=self.msg)
@ -130,22 +199,23 @@ class TestDispatcher(test_utils.BaseTestCase):
self.assertEqual(self.msg.get('method'), ex.method) self.assertEqual(self.msg.get('method'), ex.method)
else: else:
self.assertTrue(self.success, self.assertTrue(self.success,
"Not expected success of operation durung testing") "Unexpected success of operation during testing")
self.assertIsNotNone(res) self.assertIsNotNone(res)
for n, endpoint in enumerate(endpoints): for n, endpoint in enumerate(endpoints):
for method_name in ['foo', 'bar']: for method_name in self.exposed_methods:
method = getattr(endpoint, method_name) method = getattr(endpoint, method_name)
if self.dispatch_to and n == self.dispatch_to['endpoint'] and \ if self.dispatch_to and n == self.dispatch_to['endpoint'] and \
method_name == self.dispatch_to['method']: method_name == self.dispatch_to['method'] and \
method_name in self.exposed_methods:
method.assert_called_once_with( method.assert_called_once_with(
self.ctxt, **self.msg.get('args', {})) self.ctxt, **self.msg.get('args', {}))
else: else:
self.assertEqual(0, method.call_count) self.assertEqual(0, method.call_count,
'method: {}'.format(method))
class TestSerializer(test_utils.BaseTestCase): class TestSerializer(test_utils.BaseTestCase):
scenarios = [ scenarios = [
('no_args_or_retval', ('no_args_or_retval',
dict(ctxt={}, dctxt={}, args={}, retval=None)), dict(ctxt={}, dctxt={}, args={}, retval=None)),
@ -174,7 +244,7 @@ class TestSerializer(test_utils.BaseTestCase):
for arg in self.args: for arg in self.args:
serializer.deserialize_entity(self.dctxt, arg).AndReturn('d' + arg) serializer.deserialize_entity(self.dctxt, arg).AndReturn('d' + arg)
serializer.serialize_entity(self.dctxt, self.retval).\ serializer.serialize_entity(self.dctxt, self.retval). \
AndReturn('s' + self.retval if self.retval else None) AndReturn('s' + self.retval if self.retval else None)
self.mox.ReplayAll() self.mox.ReplayAll()

View File

@ -21,6 +21,7 @@ import testscenarios
import mock import mock
import oslo_messaging import oslo_messaging
from oslo_messaging import rpc
from oslo_messaging.rpc import server as rpc_server_module from oslo_messaging.rpc import server as rpc_server_module
from oslo_messaging import server as server_module from oslo_messaging import server as server_module
from oslo_messaging.tests import utils as test_utils from oslo_messaging.tests import utils as test_utils
@ -861,3 +862,22 @@ class TestServerLocking(test_utils.BaseTestCase):
# We timed out. Ensure we didn't log anything. # We timed out. Ensure we didn't log anything.
self.assertFalse(mock_log.warning.called) self.assertFalse(mock_log.warning.called)
class TestRPCExposeDecorator(test_utils.BaseTestCase):
def foo(self):
pass
@rpc.expose
def bar(self):
"""bar docstring"""
pass
def test_undecorated(self):
self.assertRaises(AttributeError, lambda: self.foo.exposed)
def test_decorated(self):
self.assertEqual(True, self.bar.exposed)
self.assertEqual("""bar docstring""", self.bar.__doc__)
self.assertEqual('bar', self.bar.__name__)