diff --git a/swift/common/middleware/catch_errors.py b/swift/common/middleware/catch_errors.py index 0f123e63a0..4053825376 100644 --- a/swift/common/middleware/catch_errors.py +++ b/swift/common/middleware/catch_errors.py @@ -18,6 +18,37 @@ import uuid from swift.common.swob import Request, HTTPServerError from swift.common.utils import get_logger +from swift.common.wsgi import WSGIContext + + +class CatchErrorsContext(WSGIContext): + + def __init__(self, app, logger): + super(CatchErrorsContext, self).__init__(app) + self.logger = logger + + def handle_request(self, env, start_response): + trans_id = 'tx' + uuid.uuid4().hex + env['swift.trans_id'] = trans_id + self.logger.txn_id = trans_id + try: + # catch any errors in the pipeline + resp = self._app_call(env) + except (Exception, Timeout), err: + self.logger.exception(_('Error: %s'), err) + resp = HTTPServerError(request=Request(env), + body='An error occurred', + content_type='text/plain') + resp.headers['x-trans-id'] = trans_id + return resp(env, start_response) + + # make sure the response has the trans_id + if self._response_headers is None: + self._response_headers = [] + self._response_headers.append(('x-trans-id', trans_id)) + start_response(self._response_status, self._response_headers, + self._response_exc_info) + return resp class CatchErrorMiddleware(object): @@ -34,23 +65,8 @@ class CatchErrorMiddleware(object): """ If used, this should be the first middleware in pipeline. """ - trans_id = 'tx' + uuid.uuid4().hex - env['swift.trans_id'] = trans_id - self.logger.txn_id = trans_id - try: - - def my_start_response(status, response_headers, exc_info=None): - trans_header = ('x-trans-id', trans_id) - response_headers.append(trans_header) - return start_response(status, response_headers, exc_info) - return self.app(env, my_start_response) - except (Exception, Timeout), err: - self.logger.exception(_('Error: %s'), err) - resp = HTTPServerError(request=Request(env), - body='An error occurred', - content_type='text/plain') - resp.headers['x-trans-id'] = trans_id - return resp(env, start_response) + context = CatchErrorsContext(self.app, self.logger) + return context.handle_request(env, start_response) def filter_factory(global_conf, **local_conf): diff --git a/test/unit/common/middleware/test_except.py b/test/unit/common/middleware/test_except.py index 94ed2a74df..05c243777b 100644 --- a/test/unit/common/middleware/test_except.py +++ b/test/unit/common/middleware/test_except.py @@ -20,15 +20,19 @@ from swift.common.middleware import catch_errors from swift.common.utils import get_logger class FakeApp(object): - def __init__(self, error=False): + def __init__(self, error=False, body_iter=None): self.error = error + self.body_iter = body_iter def __call__(self, env, start_response): if 'swift.trans_id' not in env: raise Exception('Trans id should always be in env') if self.error: raise Exception('An error occurred') - return ["FAKE APP"] + if self.body_iter is None: + return ["FAKE APP"] + else: + return self.body_iter def start_response(*args): pass @@ -43,17 +47,18 @@ class TestCatchErrors(unittest.TestCase): app = catch_errors.CatchErrorMiddleware(FakeApp(), {}) req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'}) resp = app(req.environ, start_response) - self.assertEquals(resp, ['FAKE APP']) + self.assertEquals(list(resp), ['FAKE APP']) def test_catcherrors(self): app = catch_errors.CatchErrorMiddleware(FakeApp(True), {}) req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'}) resp = app(req.environ, start_response) - self.assertEquals(resp, ['An error occurred']) + self.assertEquals(list(resp), ['An error occurred']) def test_trans_id_header_pass(self): self.assertEquals(self.logger.txn_id, None) - def start_response(status, headers): + + def start_response(status, headers, exc_info=None): self.assert_('x-trans-id' in (x[0] for x in headers)) app = catch_errors.CatchErrorMiddleware(FakeApp(), {}) req = Request.blank('/v1/a/c/o') @@ -62,12 +67,20 @@ class TestCatchErrors(unittest.TestCase): def test_trans_id_header_fail(self): self.assertEquals(self.logger.txn_id, None) - def start_response(status, headers): + + def start_response(status, headers, exc_info=None): self.assert_('x-trans-id' in (x[0] for x in headers)) app = catch_errors.CatchErrorMiddleware(FakeApp(True), {}) req = Request.blank('/v1/a/c/o') app(req.environ, start_response) self.assertEquals(len(self.logger.txn_id), 34) + def test_error_in_iterator(self): + app = catch_errors.CatchErrorMiddleware( + FakeApp(body_iter=(int(x) for x in 'abcd')), {}) + req = Request.blank('/', environ={'REQUEST_METHOD': 'GET'}) + resp = app(req.environ, start_response) + self.assertEquals(list(resp), ['An error occurred']) + if __name__ == '__main__': unittest.main()