py3: Monkey-patch json.loads to accept bytes on py35

I'm tired of creating code churn where I just slap

    .decode("nearly arbitrary choice of encoding")

in a bunch of places.

Change-Id: I79b2bc59fed130ca537e96c1074212861d7db6b8
This commit is contained in:
Tim Burke 2018-11-02 21:38:49 +00:00
parent 887ba87c5a
commit c112203e0e
12 changed files with 70 additions and 35 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
import gettext import gettext
import pkg_resources import pkg_resources
@ -39,3 +40,37 @@ _t = gettext.translation('swift', localedir=_localedir, fallback=True)
def gettext_(msg): def gettext_(msg):
return _t.gettext(msg) return _t.gettext(msg)
if (3, 0) <= sys.version_info[:2] <= (3, 5):
# In the development of py3, json.loads() stopped accepting byte strings
# for a while. https://bugs.python.org/issue17909 got fixed for py36, but
# since it was termed an enhancement and not a regression, we don't expect
# any backports. At the same time, it'd be better if we could avoid
# leaving a whole bunch of json.loads(resp.body.decode(...)) scars in the
# code that'd probably persist even *after* we drop support for 3.5 and
# earlier. So, monkey patch stdlib.
import json
if not getattr(json.loads, 'patched_to_decode', False):
class JsonLoadsPatcher(object):
def __init__(self, orig):
self._orig = orig
def __call__(self, s, **kw):
if isinstance(s, bytes):
# No fancy byte-order mark detection for us; just assume
# UTF-8 and raise a UnicodeDecodeError if appropriate.
s = s.decode('utf8')
return self._orig(s, **kw)
def __getattribute__(self, attr):
if attr == 'patched_to_decode':
return True
if attr == '_orig':
return super().__getattribute__(attr)
# Pass through all other attrs to the original; among other
# things, this preserves doc strings, etc.
return getattr(self._orig, attr)
json.loads = JsonLoadsPatcher(json.loads)
del JsonLoadsPatcher

View File

