diff --git a/swift/common/middleware/catch_errors.py b/swift/common/middleware/catch_errors.py index f737cfdde7..0b3d37626b 100644 --- a/swift/common/middleware/catch_errors.py +++ b/swift/common/middleware/catch_errors.py @@ -22,7 +22,7 @@ class BadResponseLength(Exception): pass -def enforce_byte_count(inner_iter, nbytes): +class ByteEnforcer(object): """ Enforces that inner_iter yields exactly bytes before exhaustion. @@ -31,25 +31,39 @@ def enforce_byte_count(inner_iter, nbytes): :param inner_iter: iterable of bytestrings :param nbytes: number of bytes expected - """ - try: - bytes_left = nbytes - for chunk in inner_iter: - if bytes_left>= len(chunk): - yield chunk - bytes_left -= len(chunk) - else: - yield chunk[:bytes_left] - raise BadResponseLength( - "Too many bytes; truncating after %d bytes " - "with at least %d surplus bytes remaining" % ( - nbytes, len(chunk) - bytes_left)) - if bytes_left: - raise BadResponseLength('Expected another %d bytes' % ( - bytes_left,)) - finally: - close_if_possible(inner_iter) + N.B. since we require the nbytes param and require the inner_iter to yield + exactly that many bytes we can support the __len__ interface for anyone + happens to expect non chunked resp iterables to support that + (e.g. eventlet's wsgi.server). + """ + + def __init__(self, inner_iter, nbytes): + self.inner_iter = inner_iter + self.nbytes = nbytes + + def __len__(self): + return self.nbytes + + def __iter__(self): + try: + bytes_left = self.nbytes + for chunk in self.inner_iter: + if bytes_left>= len(chunk): + yield chunk + bytes_left -= len(chunk) + else: + yield chunk[:bytes_left] + raise BadResponseLength( + "Too many bytes; truncating after %d bytes " + "with at least %d surplus bytes remaining" % ( + self.nbytes, len(chunk) - bytes_left)) + + if bytes_left: + raise BadResponseLength('Expected another %d bytes' % ( + bytes_left,)) + finally: + close_if_possible(self.inner_iter) class CatchErrorsContext(WSGIContext): @@ -99,7 +113,7 @@ class CatchErrorsContext(WSGIContext): # and raise an exception to stop any more bytes from being # generated and also to kill the TCP connection. if env['REQUEST_METHOD'] == 'HEAD': - resp = enforce_byte_count(resp, 0) + resp = ByteEnforcer(resp, 0) elif self._response_headers: content_lengths = [val for header, val in self._response_headers @@ -110,7 +124,7 @@ class CatchErrorsContext(WSGIContext): except ValueError: pass else: - resp = enforce_byte_count(resp, content_length) + resp = ByteEnforcer(resp, content_length) # make sure the response has the trans_id if self._response_headers is None: diff --git a/swift/common/middleware/proxy_logging.py b/swift/common/middleware/proxy_logging.py index f4d2bc6ff9..6609099b93 100644 --- a/swift/common/middleware/proxy_logging.py +++ b/swift/common/middleware/proxy_logging.py @@ -89,7 +89,7 @@ import os import time from swift.common.constraints import valid_api_version -from swift.common.middleware.catch_errors import enforce_byte_count +from swift.common.middleware.catch_errors import ByteEnforcer from swift.common.request_helpers import get_log_info from swift.common.swob import Request from swift.common.utils import (get_logger, get_remote_client, @@ -430,7 +430,7 @@ class ProxyLoggingMiddleware(object): if method == 'HEAD': content_length = 0 if content_length is not None: - iterator = enforce_byte_count(iterator, content_length) + iterator = ByteEnforcer(iterator, content_length) wire_status_int = int(start_response_args[0][0].split(' ', 1)[0]) resp_headers = dict(start_response_args[0][1]) diff --git a/swift/common/middleware/s3api/s3response.py b/swift/common/middleware/s3api/s3response.py index fa7f81873a..d6a4f4f10f 100644 --- a/swift/common/middleware/s3api/s3response.py +++ b/swift/common/middleware/s3api/s3response.py @@ -243,6 +243,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException): swob.HTTPException.__init__( self, status=kwargs.pop('status', self._status), + # we use an app_iter, so that we can add our trans_id to the resp + # xml *after* we've been called - technically any non-None app_iter + # would do, we override swob.Response._response_iter anyway. app_iter=self._body_iter(), content_type='application/xml', *args, **kwargs) @@ -265,6 +268,9 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException): error_elem = Element('Error') SubElement(error_elem, 'Code').text = self._code SubElement(error_elem, 'Message').text = self._msg + # N.B. swob.Response objects don't normally have an environ attribute + # when they're created, but swob always gives this to us when we're + # __call__'d if 'swift.trans_id' in self.environ: request_id = self.environ['swift.trans_id'] SubElement(error_elem, 'RequestId').text = request_id @@ -274,6 +280,13 @@ class ErrorResponse(S3ResponseBase, swob.HTTPException): yield tostring(error_elem, use_s3ns=False, xml_declaration=self.xml_declaration) + def _response_iter(self, app_iter, body): + # we don't actually want our _response_iter to be a generator, a list + # of strings is much better for eventlet.wsgi.server connection + # handling and request pipelining and ErrorResponses are small. FWIW + # we now have self.environ, app_iter=self._body_iter() and body is None + return super()._response_iter(list(app_iter), body) + def _dict_to_etree(self, parent, d): for key, value in d.items(): tag = re.sub(r'\W', '', snake_to_camel(key)) diff --git a/test/unit/common/middleware/s3api/test_s3response.py b/test/unit/common/middleware/s3api/test_s3response.py index accd7aecb7..2cb8c47b1c 100644 --- a/test/unit/common/middleware/s3api/test_s3response.py +++ b/test/unit/common/middleware/s3api/test_s3response.py @@ -15,8 +15,9 @@ import unittest -from swift.common.swob import Response +from swift.common.swob import Response, Request from swift.common.utils import HeaderKeyDict +from swift.common.middleware.catch_errors import CatchErrorMiddleware from swift.common.middleware.s3api.s3response import S3Response, ErrorResponse from swift.common.middleware.s3api.utils import sysmeta_prefix @@ -124,6 +125,27 @@ class TestErrorResponse(unittest.TestCase): b"", resp.body) + def test_error_response_trans_id(self): + req = Request.blank('/bucket/object') + err = DummyErrorResponse(msg='my-msg', reason='my reason') + app = CatchErrorMiddleware(err, {}) + with unittest.mock.patch( + 'swift.common.middleware.catch_errors.generate_trans_id', + return_value='fake-trans-id'): + resp = req.get_response(app) + self.assertIn('swift.trans_id', req.environ) + self.assertEqual(418, resp.status_int) + self.assertIn('X-Trans-Id', resp.headers) + self.assertEqual( + b"\n" + b"" + b"DummyErrorResponse" + b"my-msg" + b"fake-trans-id" + b"", + resp.body) + self.assertEqual(146, int(resp.headers['Content-Length'])) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/common/middleware/test_catch_errors.py b/test/unit/common/middleware/test_catch_errors.py index 2748c555f7..6b265b13d9 100644 --- a/test/unit/common/middleware/test_catch_errors.py +++ b/test/unit/common/middleware/test_catch_errors.py @@ -15,7 +15,7 @@ import unittest -from swift.common.swob import Request +from swift.common.swob import Request, HTTPOk from swift.common.middleware import catch_errors from swift.common.utils import get_logger @@ -137,6 +137,30 @@ class TestCatchErrors(unittest.TestCase): resp = app(req.environ, self.start_response) self.assertEqual(list(resp), [b'An error occurred']) + def test_has_len(self): + # sanity + app = HTTPOk(body='test-body') + req = Request.blank('/') + captured_status_length = [] + + def capture_start_resp(status, headers, exc_info=None): + length = None + for k, v in headers: + if k == 'Content-Length': + length = int(v) + captured_status_length.append((status, length)) + iterable = app(req.environ, capture_start_resp) + self.assertEqual(captured_status_length, [('200 OK', 9)]) + self.assertTrue(hasattr(iterable, '__len__')) + # wrapped should work the same way + app_resp = HTTPOk(body='test-body') + app = catch_errors.CatchErrorMiddleware(app_resp, {}) + req = Request.blank('/') + captured_status_length = [] + iterable = app(req.environ, capture_start_resp) + self.assertEqual(captured_status_length, [('200 OK', 9)]) + self.assertTrue(hasattr(iterable, '__len__')) + def test_HEAD_with_content_length(self): def cannot_count_app(env, sr): sr("200 OK", [("Content-Length", "10")])

AltStyle によって変換されたページ (->オリジナル) /