diff --git a/swift/common/utils.py b/swift/common/utils.py index 0a9fc4dedf..78f69781a6 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -2074,6 +2074,24 @@ def csv_append(csv_string, item): return item +class CloseableChain(object): + """ + Like itertools.chain, but with a close method that will attempt to invoke + its sub-iterators' close methods, if any. + """ + def __init__(self, *iterables): + self.iterables = iterables + + def __iter__(self): + return iter(itertools.chain(*(self.iterables))) + + def close(self): + for it in self.iterables: + close_method = getattr(it, 'close', None) + if close_method: + close_method() + + def reiterate(iterable): """ Consume the first item from an iterator, then re-chain it to the rest of @@ -2090,7 +2108,7 @@ def reiterate(iterable): chunk = '' while not chunk: chunk = next(iterator) - return itertools.chain([chunk], iterator) + return CloseableChain([chunk], iterator) except StopIteration: return [] diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index f75e1f7a12..1bacb4eff7 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -21,7 +21,6 @@ import signal import time import mimetools from swift import gettext_ as _ -from itertools import chain from StringIO import StringIO import eventlet @@ -35,7 +34,8 @@ from swift.common import utils from swift.common.swob import Request from swift.common.utils import capture_stdio, disable_fallocate, \ drop_privileges, get_logger, NullLogger, config_true_value, \ - validate_configuration, get_hub, config_auto_int_value + validate_configuration, get_hub, config_auto_int_value, \ + CloseableChain try: import multiprocessing @@ -401,7 +401,7 @@ class WSGIContext(object): except StopIteration: return iter([]) else: # We got a first_chunk - return chain([first_chunk], resp) + return CloseableChain([first_chunk], resp) def _get_status_int(self): """ diff --git a/test/unit/common/test_swob.py b/test/unit/common/test_swob.py index 3ba51ab858..d910c95d33 100644 --- a/test/unit/common/test_swob.py +++ b/test/unit/common/test_swob.py @@ -936,6 +936,22 @@ class TestResponse(unittest.TestCase): output_iter = resp(req.environ, lambda *_: None) self.assertEquals(list(output_iter), ['']) + def test_call_preserves_closeability(self): + def test_app(environ, start_response): + start_response('200 OK', []) + yield "igloo" + yield "shindig" + yield "macadamia" + yield "hullabaloo" + req = swift.common.swob.Request.blank('/') + req.method = 'GET' + status, headers, app_iter = req.call_application(test_app) + iterator = iter(app_iter) + self.assertEqual('igloo', iterator.next()) + self.assertEqual('shindig', iterator.next()) + app_iter.close() + self.assertRaises(StopIteration, iterator.next) + def test_location_rewrite(self): def start_response(env, headers): pass diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index 869fbe2593..af8987b198 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -598,6 +598,27 @@ class TestWSGIContext(unittest.TestCase): self.assertEquals(wc._response_status, '404 Not Found') self.assertEquals(''.join(it), 'Ok\n') + def test_app_iter_is_closable(self): + + def app(env, start_response): + start_response('200 OK', [('Content-Length', '25')]) + yield 'aaaaa' + yield 'bbbbb' + yield 'ccccc' + yield 'ddddd' + yield 'eeeee' + + wc = wsgi.WSGIContext(app) + r = Request.blank('/') + iterable = wc._app_call(r.environ) + self.assertEquals(wc._response_status, '200 OK') + + iterator = iter(iterable) + self.assertEqual('aaaaa', iterator.next()) + self.assertEqual('bbbbb', iterator.next()) + iterable.close() + self.assertRaises(StopIteration, iterator.next) + if __name__ == '__main__': unittest.main()