diff --git a/swift/common/middleware/catch_errors.py b/swift/common/middleware/catch_errors.py index 6e9334795a..106238daab 100644 --- a/swift/common/middleware/catch_errors.py +++ b/swift/common/middleware/catch_errors.py @@ -16,10 +16,44 @@ from swift import gettext_ as _ from swift.common.swob import Request, HTTPServerError -from swift.common.utils import get_logger, generate_trans_id +from swift.common.utils import get_logger, generate_trans_id, close_if_possible from swift.common.wsgi import WSGIContext +class BadResponseLength(Exception): + pass + + +def enforce_byte_count(inner_iter, nbytes): + """ + Enforces that inner_iter yields exactly bytes before + exhaustion. + + If inner_iter fails to do so, BadResponseLength is raised. + + :param inner_iter: iterable of bytestrings + :param nbytes: number of bytes expected + """ + try: + bytes_left = nbytes + for chunk in inner_iter: + if bytes_left >= len(chunk): + yield chunk + bytes_left -= len(chunk) + else: + yield chunk[:bytes_left] + raise BadResponseLength( + "Too many bytes; truncating after %d bytes " + "with at least %d surplus bytes remaining" % ( + nbytes, len(chunk) - bytes_left)) + + if bytes_left: + raise BadResponseLength('Expected another %d bytes' % ( + bytes_left,)) + finally: + close_if_possible(inner_iter) + + class CatchErrorsContext(WSGIContext): def __init__(self, app, logger, trans_id_suffix=''): @@ -35,6 +69,7 @@ class CatchErrorsContext(WSGIContext): trans_id = generate_trans_id(trans_id_suffix) env['swift.trans_id'] = trans_id + method = env['REQUEST_METHOD'] self.logger.txn_id = trans_id try: # catch any errors in the pipeline @@ -48,6 +83,37 @@ class CatchErrorsContext(WSGIContext): resp.headers['X-Openstack-Request-Id'] = trans_id return resp(env, start_response) + # If the app specified a Content-Length, enforce that it sends that + # many bytes. + # + # If an app gives too few bytes, then the client will wait for the + # remainder before sending another HTTP request on the same socket; + # since no more bytes are coming, this will result in either an + # infinite wait or a timeout. In this case, we want to raise an + # exception to signal to the WSGI server that it should close the + # TCP connection. + # + # If an app gives too many bytes, then we can deadlock with the + # client; if the client reads its N bytes and then sends a large-ish + # request (enough to fill TCP buffers), it'll block until we read + # some of the request. However, we won't read the request since + # we'll be trying to shove the rest of our oversized response out + # the socket. In that case, we truncate the response body at N bytes + # and raise an exception to stop any more bytes from being + # generated and also to kill the TCP connection. + if self._response_headers: + content_lengths = [val for header, val in self._response_headers + if header.lower() == "content-length"] + if len(content_lengths) == 1: + try: + content_length = int(content_lengths[0]) + except ValueError: + pass + else: + resp = enforce_byte_count( + resp, + 0 if method == 'HEAD' else content_length) + # make sure the response has the trans_id if self._response_headers is None: self._response_headers = [] diff --git a/swift/common/middleware/dlo.py b/swift/common/middleware/dlo.py index 4c4ce00bff..5c1ba9aeed 100644 --- a/swift/common/middleware/dlo.py +++ b/swift/common/middleware/dlo.py @@ -154,6 +154,9 @@ class GetContext(WSGIContext): con_resp = con_req.get_response(self.dlo.app) if not is_success(con_resp.status_int): + if req.method == 'HEAD': + close_if_possible(con_resp.app_iter) + con_resp.body = '' return con_resp, None with closing_if_possible(con_resp.app_iter): return None, json.loads(''.join(con_resp.app_iter)) diff --git a/test/unit/common/middleware/test_catch_errors.py b/test/unit/common/middleware/test_catch_errors.py index 36996b457e..1e5ee85ce7 100644 --- a/test/unit/common/middleware/test_catch_errors.py +++ b/test/unit/common/middleware/test_catch_errors.py @@ -137,6 +137,63 @@ class TestCatchErrors(unittest.TestCase): resp = app(req.environ, self.start_response) self.assertEqual(list(resp), ['An error occurred']) + def test_HEAD_with_content_length(self): + def cannot_count_app(env, sr): + sr("200 OK", [("Content-Length", "10")]) + return [b""] + + app = catch_errors.CatchErrorMiddleware(cannot_count_app, {}) + list(app({'REQUEST_METHOD': 'HEAD'}, self.start_response)) + + def test_short_response_body(self): + + def cannot_count_app(env, sr): + sr("200 OK", [("Content-Length", "2000")]) + return [b"our staff tailor is Euripedes Imenedes"] + + app = catch_errors.CatchErrorMiddleware(cannot_count_app, {}) + + with self.assertRaises(catch_errors.BadResponseLength): + list(app({'REQUEST_METHOD': 'GET'}, self.start_response)) + + def test_long_response_body(self): + def cannot_count_app(env, sr): + sr("200 OK", [("Content-Length", "10")]) + return [b"our optometric firm is C.F. Eye Care"] + + app = catch_errors.CatchErrorMiddleware(cannot_count_app, {}) + + with self.assertRaises(catch_errors.BadResponseLength): + list(app({'REQUEST_METHOD': 'GET'}, self.start_response)) + + def test_bogus_content_length(self): + + def bogus_cl_app(env, sr): + sr("200 OK", [("Content-Length", "25 cm")]) + return [b"our British cutlery specialist is Sir Irving Spoon"] + + app = catch_errors.CatchErrorMiddleware(bogus_cl_app, {}) + list(app({'REQUEST_METHOD': 'GET'}, self.start_response)) + + def test_no_content_length(self): + + def no_cl_app(env, sr): + sr("200 OK", [("Content-Type", "application/names")]) + return [b"our staff statistician is Marge Inovera"] + + app = catch_errors.CatchErrorMiddleware(no_cl_app, {}) + list(app({'REQUEST_METHOD': 'GET'}, self.start_response)) + + def test_multiple_content_lengths(self): + + def poly_cl_app(env, sr): + sr("200 OK", [("Content-Length", "30"), + ("Content-Length", "40")]) + return [b"The head of our personal trainers is Jim Shortz"] + + app = catch_errors.CatchErrorMiddleware(poly_cl_app, {}) + list(app({'REQUEST_METHOD': 'GET'}, self.start_response)) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/common/middleware/test_dlo.py b/test/unit/common/middleware/test_dlo.py index d0a4ccc8d6..8ac99827d9 100644 --- a/test/unit/common/middleware/test_dlo.py +++ b/test/unit/common/middleware/test_dlo.py @@ -612,6 +612,21 @@ class TestDloGetManifest(DloTestCase): self.assertEqual(status, "200 OK") self.assertEqual(body, "aaaaabbbbbccccc") + def test_error_listing_container_HEAD(self): + self.app.register( + 'GET', '/v1/AUTH_test/c?prefix=seg_', + # for example, if a manifest refers to segments in another + # container, but the user is accessing the manifest via a + # container-level tempurl key + swob.HTTPUnauthorized, {}, None) + + req = swob.Request.blank('/v1/AUTH_test/mancon/manifest-many-segments', + environ={'REQUEST_METHOD': 'HEAD'}) + with mock.patch(LIMIT, 3): + status, headers, body = self.call_dlo(req) + self.assertEqual(status, "401 Unauthorized") + self.assertEqual(body, b"") + def test_mismatched_etag_fetching_second_segment(self): self.app.register( 'GET', '/v1/AUTH_test/c/seg_02',