@ -174,7 +174,7 @@ def _get_direct_account_container(path, stype, node, part,
if resp.status == HTTP_NO_CONTENT: if resp.status == HTTP_NO_CONTENT:
resp.read() resp.read()
return resp_headers, [] return resp_headers, []
return resp_headers, json.loads(resp.read().decode('ascii')) return resp_headers, json.loads(resp.read())
def gen_headers(hdrs_in=None, add_ts=True): def gen_headers(hdrs_in=None, add_ts=True):

View File

@ -298,7 +298,7 @@ class InternalClient(object):
if resp.status_int >= HTTP_MULTIPLE_CHOICES: if resp.status_int >= HTTP_MULTIPLE_CHOICES:
b''.join(resp.app_iter) b''.join(resp.app_iter)
break break
data = json.loads(resp.body.decode('ascii')) data = json.loads(resp.body)
if not data: if not data:
break break
for item in data: for item in data:
@ -844,7 +844,7 @@ class SimpleClient(object):
body = conn.read() body = conn.read()
info = conn.info() info = conn.info()
try: try:
body_data = json.loads(body.decode('ascii')) body_data = json.loads(body)
except ValueError: except ValueError:
body_data = None body_data = None
trans_stop = time() trans_stop = time()

View File

@ -315,7 +315,7 @@ class MemcacheRing(object):
else: else:
value = None value = None
elif int(line[2]) & JSON_FLAG: elif int(line[2]) & JSON_FLAG:
value = json.loads(value.decode('ascii')) value = json.loads(value)
fp.readline() fp.readline()
line = fp.readline().strip().split() line = fp.readline().strip().split()
self._return_conn(server, fp, sock) self._return_conn(server, fp, sock)
@ -484,7 +484,7 @@ class MemcacheRing(object):
else: else:
value = None value = None
elif int(line[2]) & JSON_FLAG: elif int(line[2]) & JSON_FLAG:
value = json.loads(value.decode('ascii')) value = json.loads(value)
responses[line[1]] = value responses[line[1]] = value
fp.readline() fp.readline()
line = fp.readline().strip().split() line = fp.readline().strip().split()

View File

@ -185,7 +185,7 @@ class ListingFilter(object):
body = b''.join(resp_iter) body = b''.join(resp_iter)
try: try:
listing = json.loads(body.decode('ascii')) listing = json.loads(body)
# Do a couple sanity checks # Do a couple sanity checks
if not isinstance(listing, list): if not isinstance(listing, list):
raise ValueError raise ValueError

View File

@ -295,7 +295,7 @@ class SymlinkContainerContext(WSGIContext):
""" """
with closing_if_possible(resp_iter): with closing_if_possible(resp_iter):
resp_body = b''.join(resp_iter) resp_body = b''.join(resp_iter)
body_json = json.loads(resp_body.decode('ascii')) body_json = json.loads(resp_body)
swift_version, account, _junk = split_path(req.path, 2, 3, True) swift_version, account, _junk = split_path(req.path, 2, 3, True)
new_body = json.dumps( new_body = json.dumps(
[self._extract_symlink_path_json(obj_dict, swift_version, account) [self._extract_symlink_path_json(obj_dict, swift_version, account)

View File

@ -78,7 +78,7 @@ class RingData(object):
""" """
json_len, = struct.unpack('!I', gz_file.read(4)) json_len, = struct.unpack('!I', gz_file.read(4))
ring_dict = json.loads(gz_file.read(json_len).decode('ascii')) ring_dict = json.loads(gz_file.read(json_len))
ring_dict['replica2part2dev_id'] = [] ring_dict['replica2part2dev_id'] = []
if metadata_only: if metadata_only:

View File

@ -3477,7 +3477,7 @@ def dump_recon_cache(cache_dict, cache_file, logger, lock_timeout=2,
try: try:
existing_entry = cf.readline() existing_entry = cf.readline()
if existing_entry: if existing_entry:
cache_entry = json.loads(existing_entry.decode('utf8')) cache_entry = json.loads(existing_entry)
except ValueError: except ValueError:
# file doesn't have a valid entry, we'll recreate it # file doesn't have a valid entry, we'll recreate it
pass pass

View File

@ -59,7 +59,7 @@ class TestListingMiddleware(S3ApiTestCase):
req = Request.blank('/v1/a/c') req = Request.blank('/v1/a/c')
status, headers, body = self.call_s3api(req) status, headers, body = self.call_s3api(req)
self.assertEqual(json.loads(body.decode('ascii')), [ self.assertEqual(json.loads(body), [
{'name': 'obj1', 'hash': '0123456789abcdef0123456789abcdef'}, {'name': 'obj1', 'hash': '0123456789abcdef0123456789abcdef'},
{'name': 'obj2', 'hash': 'swiftetag', 's3_etag': '"mu-etag"'}, {'name': 'obj2', 'hash': 'swiftetag', 's3_etag': '"mu-etag"'},
{'name': 'obj2', 'hash': 'swiftetag; something=else'}, {'name': 'obj2', 'hash': 'swiftetag; something=else'},

View File

@ -240,7 +240,7 @@ class TestListEndpoints(unittest.TestCase):
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json') self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/1/a/c/o1", "http://10.1.1.1:6200/sdb1/1/a/c/o1",
"http://10.1.2.2:6200/sdd1/1/a/c/o1" "http://10.1.2.2:6200/sdd1/1/a/c/o1"
]) ])
@ -260,14 +260,14 @@ class TestListEndpoints(unittest.TestCase):
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json') self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")), self.assertEqual(json.loads(resp.body),
expected[pol.idx]) expected[pol.idx])
# Here, 'o1/' is the object name. # Here, 'o1/' is the object name.
resp = Request.blank('/endpoints/a/c/o1/').get_response( resp = Request.blank('/endpoints/a/c/o1/').get_response(
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c/o1/", "http://10.1.1.1:6200/sdb1/3/a/c/o1/",
"http://10.1.2.2:6200/sdd1/3/a/c/o1/" "http://10.1.2.2:6200/sdd1/3/a/c/o1/"
]) ])
@ -275,7 +275,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c2').get_response( resp = Request.blank('/endpoints/a/c2').get_response(
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sda1/2/a/c2", "http://10.1.1.1:6200/sda1/2/a/c2",
"http://10.1.2.1:6200/sdc1/2/a/c2" "http://10.1.2.1:6200/sdc1/2/a/c2"
]) ])
@ -283,7 +283,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a1').get_response( resp = Request.blank('/endpoints/a1').get_response(
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.2.1:6200/sdc1/0/a1", "http://10.1.2.1:6200/sdc1/0/a1",
"http://10.1.1.1:6200/sda1/0/a1", "http://10.1.1.1:6200/sda1/0/a1",
"http://10.1.1.1:6200/sdb1/0/a1" "http://10.1.1.1:6200/sdb1/0/a1"
@ -296,7 +296,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c 2').get_response( resp = Request.blank('/endpoints/a/c 2').get_response(
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c%202", "http://10.1.1.1:6200/sdb1/3/a/c%202",
"http://10.1.2.2:6200/sdd1/3/a/c%202" "http://10.1.2.2:6200/sdd1/3/a/c%202"
]) ])
@ -304,7 +304,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c%202').get_response( resp = Request.blank('/endpoints/a/c%202').get_response(
self.list_endpoints) self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c%202", "http://10.1.1.1:6200/sdb1/3/a/c%202",
"http://10.1.2.2:6200/sdd1/3/a/c%202" "http://10.1.2.2:6200/sdd1/3/a/c%202"
]) ])
@ -312,7 +312,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/ac%20count/con%20tainer/ob%20ject') \ resp = Request.blank('/endpoints/ac%20count/con%20tainer/ob%20ject') \
.get_response(self.list_endpoints) .get_response(self.list_endpoints)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [ self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/ac%20count/con%20tainer/ob%20ject", "http://10.1.1.1:6200/sdb1/3/ac%20count/con%20tainer/ob%20ject",
"http://10.1.2.2:6200/sdd1/3/ac%20count/con%20tainer/ob%20ject" "http://10.1.2.2:6200/sdd1/3/ac%20count/con%20tainer/ob%20ject"
]) ])
@ -342,7 +342,7 @@ class TestListEndpoints(unittest.TestCase):
.get_response(custom_path_le) .get_response(custom_path_le)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json') self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")), self.assertEqual(json.loads(resp.body),
expected[pol.idx]) expected[pol.idx])
# test custom path without trailing slash # test custom path without trailing slash
@ -356,7 +356,7 @@ class TestListEndpoints(unittest.TestCase):
.get_response(custom_path_le) .get_response(custom_path_le)
self.assertEqual(resp.status_int, 200) self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json') self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")), self.assertEqual(json.loads(resp.body),
expected[pol.idx]) expected[pol.idx])
def test_v1_response(self): def test_v1_response(self):
@ -364,7 +364,7 @@ class TestListEndpoints(unittest.TestCase):
resp = req.get_response(self.list_endpoints) resp = req.get_response(self.list_endpoints)
expected = ["http://10.1.1.1:6200/sdb1/1/a/c/o1", expected = ["http://10.1.1.1:6200/sdb1/1/a/c/o1",
"http://10.1.2.2:6200/sdd1/1/a/c/o1"] "http://10.1.2.2:6200/sdd1/1/a/c/o1"]
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
def test_v2_obj_response(self): def test_v2_obj_response(self):
req = Request.blank('/endpoints/v2/a/c/o1') req = Request.blank('/endpoints/v2/a/c/o1')
@ -374,7 +374,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.2.2:6200/sdd1/1/a/c/o1"], "http://10.1.2.2:6200/sdd1/1/a/c/o1"],
'headers': {'X-Backend-Storage-Policy-Index': "0"}, 'headers': {'X-Backend-Storage-Policy-Index': "0"},
} }
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
for policy in POLICIES: for policy in POLICIES:
patch_path = 'swift.common.middleware.list_endpoints' \ patch_path = 'swift.common.middleware.list_endpoints' \
'.get_container_info' '.get_container_info'
@ -390,7 +390,7 @@ class TestListEndpoints(unittest.TestCase):
'X-Backend-Storage-Policy-Index': str(int(policy))}, 'X-Backend-Storage-Policy-Index': str(int(policy))},
'endpoints': [path % node for node in nodes], 'endpoints': [path % node for node in nodes],
} }
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
def test_v2_non_obj_response(self): def test_v2_non_obj_response(self):
# account # account
@ -403,7 +403,7 @@ class TestListEndpoints(unittest.TestCase):
'headers': {}, 'headers': {},
} }
# container # container
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v2/a/c') req = Request.blank('/endpoints/v2/a/c')
resp = req.get_response(self.list_endpoints) resp = req.get_response(self.list_endpoints)
expected = { expected = {
@ -412,7 +412,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.2.1:6200/sdc1/0/a/c"], "http://10.1.2.1:6200/sdc1/0/a/c"],
'headers': {}, 'headers': {},
} }
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
def test_version_account_response(self): def test_version_account_response(self):
req = Request.blank('/endpoints/a') req = Request.blank('/endpoints/a')
@ -420,10 +420,10 @@ class TestListEndpoints(unittest.TestCase):
expected = ["http://10.1.2.1:6200/sdc1/0/a", expected = ["http://10.1.2.1:6200/sdc1/0/a",
"http://10.1.1.1:6200/sda1/0/a", "http://10.1.1.1:6200/sda1/0/a",
"http://10.1.1.1:6200/sdb1/0/a"] "http://10.1.1.1:6200/sdb1/0/a"]
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v1.0/a') req = Request.blank('/endpoints/v1.0/a')
resp = req.get_response(self.list_endpoints) resp = req.get_response(self.list_endpoints)
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v2/a') req = Request.blank('/endpoints/v2/a')
resp = req.get_response(self.list_endpoints) resp = req.get_response(self.list_endpoints)
@ -433,7 +433,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.1.1:6200/sdb1/0/a"], "http://10.1.1.1:6200/sdb1/0/a"],
'headers': {}, 'headers': {},
} }
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected) self.assertEqual(json.loads(resp.body), expected)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -213,7 +213,7 @@ class TestDirectClient(unittest.TestCase):
self.assertEqual(conn.req_headers['user-agent'], self.assertEqual(conn.req_headers['user-agent'],
self.user_agent) self.user_agent)
self.assertEqual(resp_headers, stub_headers) self.assertEqual(resp_headers, stub_headers)
self.assertEqual(json.loads(body.decode('ascii')), resp) self.assertEqual(json.loads(body), resp)
self.assertIn('format=json', conn.query_string) self.assertIn('format=json', conn.query_string)
for k, v in req_params.items(): for k, v in req_params.items():
if v is None: if v is None:
@ -389,7 +389,7 @@ class TestDirectClient(unittest.TestCase):
self.assertEqual(conn.req_headers['user-agent'], self.assertEqual(conn.req_headers['user-agent'],
self.user_agent) self.user_agent)
self.assertEqual(headers, resp_headers) self.assertEqual(headers, resp_headers)
self.assertEqual(json.loads(body.decode('ascii')), resp) self.assertEqual(json.loads(body), resp)
self.assertIn('format=json', conn.query_string) self.assertIn('format=json', conn.query_string)
for k, v in req_params.items(): for k, v in req_params.items():
if v is None: if v is None:

