Fix socket leak on object-server death

Consider a client that's downloading a large replicated object of size
N bytes. If the object server process dies (e.g. with a segfault)
partway through the download, the proxy will have read fewer than N
bytes, and then read(sockfd) will start returning 0 bytes. At this
point, the proxy believes the object download is complete, and so the
WSGI server waits for a new request to come in. Meanwhile, the client
is waiting for the rest of their bytes. Until the client times out,
that socket will be held open.

The fix is to look at the Content-Length and Content-Range headers in
the response from the object server, then retry with another object
server in case the original GET is truncated. This way, the client
gets all the bytes they should.

Note that ResumingGetter already had retry logic for when an
object-server is slow to send bytes -- this extends it to also cover
unexpected disconnects.

Change-Id: Iab1e07706193ddc86832fd2cff0d7c2cb6d79ad9
Related-Change: I74d8c13eba2a4917b5a116875b51a781b33a7abf
Closes-Bug: 1568650
This commit is contained in:
Samuel Merritt 2018-06-13 14:28:28 -07:00 committed by Tim Burke
parent 9aa4cafa0e
commit 0e81ffd1e1
5 changed files with 274 additions and 57 deletions

View File

@ -125,6 +125,10 @@ class ChunkReadError(SwiftException):
pass pass
class ShortReadError(SwiftException):
pass
class ChunkReadTimeout(Timeout): class ChunkReadTimeout(Timeout):
pass pass

View File

