diff --git a/swift/account/server.py b/swift/account/server.py index 360e05d047..0b80b3d5e2 100644 --- a/swift/account/server.py +++ b/swift/account/server.py @@ -31,7 +31,7 @@ import simplejson import swift.common.db from swift.common.db import AccountBroker -from swift.common.utils import get_logger, get_param, hash_path, \ +from swift.common.utils import get_logger, get_param, hash_path, public, \ normalize_timestamp, split_path, storage_directory, TRUE_VALUES from swift.common.constraints import ACCOUNT_LISTING_LIMIT, \ check_mount, check_float, check_utf8 @@ -63,6 +63,7 @@ class AccountController(object): db_path = os.path.join(self.root, drive, db_dir, hsh + '.db') return AccountBroker(db_path, account=account, logger=self.logger) + @public def DELETE(self, req): """Handle HTTP DELETE request.""" start_time = time.time() @@ -88,6 +89,7 @@ class AccountController(object): self.logger.timing_since('DELETE.timing', start_time) return HTTPNoContent(request=req) + @public def PUT(self, req): """Handle HTTP PUT request.""" start_time = time.time() @@ -149,6 +151,7 @@ class AccountController(object): else: return HTTPAccepted(request=req) + @public def HEAD(self, req): """Handle HTTP HEAD request.""" # TODO(refactor): The account server used to provide a 'account and @@ -192,6 +195,7 @@ class AccountController(object): self.logger.timing_since('HEAD.timing', start_time) return HTTPNoContent(request=req, headers=headers) + @public def GET(self, req): """Handle HTTP GET request.""" start_time = time.time() @@ -292,6 +296,7 @@ class AccountController(object): self.logger.timing_since('GET.timing', start_time) return ret + @public def REPLICATE(self, req): """ Handle HTTP REPLICATE request. @@ -318,6 +323,7 @@ class AccountController(object): self.logger.timing_since('REPLICATE.timing', start_time) return ret + @public def POST(self, req): """Handle HTTP POST request.""" start_time = time.time() @@ -357,10 +363,14 @@ class AccountController(object): res = HTTPPreconditionFailed(body='Invalid UTF8') else: try: - if hasattr(self, req.method): - res = getattr(self, req.method)(req) - else: + # disallow methods which are not publicly accessible + try: + method = getattr(self, req.method) + getattr(method, 'publicly_accessible') + except AttributeError: res = HTTPMethodNotAllowed() + else: + res = method(req) except (Exception, Timeout): self.logger.exception(_('ERROR __call__ error with %(method)s' ' %(path)s '), {'method': req.method, 'path': req.path}) diff --git a/swift/common/utils.py b/swift/common/utils.py index f0f0e39dfb..021ee7d231 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -1246,3 +1246,18 @@ def streq_const_time(s1, s2): for (a, b) in zip(s1, s2): result |= ord(a) ^ ord(b) return result == 0 + + +def public(func): + """ + Decorator to declare which methods are publicly accessible as HTTP + requests + + :param func: function to make public + """ + func.publicly_accessible = True + + @functools.wraps(func) + def wrapped(*a, **kw): + return func(*a, **kw) + return wrapped diff --git a/swift/container/server.py b/swift/container/server.py index 0281881163..17098e2241 100644 --- a/swift/container/server.py +++ b/swift/container/server.py @@ -31,7 +31,7 @@ from webob.exc import HTTPAccepted, HTTPBadRequest, HTTPConflict, \ import swift.common.db from swift.common.db import ContainerBroker -from swift.common.utils import get_logger, get_param, hash_path, \ +from swift.common.utils import get_logger, get_param, hash_path, public, \ normalize_timestamp, storage_directory, split_path, validate_sync_to, \ TRUE_VALUES from swift.common.constraints import CONTAINER_LISTING_LIMIT, \ @@ -138,6 +138,7 @@ class ContainerController(object): 'device': account_device}) return None + @public def DELETE(self, req): """Handle HTTP DELETE request.""" start_time = time.time() @@ -187,6 +188,7 @@ class ContainerController(object): return HTTPNoContent(request=req) return HTTPNotFound() + @public def PUT(self, req): """Handle HTTP PUT request.""" start_time = time.time() @@ -255,6 +257,7 @@ class ContainerController(object): else: return HTTPAccepted(request=req) + @public def HEAD(self, req): """Handle HTTP HEAD request.""" start_time = time.time() @@ -288,6 +291,7 @@ class ContainerController(object): self.logger.timing_since('HEAD.timing', start_time) return HTTPNoContent(request=req, headers=headers) + @public def GET(self, req): """Handle HTTP GET request.""" start_time = time.time() @@ -409,6 +413,7 @@ class ContainerController(object): self.logger.timing_since('GET.timing', start_time) return ret + @public def REPLICATE(self, req): """ Handle HTTP REPLICATE request (json-encoded RPC calls for replication.) @@ -434,6 +439,7 @@ class ContainerController(object): self.logger.timing_since('REPLICATE.timing', start_time) return ret + @public def POST(self, req): """Handle HTTP POST request.""" start_time = time.time() @@ -485,10 +491,14 @@ class ContainerController(object): res = HTTPPreconditionFailed(body='Invalid UTF8') else: try: - if hasattr(self, req.method): - res = getattr(self, req.method)(req) - else: + # disallow methods which have not been marked 'public' + try: + method = getattr(self, req.method) + getattr(method, 'publicly_accessible') + except AttributeError: res = HTTPMethodNotAllowed() + else: + res = method(req) except (Exception, Timeout): self.logger.exception(_('ERROR __call__ error with %(method)s' ' %(path)s '), {'method': req.method, 'path': req.path}) diff --git a/swift/obj/server.py b/swift/obj/server.py index 9a636890ed..d1b267bd09 100644 --- a/swift/obj/server.py +++ b/swift/obj/server.py @@ -35,7 +35,7 @@ from webob.exc import HTTPAccepted, HTTPBadRequest, HTTPCreated, \ from xattr import getxattr, setxattr from eventlet import sleep, Timeout, tpool -from swift.common.utils import mkdirs, normalize_timestamp, \ +from swift.common.utils import mkdirs, normalize_timestamp, public, \ storage_directory, hash_path, renamer, fallocate, \ split_path, drop_buffer_cache, get_logger, write_pickle from swift.common.bufferedhttp import http_connect @@ -484,6 +484,7 @@ class ObjectController(object): '%s-%s/%s/%s' % (delete_at, account, container, obj), host, partition, contdevice, headers_out, objdevice) + @public def POST(self, request): """Handle HTTP POST requests for the Swift Object Server.""" start_time = time.time() @@ -543,6 +544,7 @@ class ObjectController(object): self.logger.timing_since('POST.timing', start_time) return response_class(request=request) + @public def PUT(self, request): """Handle HTTP PUT requests for the Swift Object Server.""" start_time = time.time() @@ -641,6 +643,7 @@ class ObjectController(object): self.logger.timing_since('PUT.timing', start_time) return resp + @public def GET(self, request): """Handle HTTP GET requests for the Swift Object Server.""" start_time = time.time() @@ -729,6 +732,7 @@ class ObjectController(object): self.logger.timing_since('GET.timing', start_time) return request.get_response(response) + @public def HEAD(self, request): """Handle HTTP HEAD requests for the Swift Object Server.""" start_time = time.time() @@ -774,6 +778,7 @@ class ObjectController(object): self.logger.timing_since('HEAD.timing', start_time) return response + @public def DELETE(self, request): """Handle HTTP DELETE requests for the Swift Object Server.""" start_time = time.time() @@ -824,6 +829,7 @@ class ObjectController(object): self.logger.timing_since('DELETE.timing', start_time) return resp + @public def REPLICATE(self, request): """ Handle REPLICATE requests for the Swift Object Server. This is used @@ -862,10 +868,14 @@ class ObjectController(object): res = HTTPPreconditionFailed(body='Invalid UTF8') else: try: - if hasattr(self, req.method): - res = getattr(self, req.method)(req) - else: + # disallow methods which have not been marked 'public' + try: + method = getattr(self, req.method) + getattr(method, 'publicly_accessible') + except AttributeError: res = HTTPMethodNotAllowed() + else: + res = method(req) except (Exception, Timeout): self.logger.exception(_('ERROR __call__ error with %(method)s' ' %(path)s '), {'method': req.method, 'path': req.path}) diff --git a/swift/proxy/server.py b/swift/proxy/server.py index aef685fadf..38e9e66e8f 100644 --- a/swift/proxy/server.py +++ b/swift/proxy/server.py @@ -53,7 +53,7 @@ from webob import Request, Response from swift.common.ring import Ring from swift.common.utils import cache_from_env, ContextPool, get_logger, \ - get_remote_client, normalize_timestamp, split_path, TRUE_VALUES + get_remote_client, normalize_timestamp, split_path, TRUE_VALUES, public from swift.common.bufferedhttp import http_connect from swift.common.constraints import check_metadata, check_object_creation, \ check_utf8, CONTAINER_LISTING_LIMIT, MAX_ACCOUNT_NAME_LENGTH, \ @@ -86,21 +86,6 @@ def update_headers(response, headers): response.headers[name] = value -def public(func): - """ - Decorator to declare which methods are publicly accessible as HTTP - requests - - :param func: function to make public - """ - func.publicly_accessible = True - - @functools.wraps(func) - def wrapped(*a, **kw): - return func(*a, **kw) - return wrapped - - def delay_denial(func): """ Decorator to declare which methods should have any swift.authorize call @@ -2022,11 +2007,8 @@ class BaseApplication(object): self.logger.client_ip = get_remote_client(req) try: handler = getattr(controller, req.method) - if not getattr(handler, 'publicly_accessible'): - handler = None + getattr(handler, 'publicly_accessible') except AttributeError: - handler = None - if not handler: self.logger.increment('method_not_allowed') return HTTPMethodNotAllowed(request=req) if path_parts['version']: diff --git a/test/unit/account/test_server.py b/test/unit/account/test_server.py index afa90b9e0d..ad9358ba23 100644 --- a/test/unit/account/test_server.py +++ b/test/unit/account/test_server.py @@ -962,6 +962,30 @@ class TestAccountController(unittest.TestCase): self.assertEquals(errbuf.getvalue(), '') self.assertEquals(outbuf.getvalue()[:4], '400 ') + def test_invalid_method_doesnt_exist(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist', + 'PATH_INFO': '/sda1/p/a'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + + def test_invalid_method_is_not_public(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.controller.__call__({'REQUEST_METHOD': '__init__', + 'PATH_INFO': '/sda1/p/a'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_params_utf8(self): self.controller.PUT(Request.blank('/sda1/p/a', headers={'X-Timestamp': normalize_timestamp(1)}, diff --git a/test/unit/container/test_server.py b/test/unit/container/test_server.py index fef364f301..1674e874c0 100644 --- a/test/unit/container/test_server.py +++ b/test/unit/container/test_server.py @@ -928,6 +928,30 @@ class TestContainerController(unittest.TestCase): self.assertEquals(errbuf.getvalue(), '') self.assertEquals(outbuf.getvalue()[:4], '400 ') + def test_invalid_method_doesnt_exist(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist', + 'PATH_INFO': '/sda1/p/a/c'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + + def test_invalid_method_is_not_public(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.controller.__call__({'REQUEST_METHOD': '__init__', + 'PATH_INFO': '/sda1/p/a/c'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_params_utf8(self): self.controller.PUT(Request.blank('/sda1/p/a/c', headers={'X-Timestamp': normalize_timestamp(1)}, diff --git a/test/unit/obj/test_server.py b/test/unit/obj/test_server.py index dd9d3c8cb4..96e2af0249 100644 --- a/test/unit/obj/test_server.py +++ b/test/unit/obj/test_server.py @@ -1342,6 +1342,30 @@ class TestObjectController(unittest.TestCase): self.assertEquals(errbuf.getvalue(), '') self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_invalid_method_doesnt_exist(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.object_controller.__call__({'REQUEST_METHOD': 'method_doesnt_exist', + 'PATH_INFO': '/sda1/p/a/c/o'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + + def test_invalid_method_is_not_public(self): + inbuf = StringIO() + errbuf = StringIO() + outbuf = StringIO() + def start_response(*args): + outbuf.writelines(args) + self.object_controller.__call__({'REQUEST_METHOD': '__init__', + 'PATH_INFO': '/sda1/p/a/c/o'}, + start_response) + self.assertEquals(errbuf.getvalue(), '') + self.assertEquals(outbuf.getvalue()[:4], '405 ') + def test_chunked_put(self): listener = listen(('localhost', 0)) port = listener.getsockname()[1] diff --git a/test/unit/proxy/test_server.py b/test/unit/proxy/test_server.py index 1af4893a97..6187502ea7 100644 --- a/test/unit/proxy/test_server.py +++ b/test/unit/proxy/test_server.py @@ -560,6 +560,22 @@ class TestProxyServer(unittest.TestCase): resp = app.handle_request(req) self.assertEquals(resp.status_int, 500) + def test_internal_method_request(self): + baseapp = proxy_server.BaseApplication({}, + FakeMemcache(), container_ring=FakeRing(), object_ring=FakeRing(), + account_ring=FakeRing()) + resp = baseapp.handle_request( + Request.blank('/v1/a', environ={'REQUEST_METHOD': '__init__'})) + self.assertEquals(resp.status, '405 Method Not Allowed') + + def test_inexistent_method_request(self): + baseapp = proxy_server.BaseApplication({}, + FakeMemcache(), container_ring=FakeRing(), account_ring=FakeRing(), + object_ring=FakeRing()) + resp = baseapp.handle_request( + Request.blank('/v1/a', environ={'REQUEST_METHOD': '!invalid'})) + self.assertEquals(resp.status, '405 Method Not Allowed') + def test_calls_authorize_allow(self): called = [False]