View File

@ -62,7 +62,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req) resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException) self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp)) self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii')) info = json.loads(resp.body)
self.assertNotIn('admin', info) self.assertNotIn('admin', info)
self.assertIn('foo', info) self.assertIn('foo', info)
self.assertIn('bar', info['foo']) self.assertIn('bar', info['foo'])
@ -89,7 +89,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req) resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException) self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp)) self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii')) info = json.loads(resp.body)
self.assertNotIn('admin', info) self.assertNotIn('admin', info)
self.assertIn('foo', info) self.assertIn('foo', info)
self.assertIn('bar', info['foo']) self.assertIn('bar', info['foo'])
@ -120,7 +120,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req) resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException) self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp)) self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii')) info = json.loads(resp.body)
self.assertIn('foo', info) self.assertIn('foo', info)
self.assertIn('bar', info['foo']) self.assertIn('bar', info['foo'])
self.assertEqual(info['foo']['bar'], 'baz') self.assertEqual(info['foo']['bar'], 'baz')
@ -156,7 +156,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req) resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException) self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp)) self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii')) info = json.loads(resp.body)
self.assertIn('admin', info) self.assertIn('admin', info)
self.assertIn('qux', info['admin']) self.assertIn('qux', info['admin'])
self.assertIn('quux', info['admin']['qux']) self.assertIn('quux', info['admin']['qux'])
@ -279,7 +279,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req) resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException) self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp)) self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii')) info = json.loads(resp.body)
self.assertNotIn('foo2', info) self.assertNotIn('foo2', info)
self.assertIn('admin', info) self.assertIn('admin', info)
self.assertIn('disallowed_sections', info['admin']) self.assertIn('disallowed_sections', info['admin'])