@ -49,7 +49,7 @@ from swift.common.utils import Timestamp, config_true_value, \
from swift.common.bufferedhttp import http_connect from swift.common.bufferedhttp import http_connect
from swift.common import constraints from swift.common import constraints
from swift.common.exceptions import ChunkReadTimeout, ChunkWriteTimeout, \ from swift.common.exceptions import ChunkReadTimeout, ChunkWriteTimeout, \
ConnectionTimeout, RangeAlreadyComplete ConnectionTimeout, RangeAlreadyComplete, ShortReadError
from swift.common.header_key_dict import HeaderKeyDict from swift.common.header_key_dict import HeaderKeyDict
from swift.common.http import is_informational, is_success, is_redirection, \ from swift.common.http import is_informational, is_success, is_redirection, \
is_server_error, HTTP_OK, HTTP_PARTIAL_CONTENT, HTTP_MULTIPLE_CHOICES, \ is_server_error, HTTP_OK, HTTP_PARTIAL_CONTENT, HTTP_MULTIPLE_CHOICES, \
@ -750,6 +750,37 @@ def bytes_to_skip(record_size, range_start):
return (record_size - (range_start % record_size)) % record_size return (record_size - (range_start % record_size)) % record_size
class ByteCountEnforcer(object):
"""
Enforces that successive calls to file_like.read() give at least
<nbytes> bytes before exhaustion.
If file_like fails to do so, ShortReadError is raised.
If more than <nbytes> bytes are read, we don't care.
"""
def __init__(self, file_like, nbytes):
"""
:param file_like: file-like object
:param nbytes: number of bytes expected, or None if length is unknown.
"""
self.file_like = file_like
self.nbytes = self.bytes_left = nbytes
def read(self, amt=None):
chunk = self.file_like.read(amt)
if self.bytes_left is None:
return chunk
elif len(chunk) == 0 and self.bytes_left > 0:
raise ShortReadError(
"Too few bytes; read %d, expecting %d" % (
self.nbytes - self.bytes_left, self.nbytes))
else:
self.bytes_left -= len(chunk)
return chunk
class ResumingGetter(object): class ResumingGetter(object):
def __init__(self, app, req, server_type, node_iter, partition, path, def __init__(self, app, req, server_type, node_iter, partition, path,
backend_headers, concurrency=1, client_chunk_size=None, backend_headers, concurrency=1, client_chunk_size=None,
@ -947,9 +978,9 @@ class ResumingGetter(object):
except ChunkReadTimeout: except ChunkReadTimeout:
new_source, new_node = self._get_source_and_node() new_source, new_node = self._get_source_and_node()
if new_source: if new_source:
self.app.exception_occurred( self.app.error_occurred(
node[0], _('Object'), node[0], _('Trying to read object during '
_('Trying to read during GET (retrying)')) 'GET (retrying)'))
# Close-out the connection as best as possible. # Close-out the connection as best as possible.
if getattr(source[0], 'swift_conn', None): if getattr(source[0], 'swift_conn', None):
close_swift_conn(source[0]) close_swift_conn(source[0])
@ -963,16 +994,21 @@ class ResumingGetter(object):
else: else:
raise StopIteration() raise StopIteration()
def iter_bytes_from_response_part(part_file): def iter_bytes_from_response_part(part_file, nbytes):
nchunks = 0 nchunks = 0
buf = b'' buf = b''
part_file = ByteCountEnforcer(part_file, nbytes)
while True: while True:
try: try:
with ChunkReadTimeout(node_timeout): with ChunkReadTimeout(node_timeout):
chunk = part_file.read(self.app.object_chunk_size) chunk = part_file.read(self.app.object_chunk_size)
nchunks += 1 nchunks += 1
# NB: this append must be *inside* the context
# manager for test.unit.SlowBody to do its thing
buf += chunk buf += chunk
except ChunkReadTimeout: if nbytes is not None:
nbytes -= len(chunk)
except (ChunkReadTimeout, ShortReadError):
exc_type, exc_value, exc_traceback = exc_info() exc_type, exc_value, exc_traceback = exc_info()
if self.newest or self.server_type != 'Object': if self.newest or self.server_type != 'Object':
raise raise
@ -985,9 +1021,9 @@ class ResumingGetter(object):
buf = b'' buf = b''
new_source, new_node = self._get_source_and_node() new_source, new_node = self._get_source_and_node()
if new_source: if new_source:
self.app.exception_occurred( self.app.error_occurred(
node[0], _('Object'), node[0], _('Trying to read object during '
_('Trying to read during GET (retrying)')) 'GET (retrying)'))
# Close-out the connection as best as possible. # Close-out the connection as best as possible.
if getattr(source[0], 'swift_conn', None): if getattr(source[0], 'swift_conn', None):
close_swift_conn(source[0]) close_swift_conn(source[0])
@ -1006,8 +1042,9 @@ class ResumingGetter(object):
except StopIteration: except StopIteration:
# Tried to find a new node from which to # Tried to find a new node from which to
# finish the GET, but failed. There's # finish the GET, but failed. There's
# nothing more to do here. # nothing more we can do here.
return six.reraise(exc_type, exc_value, exc_traceback)
part_file = ByteCountEnforcer(part_file, nbytes)
else: else:
six.reraise(exc_type, exc_value, exc_traceback) six.reraise(exc_type, exc_value, exc_traceback)
else: else:
@ -1069,10 +1106,18 @@ class ResumingGetter(object):
while True: while True:
start_byte, end_byte, length, headers, part = \ start_byte, end_byte, length, headers, part = \
get_next_doc_part() get_next_doc_part()
# note: learn_size_from_content_range() sets
# self.skip_bytes
self.learn_size_from_content_range( self.learn_size_from_content_range(
start_byte, end_byte, length) start_byte, end_byte, length)
self.bytes_used_from_backend = 0 self.bytes_used_from_backend = 0
part_iter = iter_bytes_from_response_part(part) # not length; that refers to the whole object, so is the
# wrong value to use for GET-range responses
byte_count = ((end_byte - start_byte + 1) - self.skip_bytes
if (end_byte is not None
and start_byte is not None)
else None)
part_iter = iter_bytes_from_response_part(part, byte_count)
yield {'start_byte': start_byte, 'end_byte': end_byte, yield {'start_byte': start_byte, 'end_byte': end_byte,
'entity_length': length, 'headers': headers, 'entity_length': length, 'headers': headers,
'part_iter': part_iter} 'part_iter': part_iter}

View File

@ -869,6 +869,9 @@ def fake_http_connect(*code_iter, **kwargs):
class FakeConn(object): class FakeConn(object):
SLOW_READS = 4
SLOW_WRITES = 4
def __init__(self, status, etag=None, body=b'', timestamp='1', def __init__(self, status, etag=None, body=b'', timestamp='1',
headers=None, expect_headers=None, connection_id=None, headers=None, expect_headers=None, connection_id=None,
give_send=None, give_expect=None): give_send=None, give_expect=None):
@ -894,6 +897,12 @@ def fake_http_connect(*code_iter, **kwargs):
self._next_sleep = kwargs['slow'].pop(0) self._next_sleep = kwargs['slow'].pop(0)
except IndexError: except IndexError:
self._next_sleep = None self._next_sleep = None
# if we're going to be slow, we need a body to send slowly
am_slow, _junk = self.get_slow()
if am_slow and len(self.body) < self.SLOW_READS:
self.body += " " * (self.SLOW_READS - len(self.body))
# be nice to trixy bits with node_iter's # be nice to trixy bits with node_iter's
eventlet.sleep() eventlet.sleep()
@ -929,6 +938,7 @@ def fake_http_connect(*code_iter, **kwargs):
else: else:
etag = '"68b329da9893e34099c7d8ad5cb9c940"' etag = '"68b329da9893e34099c7d8ad5cb9c940"'
am_slow, _junk = self.get_slow()
headers = HeaderKeyDict({ headers = HeaderKeyDict({
'content-length': len(self.body), 'content-length': len(self.body),
'content-type': 'x-application/test', 'content-type': 'x-application/test',
@ -951,9 +961,6 @@ def fake_http_connect(*code_iter, **kwargs):
headers['x-container-timestamp'] = '1' headers['x-container-timestamp'] = '1'
except StopIteration: except StopIteration:
pass pass
am_slow, value = self.get_slow()
if am_slow:
headers['content-length'] = '4'
headers.update(self.headers) headers.update(self.headers)
return headers.items() return headers.items()
@ -970,12 +977,16 @@ def fake_http_connect(*code_iter, **kwargs):
def read(self, amt=None): def read(self, amt=None):
am_slow, value = self.get_slow() am_slow, value = self.get_slow()
if am_slow: if am_slow:
if self.sent < 4: if self.sent < self.SLOW_READS:
slowly_read_byte = self.body[self.sent]
self.sent += 1 self.sent += 1
eventlet.sleep(value) eventlet.sleep(value)
return ' ' return slowly_read_byte
rv = self.body[:amt] if amt is None:
self.body = self.body[amt:] rv = self.body[self.sent:]
else:
rv = self.body[self.sent:self.sent + amt]
self.sent += len(rv)
return rv return rv
def send(self, data=None): def send(self, data=None):
@ -983,7 +994,7 @@ def fake_http_connect(*code_iter, **kwargs):
self.give_send(self, data) self.give_send(self, data)
am_slow, value = self.get_slow() am_slow, value = self.get_slow()
if am_slow: if am_slow:
if self.received < 4: if self.received < self.SLOW_WRITES:
self.received += 1 self.received += 1
eventlet.sleep(value) eventlet.sleep(value)

View File

@ -46,6 +46,12 @@ class TestReplicatedObjectController(
def test_policy_IO(self): def test_policy_IO(self):
pass pass
def test_GET_short_read(self):
pass
def test_GET_short_read_resuming(self):
pass
class TestECObjectController(test_server.TestECObjectController): class TestECObjectController(test_server.TestECObjectController):
def test_PUT_ec(self): def test_PUT_ec(self):

View File

@ -20,6 +20,8 @@ import logging
import json import json
import math import math
import os import os
import posix
import socket
import sys import sys
import traceback import traceback
import unittest import unittest
@ -64,7 +66,7 @@ from swift.common.middleware import proxy_logging, versioned_writes, \
copy, listing_formats copy, listing_formats
from swift.common.middleware.acl import parse_acl, format_acl from swift.common.middleware.acl import parse_acl, format_acl
from swift.common.exceptions import ChunkReadTimeout, DiskFileNotExist, \ from swift.common.exceptions import ChunkReadTimeout, DiskFileNotExist, \
APIVersionError, ChunkWriteTimeout APIVersionError, ChunkWriteTimeout, ChunkReadError
from swift.common import utils, constraints from swift.common import utils, constraints
from swift.common.utils import hash_path, storage_directory, \ from swift.common.utils import hash_path, storage_directory, \
parse_content_type, parse_mime_headers, \ parse_content_type, parse_mime_headers, \
@ -960,14 +962,18 @@ class TestProxyServer(unittest.TestCase):
self.kargs = kargs self.kargs = kargs
def getresponse(self): def getresponse(self):
body = 'Response from %s' % self.ip
def mygetheader(header, *args, **kargs): def mygetheader(header, *args, **kargs):
if header == "Content-Type": if header == "Content-Type":
return "" return ""
elif header == "Content-Length":
return str(len(body))
else: else:
return 1 return 1
resp = mock.Mock() resp = mock.Mock()
resp.read.side_effect = ['Response from %s' % self.ip, ''] resp.read.side_effect = [body, '']
resp.getheader = mygetheader resp.getheader = mygetheader
resp.getheaders.return_value = {} resp.getheaders.return_value = {}
resp.reason = '' resp.reason = ''
@ -2373,6 +2379,178 @@ class TestReplicatedObjectController(
self.assertEqual(res.status_int, 200) self.assertEqual(res.status_int, 200)
self.assertEqual(res.body, '') self.assertEqual(res.body, '')
@unpatch_policies
def test_GET_short_read(self):
prolis = _test_sockets[0]
prosrv = _test_servers[0]
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile()
obj = (''.join(
('%d bottles of beer on the wall\n' % i)
for i in reversed(range(1, 200))))
# if the object is too short, then we don't have a mid-stream
# exception after the headers are sent, but instead an early one
# before the headers
self.assertGreater(len(obj), wsgi.MINIMUM_CHUNK_SIZE)
path = '/v1/a/c/o.bottles'
fd.write('PUT %s HTTP/1.1\r\n'
'Connection: keep-alive\r\n'
'Host: localhost\r\n'
'X-Storage-Token: t\r\n'
'Content-Length: %s\r\n'
'Content-Type: application/beer-stream\r\n'
'\r\n%s' % (path, str(len(obj)), obj))
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 201'
self.assertEqual(headers[:len(exp)], exp)
# go shorten that object by a few bytes
shrinkage = 100 # bytes
shortened = 0
for dirpath, _, filenames in os.walk(_testdir):
for filename in filenames:
if filename.endswith(".data"):
with open(os.path.join(dirpath, filename), "r+") as fh:
fh.truncate(len(obj) - shrinkage)
shortened += 1
self.assertGreater(shortened, 0) # ensure test is working
real_fstat = os.fstat
# stop the object server from immediately quarantining the object
# and returning 404
def lying_fstat(fd):
sr = real_fstat(fd)
fake_stat_result = posix.stat_result((
sr.st_mode, sr.st_ino, sr.st_dev, sr.st_nlink, sr.st_uid,
sr.st_gid,
sr.st_size + shrinkage, # here's the lie
sr.st_atime, sr.st_mtime, sr.st_ctime))
return fake_stat_result
# Read the object back
with mock.patch('os.fstat', lying_fstat), \
mock.patch.object(prosrv, 'client_chunk_size', 32), \
mock.patch.object(prosrv, 'object_chunk_size', 32):
fd.write('GET %s HTTP/1.1\r\n'
'Host: localhost\r\n'
'Connection: keep-alive\r\n'
'X-Storage-Token: t\r\n'
'\r\n' % (path,))
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 200'
self.assertEqual(headers[:len(exp)], exp)
obj_parts = []
while True:
buf = fd.read(1024)
if not buf:
break
obj_parts.append(buf)
got_obj = ''.join(obj_parts)
self.assertLessEqual(len(got_obj), len(obj) - shrinkage)
# Make sure the server closed the connection
with self.assertRaises(socket.error):
# Two calls are necessary; you can apparently write to a socket
# that the peer has closed exactly once without error, then the
# kernel discovers that the connection is not open and
# subsequent send attempts fail.
sock.sendall('GET /info HTTP/1.1\r\n')
sock.sendall('Host: localhost\r\n'
'X-Storage-Token: t\r\n'
'\r\n')
@unpatch_policies
def test_GET_short_read_resuming(self):
prolis = _test_sockets[0]
prosrv = _test_servers[0]
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile()
obj = (''.join(
('%d bottles of beer on the wall\n' % i)
for i in reversed(range(1, 200))))
# if the object is too short, then we don't have a mid-stream
# exception after the headers are sent, but instead an early one
# before the headers
self.assertGreater(len(obj), wsgi.MINIMUM_CHUNK_SIZE)
path = '/v1/a/c/o.bottles'
fd.write('PUT %s HTTP/1.1\r\n'
'Connection: keep-alive\r\n'
'Host: localhost\r\n'
'X-Storage-Token: t\r\n'
'Content-Length: %s\r\n'
'Content-Type: application/beer-stream\r\n'
'\r\n%s' % (path, str(len(obj)), obj))
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 201'
self.assertEqual(headers[:len(exp)], exp)
# we shorten the first replica of the object by 200 bytes and leave
# the others untouched
_, obj_nodes = POLICIES.default.object_ring.get_nodes(
"a", "c", "o.bottles")
shortened = 0
for dirpath, _, filenames in os.walk(
os.path.join(_testdir, obj_nodes[0]['device'])):
for filename in filenames:
if filename.endswith(".data"):
if shortened == 0:
with open(os.path.join(dirpath, filename), "r+") as fh:
fh.truncate(len(obj) - 200)
shortened += 1
self.assertEqual(shortened, 1) # sanity check
real_fstat = os.fstat
# stop the object server from immediately quarantining the object
# and returning 404
def lying_fstat(fd):
sr = real_fstat(fd)
fake_stat_result = posix.stat_result((
sr.st_mode, sr.st_ino, sr.st_dev, sr.st_nlink, sr.st_uid,
sr.st_gid,
len(obj), # sometimes correct, sometimes not
sr.st_atime, sr.st_mtime, sr.st_ctime))
return fake_stat_result
# Read the object back
with mock.patch('os.fstat', lying_fstat), \
mock.patch.object(prosrv, 'client_chunk_size', 32), \
mock.patch.object(prosrv, 'object_chunk_size', 32), \
mock.patch.object(prosrv, 'sort_nodes',
lambda nodes, **kw: nodes):
fd.write('GET %s HTTP/1.1\r\n'
'Host: localhost\r\n'
'Connection: close\r\n'
'X-Storage-Token: t\r\n'
'\r\n' % (path,))
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 200'
self.assertEqual(headers[:len(exp)], exp)
obj_parts = []
while True:
buf = fd.read(1024)
if not buf:
break
obj_parts.append(buf)
got_obj = ''.join(obj_parts)
# technically this is a redundant test, but it saves us from screens
# full of error message when got_obj is shorter than obj
self.assertEqual(len(obj), len(got_obj))
self.assertEqual(obj, got_obj)
@unpatch_policies @unpatch_policies
def test_GET_ranges_resuming(self): def test_GET_ranges_resuming(self):
prolis = _test_sockets[0] prolis = _test_sockets[0]
@ -2497,7 +2675,7 @@ class TestReplicatedObjectController(
try: try:
for chunk in res.app_iter: for chunk in res.app_iter:
body += chunk body += chunk
except ChunkReadTimeout: except (ChunkReadTimeout, ChunkReadError):
pass pass
self.assertEqual(res.status_int, 206) self.assertEqual(res.status_int, 206)
@ -3845,12 +4023,8 @@ class TestReplicatedObjectController(
self.app.recoverable_node_timeout = 0.1 self.app.recoverable_node_timeout = 0.1
set_http_connect(200, 200, 200, slow=1.0) set_http_connect(200, 200, 200, slow=1.0)
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False with self.assertRaises(ChunkReadTimeout):
try:
resp.body resp.body
except ChunkReadTimeout:
got_exc = True
self.assertTrue(got_exc)
def test_node_read_timeout_retry(self): def test_node_read_timeout_retry(self):
with save_globals(): with save_globals():
@ -3873,53 +4047,30 @@ class TestReplicatedObjectController(
self.app.recoverable_node_timeout = 0.1 self.app.recoverable_node_timeout = 0.1
set_http_connect(200, 200, 200, slow=[1.0, 1.0, 1.0]) set_http_connect(200, 200, 200, slow=[1.0, 1.0, 1.0])
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False with self.assertRaises(ChunkReadTimeout):
try: resp.body
self.assertEqual('', resp.body)
except ChunkReadTimeout:
got_exc = True
self.assertTrue(got_exc)
set_http_connect(200, 200, 200, body='lalala', set_http_connect(200, 200, 200, body='lalala',
slow=[1.0, 1.0]) slow=[1.0, 1.0])
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False self.assertEqual(resp.body, 'lalala')
try:
self.assertEqual(resp.body, 'lalala')
except ChunkReadTimeout:
got_exc = True
self.assertFalse(got_exc)
set_http_connect(200, 200, 200, body='lalala', set_http_connect(200, 200, 200, body='lalala',
slow=[1.0, 1.0], etags=['a', 'a', 'a']) slow=[1.0, 1.0], etags=['a', 'a', 'a'])
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False self.assertEqual(resp.body, 'lalala')
try:
self.assertEqual(resp.body, 'lalala')
except ChunkReadTimeout:
got_exc = True
self.assertFalse(got_exc)
set_http_connect(200, 200, 200, body='lalala', set_http_connect(200, 200, 200, body='lalala',
slow=[1.0, 1.0], etags=['a', 'b', 'a']) slow=[1.0, 1.0], etags=['a', 'b', 'a'])
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False self.assertEqual(resp.body, 'lalala')
try:
self.assertEqual(resp.body, 'lalala')
except ChunkReadTimeout:
got_exc = True
self.assertFalse(got_exc)
req = Request.blank('/v1/a/c/o', environ={'REQUEST_METHOD': 'GET'}) req = Request.blank('/v1/a/c/o', environ={'REQUEST_METHOD': 'GET'})
set_http_connect(200, 200, 200, body='lalala', set_http_connect(200, 200, 200, body='lalala',
slow=[1.0, 1.0], etags=['a', 'b', 'b']) slow=[1.0, 1.0], etags=['a', 'b', 'b'])
resp = req.get_response(self.app) resp = req.get_response(self.app)
got_exc = False with self.assertRaises(ChunkReadTimeout):
try:
resp.body resp.body
except ChunkReadTimeout:
got_exc = True
self.assertTrue(got_exc)
def test_node_write_timeout(self): def test_node_write_timeout(self):
with save_globals(): with save_globals():