change catch_errors to use WSGIContext
The current catch_errors (ie without this patch) relinquishes control before the underlying middleware/app has been evaluated. This results in not catching errors in the stack when they occur in either the start_response or in generating the first chunk sent to the client of the underlying stack. Change-Id: Iecd21e4fc7e30fa20239d011f69216354b50baf1
This commit is contained in:
parent
4236e6379b
commit
20d4b00645
@ -18,6 +18,37 @@ import uuid
|
||||
|
||||
from swift.common.swob import Request, HTTPServerError
|
||||
from swift.common.utils import get_logger
|
||||
from swift.common.wsgi import WSGIContext
|
||||
|
||||
|
||||
class CatchErrorsContext(WSGIContext):
|
||||
|
||||
def __init__(self, app, logger):
|
||||
super(CatchErrorsContext, self).__init__(app)
|
||||
self.logger = logger
|
||||
|
||||
def handle_request(self, env, start_response):
|
||||
trans_id = 'tx' + uuid.uuid4().hex
|
||||
env['swift.trans_id'] = trans_id
|
||||
self.logger.txn_id = trans_id
|
||||
try:
|
||||
# catch any errors in the pipeline
|
||||
resp = self._app_call(env)
|
||||
except (Exception, Timeout), err:
|
||||
self.logger.exception(_('Error: %s'), err)
|
||||
resp = HTTPServerError(request=Request(env),
|
||||
body='An error occurred',
|
||||
content_type='text/plain')
|
||||
resp.headers['x-trans-id'] = trans_id
|
||||
return resp(env, start_response)
|
||||
|
||||
# make sure the response has the trans_id
|
||||
if self._response_headers is None:
|
||||
self._response_headers = []
|
||||
self._response_headers.append(('x-trans-id', trans_id))
|
||||
start_response(self._response_status, self._response_headers,
|
||||
self._response_exc_info)
|
||||
return resp
|
||||
|
||||
|
||||
class CatchErrorMiddleware(object):
|
||||
@ -34,23 +65,8 @@ class CatchErrorMiddleware(object):
|
||||
"""
|
||||
If used, this should be the first middleware in pipeline.
|
||||
"""
|
||||
trans_id = 'tx' + uuid.uuid4().hex
|
||||
env['swift.trans_id'] = trans_id
|
||||
self.logger.txn_id = trans_id
|
||||
try:
|
||||
|
||||
def my_start_response(status, response_headers, exc_info=None):
|
||||
trans_header = ('x-trans-id', trans_id)
|
||||
response_headers.append(trans_header)
|
||||
return start_response(status, response_headers, exc_info)
|
||||
return self.app(env, my_start_response)
|
||||
except (Exception, Timeout), err:
|
||||
self.logger.exception(_('Error: %s'), err)
|
||||
resp = HTTPServerError(request=Request(env),
|
||||
body='An error occurred',
|
||||
content_type='text/plain')
|
||||
resp.headers['x-trans-id'] = trans_id
|
||||
return resp(env, start_response)
|
||||
context = CatchErrorsContext(self.app, self.logger)
|
||||
return context.handle_request(env, start_response)
|
||||
|
||||
|
||||
def filter_factory(global_conf, **local_conf):
|
||||
|
@ -20,15 +20,19 @@ from swift.common.middleware import catch_errors
|
||||
from swift.common.utils import get_logger
|
||||
|
||||
class FakeApp(object):
|
||||
def __init__(self, error=False):
|
||||
def __init__(self, error=False, body_iter=None):
|
||||
self.error = error
|
||||
self.body_iter = body_iter
|
||||
|
||||
def __call__(self, env, start_response):
|
||||
if 'swift.trans_id' not in env:
|
||||
raise Exception('Trans id should always be in env')
|
||||
if self.error:
|
||||
raise Exception('An error occurred')
|
||||
return ["FAKE APP"]
|
||||
if self.body_iter is None:
|
||||
return ["FAKE APP"]
|
||||
else:
|
||||
return self.body_iter
|
||||
|
||||
def start_response(*args):
|
||||
pass
|
||||
@ -43,17 +47,18 @@ class TestCatchErrors(unittest.TestCase):
|
||||
app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
|
||||
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
|
||||
resp = app(req.environ, start_response)
|
||||
self.assertEquals(resp, ['FAKE APP'])
|
||||
self.assertEquals(list(resp), ['FAKE APP'])
|
||||
|
||||
def test_catcherrors(self):
|
||||
app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
|
||||
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
|
||||
resp = app(req.environ, start_response)
|
||||
self.assertEquals(resp, ['An error occurred'])
|
||||
self.assertEquals(list(resp), ['An error occurred'])
|
||||
|
||||
def test_trans_id_header_pass(self):
|
||||
self.assertEquals(self.logger.txn_id, None)
|
||||
def start_response(status, headers):
|
||||
|
||||
def start_response(status, headers, exc_info=None):
|
||||
self.assert_('x-trans-id' in (x[0] for x in headers))
|
||||
app = catch_errors.CatchErrorMiddleware(FakeApp(), {})
|
||||
req = Request.blank('/v1/a/c/o')
|
||||
@ -62,12 +67,20 @@ class TestCatchErrors(unittest.TestCase):
|
||||
|
||||
def test_trans_id_header_fail(self):
|
||||
self.assertEquals(self.logger.txn_id, None)
|
||||
def start_response(status, headers):
|
||||
|
||||
def start_response(status, headers, exc_info=None):
|
||||
self.assert_('x-trans-id' in (x[0] for x in headers))
|
||||
app = catch_errors.CatchErrorMiddleware(FakeApp(True), {})
|
||||
req = Request.blank('/v1/a/c/o')
|
||||
app(req.environ, start_response)
|
||||
self.assertEquals(len(self.logger.txn_id), 34)
|
||||
|
||||
def test_error_in_iterator(self):
|
||||
app = catch_errors.CatchErrorMiddleware(
|
||||
FakeApp(body_iter=(int(x) for x in 'abcd')), {})
|
||||
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'})
|
||||
resp = app(req.environ, start_response)
|
||||
self.assertEquals(list(resp), ['An error occurred'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user