From 2db9453722745300418980a1b122cbed9e557e0b Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Tue, 3 Dec 2013 14:49:57 -0800 Subject: [PATCH] Preserve closeability of app iterables PEP 333 (WSGI) says that if your iterable has a close() method, the framework must call it. WSGIContext._app_call pulls the first chunk off the returned iterable to make sure that it gets status and headers, and then it would itertools.chain() that first chunk back onto the iterable so the whole body went out. swob.Response.call_application() does it too. The problem is that an itertools.chain object doesn't have a close() method, so your iterable's fancy-pants close() method has no chance of getting called. This patch adds a slightly smarter CloseableChain that works like itertools.chain, but has a close() method that calls the underlying iterables' close() methods, if any. Change-Id: If975c93f53c27dfa0c2f52f4bbf599af25202f70 --- swift/common/utils.py | 20 +++++++++++++++++++- swift/common/wsgi.py | 6 +++--- test/unit/common/test_swob.py | 16 ++++++++++++++++ test/unit/common/test_wsgi.py | 21 +++++++++++++++++++++ 4 files changed, 59 insertions(+), 4 deletions(-) diff --git a/swift/common/utils.py b/swift/common/utils.py index 0a9fc4dedf..78f69781a6 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -2074,6 +2074,24 @@ def csv_append(csv_string, item): return item +class CloseableChain(object): + """ + Like itertools.chain, but with a close method that will attempt to invoke + its sub-iterators' close methods, if any. + """ + def __init__(self, *iterables): + self.iterables = iterables + + def __iter__(self): + return iter(itertools.chain(*(self.iterables))) + + def close(self): + for it in self.iterables: + close_method = getattr(it, 'close', None) + if close_method: + close_method() + + def reiterate(iterable): """ Consume the first item from an iterator, then re-chain it to the rest of @@ -2090,7 +2108,7 @@ def reiterate(iterable): chunk = '' while not chunk: chunk = next(iterator) - return itertools.chain([chunk], iterator) + return CloseableChain([chunk], iterator) except StopIteration: return [] diff --git a/swift/common/wsgi.py b/swift/common/wsgi.py index f75e1f7a12..1bacb4eff7 100644 --- a/swift/common/wsgi.py +++ b/swift/common/wsgi.py @@ -21,7 +21,6 @@ import signal import time import mimetools from swift import gettext_ as _ -from itertools import chain from StringIO import StringIO import eventlet @@ -35,7 +34,8 @@ from swift.common import utils from swift.common.swob import Request from swift.common.utils import capture_stdio, disable_fallocate, \ drop_privileges, get_logger, NullLogger, config_true_value, \ - validate_configuration, get_hub, config_auto_int_value + validate_configuration, get_hub, config_auto_int_value, \ + CloseableChain try: import multiprocessing @@ -401,7 +401,7 @@ class WSGIContext(object): except StopIteration: return iter([]) else: # We got a first_chunk - return chain([first_chunk], resp) + return CloseableChain([first_chunk], resp) def _get_status_int(self): """ diff --git a/test/unit/common/test_swob.py b/test/unit/common/test_swob.py index 3ba51ab858..d910c95d33 100644 --- a/test/unit/common/test_swob.py +++ b/test/unit/common/test_swob.py @@ -936,6 +936,22 @@ class TestResponse(unittest.TestCase): output_iter = resp(req.environ, lambda *_: None) self.assertEquals(list(output_iter), ['']) + def test_call_preserves_closeability(self): + def test_app(environ, start_response): + start_response('200 OK', []) + yield "igloo" + yield "shindig" + yield "macadamia" + yield "hullabaloo" + req = swift.common.swob.Request.blank('/') + req.method = 'GET' + status, headers, app_iter = req.call_application(test_app) + iterator = iter(app_iter) + self.assertEqual('igloo', iterator.next()) + self.assertEqual('shindig', iterator.next()) + app_iter.close() + self.assertRaises(StopIteration, iterator.next) + def test_location_rewrite(self): def start_response(env, headers): pass diff --git a/test/unit/common/test_wsgi.py b/test/unit/common/test_wsgi.py index 869fbe2593..af8987b198 100644 --- a/test/unit/common/test_wsgi.py +++ b/test/unit/common/test_wsgi.py @@ -598,6 +598,27 @@ class TestWSGIContext(unittest.TestCase): self.assertEquals(wc._response_status, '404 Not Found') self.assertEquals(''.join(it), 'Ok\n') + def test_app_iter_is_closable(self): + + def app(env, start_response): + start_response('200 OK', [('Content-Length', '25')]) + yield 'aaaaa' + yield 'bbbbb' + yield 'ccccc' + yield 'ddddd' + yield 'eeeee' + + wc = wsgi.WSGIContext(app) + r = Request.blank('/') + iterable = wc._app_call(r.environ) + self.assertEquals(wc._response_status, '200 OK') + + iterator = iter(iterable) + self.assertEqual('aaaaa', iterator.next()) + self.assertEqual('bbbbb', iterator.next()) + iterable.close() + self.assertRaises(StopIteration, iterator.next) + if __name__ == '__main__': unittest.main()