Merge "Preserve closeability of app iterables"

This commit is contained in:
Jenkins 2013-12-06 22:20:20 +00:00 committed by Gerrit Code Review
commit fb1a985ff0
4 changed files with 59 additions and 4 deletions

View File

@ -2074,6 +2074,24 @@ def csv_append(csv_string, item):
return item return item
class CloseableChain(object):
"""
Like itertools.chain, but with a close method that will attempt to invoke
its sub-iterators' close methods, if any.
"""
def __init__(self, *iterables):
self.iterables = iterables
def __iter__(self):
return iter(itertools.chain(*(self.iterables)))
def close(self):
for it in self.iterables:
close_method = getattr(it, 'close', None)
if close_method:
close_method()
def reiterate(iterable): def reiterate(iterable):
""" """
Consume the first item from an iterator, then re-chain it to the rest of Consume the first item from an iterator, then re-chain it to the rest of
@ -2090,7 +2108,7 @@ def reiterate(iterable):
chunk = '' chunk = ''
while not chunk: while not chunk:
chunk = next(iterator) chunk = next(iterator)
return itertools.chain([chunk], iterator) return CloseableChain([chunk], iterator)
except StopIteration: except StopIteration:
return [] return []

View File

@ -21,7 +21,6 @@ import signal
import time import time
import mimetools import mimetools
from swift import gettext_ as _ from swift import gettext_ as _
from itertools import chain
from StringIO import StringIO from StringIO import StringIO
import eventlet import eventlet
@ -35,7 +34,8 @@ from swift.common import utils
from swift.common.swob import Request from swift.common.swob import Request
from swift.common.utils import capture_stdio, disable_fallocate, \ from swift.common.utils import capture_stdio, disable_fallocate, \
drop_privileges, get_logger, NullLogger, config_true_value, \ drop_privileges, get_logger, NullLogger, config_true_value, \
validate_configuration, get_hub, config_auto_int_value validate_configuration, get_hub, config_auto_int_value, \
CloseableChain
try: try:
import multiprocessing import multiprocessing
@ -401,7 +401,7 @@ class WSGIContext(object):
except StopIteration: except StopIteration:
return iter([]) return iter([])
else: # We got a first_chunk else: # We got a first_chunk
return chain([first_chunk], resp) return CloseableChain([first_chunk], resp)
def _get_status_int(self): def _get_status_int(self):
""" """

View File

@ -936,6 +936,22 @@ class TestResponse(unittest.TestCase):
output_iter = resp(req.environ, lambda *_: None) output_iter = resp(req.environ, lambda *_: None)
self.assertEquals(list(output_iter), ['']) self.assertEquals(list(output_iter), [''])
def test_call_preserves_closeability(self):
def test_app(environ, start_response):
start_response('200 OK', [])
yield "igloo"
yield "shindig"
yield "macadamia"
yield "hullabaloo"
req = swift.common.swob.Request.blank('/')
req.method = 'GET'
status, headers, app_iter = req.call_application(test_app)
iterator = iter(app_iter)
self.assertEqual('igloo', iterator.next())
self.assertEqual('shindig', iterator.next())
app_iter.close()
self.assertRaises(StopIteration, iterator.next)
def test_location_rewrite(self): def test_location_rewrite(self):
def start_response(env, headers): def start_response(env, headers):
pass pass

View File

@ -598,6 +598,27 @@ class TestWSGIContext(unittest.TestCase):
self.assertEquals(wc._response_status, '404 Not Found') self.assertEquals(wc._response_status, '404 Not Found')
self.assertEquals(''.join(it), 'Ok\n') self.assertEquals(''.join(it), 'Ok\n')
def test_app_iter_is_closable(self):
def app(env, start_response):
start_response('200 OK', [('Content-Length', '25')])
yield 'aaaaa'
yield 'bbbbb'
yield 'ccccc'
yield 'ddddd'
yield 'eeeee'
wc = wsgi.WSGIContext(app)
r = Request.blank('/')
iterable = wc._app_call(r.environ)
self.assertEquals(wc._response_status, '200 OK')
iterator = iter(iterable)
self.assertEqual('aaaaa', iterator.next())
self.assertEqual('bbbbb', iterator.next())
iterable.close()
self.assertRaises(StopIteration, iterator.next)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()