From 0e81ffd1e1481a73146fce17f61f2ab9e01eb684 Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 13 Jun 2018 14:28:28 -0700 Subject: [PATCH] Fix socket leak on object-server death Consider a client that's downloading a large replicated object of size N bytes. If the object server process dies (e.g. with a segfault) partway through the download, the proxy will have read fewer than N bytes, and then read(sockfd) will start returning 0 bytes. At this point, the proxy believes the object download is complete, and so the WSGI server waits for a new request to come in. Meanwhile, the client is waiting for the rest of their bytes. Until the client times out, that socket will be held open. The fix is to look at the Content-Length and Content-Range headers in the response from the object server, then retry with another object server in case the original GET is truncated. This way, the client gets all the bytes they should. Note that ResumingGetter already had retry logic for when an object-server is slow to send bytes -- this extends it to also cover unexpected disconnects. Change-Id: Iab1e07706193ddc86832fd2cff0d7c2cb6d79ad9 Related-Change: I74d8c13eba2a4917b5a116875b51a781b33a7abf Closes-Bug: 1568650 --- swift/common/exceptions.py | 4 + swift/proxy/controllers/base.py | 69 +++++++-- test/unit/__init__.py | 27 +++- test/unit/proxy/test_mem_server.py | 6 + test/unit/proxy/test_server.py | 225 ++++++++++++++++++++++++----- 5 files changed, 274 insertions(+), 57 deletions(-) diff --git a/swift/common/exceptions.py b/swift/common/exceptions.py index 8774ca8235..922d83833f 100644 --- a/swift/common/exceptions.py +++ b/swift/common/exceptions.py @@ -125,6 +125,10 @@ class ChunkReadError(SwiftException): pass +class ShortReadError(SwiftException): + pass + + class ChunkReadTimeout(Timeout): pass diff --git a/swift/proxy/controllers/base.py b/swift/proxy/controllers/base.py index 4e908e2408..7bfb7d12be 100644 --- a/swift/proxy/controllers/base.py +++ b/swift/proxy/controllers/base.py @@ -49,7 +49,7 @@ from swift.common.utils import Timestamp, config_true_value, \ from swift.common.bufferedhttp import http_connect from swift.common import constraints from swift.common.exceptions import ChunkReadTimeout, ChunkWriteTimeout, \ - ConnectionTimeout, RangeAlreadyComplete + ConnectionTimeout, RangeAlreadyComplete, ShortReadError from swift.common.header_key_dict import HeaderKeyDict from swift.common.http import is_informational, is_success, is_redirection, \ is_server_error, HTTP_OK, HTTP_PARTIAL_CONTENT, HTTP_MULTIPLE_CHOICES, \ @@ -750,6 +750,37 @@ def bytes_to_skip(record_size, range_start): return (record_size - (range_start % record_size)) % record_size +class ByteCountEnforcer(object): + """ + Enforces that successive calls to file_like.read() give at least + bytes before exhaustion. + + If file_like fails to do so, ShortReadError is raised. + + If more than bytes are read, we don't care. + """ + + def __init__(self, file_like, nbytes): + """ + :param file_like: file-like object + :param nbytes: number of bytes expected, or None if length is unknown. + """ + self.file_like = file_like + self.nbytes = self.bytes_left = nbytes + + def read(self, amt=None): + chunk = self.file_like.read(amt) + if self.bytes_left is None: + return chunk + elif len(chunk) == 0 and self.bytes_left > 0: + raise ShortReadError( + "Too few bytes; read %d, expecting %d" % ( + self.nbytes - self.bytes_left, self.nbytes)) + else: + self.bytes_left -= len(chunk) + return chunk + + class ResumingGetter(object): def __init__(self, app, req, server_type, node_iter, partition, path, backend_headers, concurrency=1, client_chunk_size=None, @@ -947,9 +978,9 @@ class ResumingGetter(object): except ChunkReadTimeout: new_source, new_node = self._get_source_and_node() if new_source: - self.app.exception_occurred( - node[0], _('Object'), - _('Trying to read during GET (retrying)')) + self.app.error_occurred( + node[0], _('Trying to read object during ' + 'GET (retrying)')) # Close-out the connection as best as possible. if getattr(source[0], 'swift_conn', None): close_swift_conn(source[0]) @@ -963,16 +994,21 @@ class ResumingGetter(object): else: raise StopIteration() - def iter_bytes_from_response_part(part_file): + def iter_bytes_from_response_part(part_file, nbytes): nchunks = 0 buf = b'' + part_file = ByteCountEnforcer(part_file, nbytes) while True: try: with ChunkReadTimeout(node_timeout): chunk = part_file.read(self.app.object_chunk_size) nchunks += 1 + # NB: this append must be *inside* the context + # manager for test.unit.SlowBody to do its thing buf += chunk - except ChunkReadTimeout: + if nbytes is not None: + nbytes -= len(chunk) + except (ChunkReadTimeout, ShortReadError): exc_type, exc_value, exc_traceback = exc_info() if self.newest or self.server_type != 'Object': raise @@ -985,9 +1021,9 @@ class ResumingGetter(object): buf = b'' new_source, new_node = self._get_source_and_node() if new_source: - self.app.exception_occurred( - node[0], _('Object'), - _('Trying to read during GET (retrying)')) + self.app.error_occurred( + node[0], _('Trying to read object during ' + 'GET (retrying)')) # Close-out the connection as best as possible. if getattr(source[0], 'swift_conn', None): close_swift_conn(source[0]) @@ -1006,8 +1042,9 @@ class ResumingGetter(object): except StopIteration: # Tried to find a new node from which to # finish the GET, but failed. There's - # nothing more to do here. - return + # nothing more we can do here. + six.reraise(exc_type, exc_value, exc_traceback) + part_file = ByteCountEnforcer(part_file, nbytes) else: six.reraise(exc_type, exc_value, exc_traceback) else: @@ -1069,10 +1106,18 @@ class ResumingGetter(object): while True: start_byte, end_byte, length, headers, part = \ get_next_doc_part() + # note: learn_size_from_content_range() sets + # self.skip_bytes self.learn_size_from_content_range( start_byte, end_byte, length) self.bytes_used_from_backend = 0 - part_iter = iter_bytes_from_response_part(part) + # not length; that refers to the whole object, so is the + # wrong value to use for GET-range responses + byte_count = ((end_byte - start_byte + 1) - self.skip_bytes + if (end_byte is not None + and start_byte is not None) + else None) + part_iter = iter_bytes_from_response_part(part, byte_count) yield {'start_byte': start_byte, 'end_byte': end_byte, 'entity_length': length, 'headers': headers, 'part_iter': part_iter} diff --git a/test/unit/__init__.py b/test/unit/__init__.py index 24c0eb08b2..4c71299af6 100644 --- a/test/unit/__init__.py +++ b/test/unit/__init__.py @@ -869,6 +869,9 @@ def fake_http_connect(*code_iter, **kwargs): class FakeConn(object): + SLOW_READS = 4 + SLOW_WRITES = 4 + def __init__(self, status, etag=None, body=b'', timestamp='1', headers=None, expect_headers=None, connection_id=None, give_send=None, give_expect=None): @@ -894,6 +897,12 @@ def fake_http_connect(*code_iter, **kwargs): self._next_sleep = kwargs['slow'].pop(0) except IndexError: self._next_sleep = None + + # if we're going to be slow, we need a body to send slowly + am_slow, _junk = self.get_slow() + if am_slow and len(self.body) < self.SLOW_READS: + self.body += " " * (self.SLOW_READS - len(self.body)) + # be nice to trixy bits with node_iter's eventlet.sleep() @@ -929,6 +938,7 @@ def fake_http_connect(*code_iter, **kwargs): else: etag = '"68b329da9893e34099c7d8ad5cb9c940"' + am_slow, _junk = self.get_slow() headers = HeaderKeyDict({ 'content-length': len(self.body), 'content-type': 'x-application/test', @@ -951,9 +961,6 @@ def fake_http_connect(*code_iter, **kwargs): headers['x-container-timestamp'] = '1' except StopIteration: pass - am_slow, value = self.get_slow() - if am_slow: - headers['content-length'] = '4' headers.update(self.headers) return headers.items() @@ -970,12 +977,16 @@ def fake_http_connect(*code_iter, **kwargs): def read(self, amt=None): am_slow, value = self.get_slow() if am_slow: - if self.sent < 4: + if self.sent < self.SLOW_READS: + slowly_read_byte = self.body[self.sent] self.sent += 1 eventlet.sleep(value) - return ' ' - rv = self.body[:amt] - self.body = self.body[amt:] + return slowly_read_byte + if amt is None: + rv = self.body[self.sent:] + else: + rv = self.body[self.sent:self.sent + amt] + self.sent += len(rv) return rv def send(self, data=None): @@ -983,7 +994,7 @@ def fake_http_connect(*code_iter, **kwargs): self.give_send(self, data) am_slow, value = self.get_slow() if am_slow: - if self.received < 4: + if self.received < self.SLOW_WRITES: self.received += 1 eventlet.sleep(value) diff --git a/test/unit/proxy/test_mem_server.py b/test/unit/proxy/test_mem_server.py index 0901091529..4dda8c5f87 100644 --- a/test/unit/proxy/test_mem_server.py +++ b/test/unit/proxy/test_mem_server.py @@ -46,6 +46,12 @@ class TestReplicatedObjectController( def test_policy_IO(self): pass + def test_GET_short_read(self): + pass + + def test_GET_short_read_resuming(self): + pass + class TestECObjectController(test_server.TestECObjectController): def test_PUT_ec(self): diff --git a/test/unit/proxy/test_server.py b/test/unit/proxy/test_server.py index bcc5a5bed1..454bb1b611 100644 --- a/test/unit/proxy/test_server.py +++ b/test/unit/proxy/test_server.py @@ -20,6 +20,8 @@ import logging import json import math import os +import posix +import socket import sys import traceback import unittest @@ -64,7 +66,7 @@ from swift.common.middleware import proxy_logging, versioned_writes, \ copy, listing_formats from swift.common.middleware.acl import parse_acl, format_acl from swift.common.exceptions import ChunkReadTimeout, DiskFileNotExist, \ - APIVersionError, ChunkWriteTimeout + APIVersionError, ChunkWriteTimeout, ChunkReadError from swift.common import utils, constraints from swift.common.utils import hash_path, storage_directory, \ parse_content_type, parse_mime_headers, \ @@ -960,14 +962,18 @@ class TestProxyServer(unittest.TestCase): self.kargs = kargs def getresponse(self): + body = 'Response from %s' % self.ip + def mygetheader(header, *args, **kargs): if header == "Content-Type": return "" + elif header == "Content-Length": + return str(len(body)) else: return 1 resp = mock.Mock() - resp.read.side_effect = ['Response from %s' % self.ip, ''] + resp.read.side_effect = [body, ''] resp.getheader = mygetheader resp.getheaders.return_value = {} resp.reason = '' @@ -2373,6 +2379,178 @@ class TestReplicatedObjectController( self.assertEqual(res.status_int, 200) self.assertEqual(res.body, '') + @unpatch_policies + def test_GET_short_read(self): + prolis = _test_sockets[0] + prosrv = _test_servers[0] + sock = connect_tcp(('localhost', prolis.getsockname()[1])) + fd = sock.makefile() + obj = (''.join( + ('%d bottles of beer on the wall\n' % i) + for i in reversed(range(1, 200)))) + + # if the object is too short, then we don't have a mid-stream + # exception after the headers are sent, but instead an early one + # before the headers + self.assertGreater(len(obj), wsgi.MINIMUM_CHUNK_SIZE) + + path = '/v1/a/c/o.bottles' + fd.write('PUT %s HTTP/1.1\r\n' + 'Connection: keep-alive\r\n' + 'Host: localhost\r\n' + 'X-Storage-Token: t\r\n' + 'Content-Length: %s\r\n' + 'Content-Type: application/beer-stream\r\n' + '\r\n%s' % (path, str(len(obj)), obj)) + fd.flush() + headers = readuntil2crlfs(fd) + exp = 'HTTP/1.1 201' + self.assertEqual(headers[:len(exp)], exp) + + # go shorten that object by a few bytes + shrinkage = 100 # bytes + shortened = 0 + for dirpath, _, filenames in os.walk(_testdir): + for filename in filenames: + if filename.endswith(".data"): + with open(os.path.join(dirpath, filename), "r+") as fh: + fh.truncate(len(obj) - shrinkage) + shortened += 1 + self.assertGreater(shortened, 0) # ensure test is working + + real_fstat = os.fstat + + # stop the object server from immediately quarantining the object + # and returning 404 + def lying_fstat(fd): + sr = real_fstat(fd) + fake_stat_result = posix.stat_result(( + sr.st_mode, sr.st_ino, sr.st_dev, sr.st_nlink, sr.st_uid, + sr.st_gid, + sr.st_size + shrinkage, # here's the lie + sr.st_atime, sr.st_mtime, sr.st_ctime)) + return fake_stat_result + + # Read the object back + with mock.patch('os.fstat', lying_fstat), \ + mock.patch.object(prosrv, 'client_chunk_size', 32), \ + mock.patch.object(prosrv, 'object_chunk_size', 32): + fd.write('GET %s HTTP/1.1\r\n' + 'Host: localhost\r\n' + 'Connection: keep-alive\r\n' + 'X-Storage-Token: t\r\n' + '\r\n' % (path,)) + fd.flush() + headers = readuntil2crlfs(fd) + exp = 'HTTP/1.1 200' + self.assertEqual(headers[:len(exp)], exp) + + obj_parts = [] + while True: + buf = fd.read(1024) + if not buf: + break + obj_parts.append(buf) + got_obj = ''.join(obj_parts) + self.assertLessEqual(len(got_obj), len(obj) - shrinkage) + + # Make sure the server closed the connection + with self.assertRaises(socket.error): + # Two calls are necessary; you can apparently write to a socket + # that the peer has closed exactly once without error, then the + # kernel discovers that the connection is not open and + # subsequent send attempts fail. + sock.sendall('GET /info HTTP/1.1\r\n') + sock.sendall('Host: localhost\r\n' + 'X-Storage-Token: t\r\n' + '\r\n') + + @unpatch_policies + def test_GET_short_read_resuming(self): + prolis = _test_sockets[0] + prosrv = _test_servers[0] + sock = connect_tcp(('localhost', prolis.getsockname()[1])) + fd = sock.makefile() + obj = (''.join( + ('%d bottles of beer on the wall\n' % i) + for i in reversed(range(1, 200)))) + + # if the object is too short, then we don't have a mid-stream + # exception after the headers are sent, but instead an early one + # before the headers + self.assertGreater(len(obj), wsgi.MINIMUM_CHUNK_SIZE) + + path = '/v1/a/c/o.bottles' + fd.write('PUT %s HTTP/1.1\r\n' + 'Connection: keep-alive\r\n' + 'Host: localhost\r\n' + 'X-Storage-Token: t\r\n' + 'Content-Length: %s\r\n' + 'Content-Type: application/beer-stream\r\n' + '\r\n%s' % (path, str(len(obj)), obj)) + fd.flush() + headers = readuntil2crlfs(fd) + exp = 'HTTP/1.1 201' + self.assertEqual(headers[:len(exp)], exp) + + # we shorten the first replica of the object by 200 bytes and leave + # the others untouched + _, obj_nodes = POLICIES.default.object_ring.get_nodes( + "a", "c", "o.bottles") + + shortened = 0 + for dirpath, _, filenames in os.walk( + os.path.join(_testdir, obj_nodes[0]['device'])): + for filename in filenames: + if filename.endswith(".data"): + if shortened == 0: + with open(os.path.join(dirpath, filename), "r+") as fh: + fh.truncate(len(obj) - 200) + shortened += 1 + self.assertEqual(shortened, 1) # sanity check + + real_fstat = os.fstat + + # stop the object server from immediately quarantining the object + # and returning 404 + def lying_fstat(fd): + sr = real_fstat(fd) + fake_stat_result = posix.stat_result(( + sr.st_mode, sr.st_ino, sr.st_dev, sr.st_nlink, sr.st_uid, + sr.st_gid, + len(obj), # sometimes correct, sometimes not + sr.st_atime, sr.st_mtime, sr.st_ctime)) + return fake_stat_result + + # Read the object back + with mock.patch('os.fstat', lying_fstat), \ + mock.patch.object(prosrv, 'client_chunk_size', 32), \ + mock.patch.object(prosrv, 'object_chunk_size', 32), \ + mock.patch.object(prosrv, 'sort_nodes', + lambda nodes, **kw: nodes): + fd.write('GET %s HTTP/1.1\r\n' + 'Host: localhost\r\n' + 'Connection: close\r\n' + 'X-Storage-Token: t\r\n' + '\r\n' % (path,)) + fd.flush() + headers = readuntil2crlfs(fd) + exp = 'HTTP/1.1 200' + self.assertEqual(headers[:len(exp)], exp) + + obj_parts = [] + while True: + buf = fd.read(1024) + if not buf: + break + obj_parts.append(buf) + got_obj = ''.join(obj_parts) + + # technically this is a redundant test, but it saves us from screens + # full of error message when got_obj is shorter than obj + self.assertEqual(len(obj), len(got_obj)) + self.assertEqual(obj, got_obj) + @unpatch_policies def test_GET_ranges_resuming(self): prolis = _test_sockets[0] @@ -2497,7 +2675,7 @@ class TestReplicatedObjectController( try: for chunk in res.app_iter: body += chunk - except ChunkReadTimeout: + except (ChunkReadTimeout, ChunkReadError): pass self.assertEqual(res.status_int, 206) @@ -3845,12 +4023,8 @@ class TestReplicatedObjectController( self.app.recoverable_node_timeout = 0.1 set_http_connect(200, 200, 200, slow=1.0) resp = req.get_response(self.app) - got_exc = False - try: + with self.assertRaises(ChunkReadTimeout): resp.body - except ChunkReadTimeout: - got_exc = True - self.assertTrue(got_exc) def test_node_read_timeout_retry(self): with save_globals(): @@ -3873,53 +4047,30 @@ class TestReplicatedObjectController( self.app.recoverable_node_timeout = 0.1 set_http_connect(200, 200, 200, slow=[1.0, 1.0, 1.0]) resp = req.get_response(self.app) - got_exc = False - try: - self.assertEqual('', resp.body) - except ChunkReadTimeout: - got_exc = True - self.assertTrue(got_exc) + with self.assertRaises(ChunkReadTimeout): + resp.body set_http_connect(200, 200, 200, body='lalala', slow=[1.0, 1.0]) resp = req.get_response(self.app) - got_exc = False - try: - self.assertEqual(resp.body, 'lalala') - except ChunkReadTimeout: - got_exc = True - self.assertFalse(got_exc) + self.assertEqual(resp.body, 'lalala') set_http_connect(200, 200, 200, body='lalala', slow=[1.0, 1.0], etags=['a', 'a', 'a']) resp = req.get_response(self.app) - got_exc = False - try: - self.assertEqual(resp.body, 'lalala') - except ChunkReadTimeout: - got_exc = True - self.assertFalse(got_exc) + self.assertEqual(resp.body, 'lalala') set_http_connect(200, 200, 200, body='lalala', slow=[1.0, 1.0], etags=['a', 'b', 'a']) resp = req.get_response(self.app) - got_exc = False - try: - self.assertEqual(resp.body, 'lalala') - except ChunkReadTimeout: - got_exc = True - self.assertFalse(got_exc) + self.assertEqual(resp.body, 'lalala') req = Request.blank('/v1/a/c/o', environ={'REQUEST_METHOD': 'GET'}) set_http_connect(200, 200, 200, body='lalala', slow=[1.0, 1.0], etags=['a', 'b', 'b']) resp = req.get_response(self.app) - got_exc = False - try: + with self.assertRaises(ChunkReadTimeout): resp.body - except ChunkReadTimeout: - got_exc = True - self.assertTrue(got_exc) def test_node_write_timeout(self): with save_globals():