change catch_errors to use WSGIContext

The current catch_errors (ie without this patch) relinquishes control
before the underlying middleware/app has been evaluated. This results
in not catching errors in the stack when they occur in either the
start_response or in generating the first chunk sent to the client of
the underlying stack.

Change-Id: Iecd21e4fc7e30fa20239d011f69216354b50baf1
This commit is contained in:
John Dickinson 2012-11-13 16:07:16 -08:00
parent 4236e6379b
commit 20d4b00645
2 changed files with 52 additions and 23 deletions

View File

@ -18,6 +18,37 @@ import uuid
from swift.common.swob import Request, HTTPServerError from swift.common.swob import Request, HTTPServerError
from swift.common.utils import get_logger from swift.common.utils import get_logger
from swift.common.wsgi import WSGIContext
class CatchErrorsContext(WSGIContext):
def __init__(self, app, logger):
super(CatchErrorsContext, self).__init__(app)
self.logger = logger
def handle_request(self, env, start_response):
trans_id = 'tx' + uuid.uuid4().hex
env['swift.trans_id'] = trans_id
self.logger.txn_id = trans_id
try:
# catch any errors in the pipeline
resp = self._app_call(env)
except (Exception, Timeout), err:
self.logger.exception(_('Error: %s'), err)
resp = HTTPServerError(request=Request(env),
body='An error occurred',
content_type='text/plain')
resp.headers['x-trans-id'] = trans_id
return resp(env, start_response)
# make sure the response has the trans_id
if self._response_headers is None:
self._response_headers = []
self._response_headers.append(('x-trans-id', trans_id))
start_response(self._response_status, self._response_headers,
self._response_exc_info)
return resp
class CatchErrorMiddleware(object): class CatchErrorMiddleware(object):
@ -34,23 +65,8 @@ class CatchErrorMiddleware(object):
""" """
If used, this should be the first middleware in pipeline. If used, this should be the first middleware in pipeline.
""" """
trans_id = 'tx' + uuid.uuid4().hex context = CatchErrorsContext(self.app, self.logger)
env['swift.trans_id'] = trans_id return context.handle_request(env, start_response)
self.logger.txn_id = trans_id
try:
def my_start_response(status, response_headers, exc_info=None):
trans_header = ('x-trans-id', trans_id)
response_headers.append(trans_header)
return start_response(status, response_headers, exc_info)
return self.app(env, my_start_response)
except (Exception, Timeout), err:
self.logger.exception(_('Error: %s'), err)
resp = HTTPServerError(request=Request(env),
body='An error occurred',
content_type='text/plain')
resp.headers['x-trans-id'] = trans_id
return resp(env, start_response)
def filter_factory(global_conf, **local_conf): def filter_factory(global_conf, **local_conf):

View File

@ -20,15 +20,19 @@ from swift.common.middleware import catch_errors
from swift.common.utils import get_logger from swift.common.utils import get_logger
class FakeApp(object): class FakeApp(object):
def __init__(self, error=False): def __init__(self, error=False, body_iter=None):
self.error = error self.error = error
self.body_iter = body_iter
def __call__(self, env, start_response): def __call__(self, env, start_response):
if 'swift.trans_id' not in env: if 'swift.trans_id' not in env:
raise Exception('Trans id should always be in env') raise Exception('Trans id should always be in env')
if self.error: if self.error:
raise Exception('An error occurred') raise Exception('An error occurred')
if self.body_iter is None:
return ["FAKE APP"] return ["FAKE APP"]
else:
return self.body_iter
def start_response(*args): def start_response(*args):
pass pass
@ -43,17 +47,18 @@ class TestCatchErrors(unittest.TestCase):
app = catch_errors.CatchErrorMiddleware(FakeApp(), {}) app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'}) req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
resp = app(req.environ, start_response) resp = app(req.environ, start_response)
self.assertEquals(resp, ['FAKE APP']) self.assertEquals(list(resp), ['FAKE APP'])
def test_catcherrors(self): def test_catcherrors(self):
app = catch_errors.CatchErrorMiddleware(FakeApp(True), {}) app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'}) req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
resp = app(req.environ, start_response) resp = app(req.environ, start_response)
self.assertEquals(resp, ['An error occurred']) self.assertEquals(list(resp), ['An error occurred'])
def test_trans_id_header_pass(self): def test_trans_id_header_pass(self):
self.assertEquals(self.logger.txn_id, None) self.assertEquals(self.logger.txn_id, None)
def start_response(status, headers):
def start_response(status, headers, exc_info=None):
self.assert_('x-trans-id' in (x[0] for x in headers)) self.assert_('x-trans-id' in (x[0] for x in headers))
app = catch_errors.CatchErrorMiddleware(FakeApp(), {}) app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
req = Request.blank('/v1/a/c/o') req = Request.blank('/v1/a/c/o')
@ -62,12 +67,20 @@ class TestCatchErrors(unittest.TestCase):
def test_trans_id_header_fail(self): def test_trans_id_header_fail(self):
self.assertEquals(self.logger.txn_id, None) self.assertEquals(self.logger.txn_id, None)
def start_response(status, headers):
def start_response(status, headers, exc_info=None):
self.assert_('x-trans-id' in (x[0] for x in headers)) self.assert_('x-trans-id' in (x[0] for x in headers))
app = catch_errors.CatchErrorMiddleware(FakeApp(True), {}) app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
req = Request.blank('/v1/a/c/o') req = Request.blank('/v1/a/c/o')
app(req.environ, start_response) app(req.environ, start_response)
self.assertEquals(len(self.logger.txn_id), 34) self.assertEquals(len(self.logger.txn_id), 34)
def test_error_in_iterator(self):
app = catch_errors.CatchErrorMiddleware(
FakeApp(body_iter=(int(x) for x in 'abcd')), {})
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
resp = app(req.environ, start_response)
self.assertEquals(list(resp), ['An error occurred'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()