Ensure domain stored in memcached gets utf8 decoded on py2

Change-Id: I73b5af9645f3f7349144384609bf18a79620e92f
Closes-Bug: #1862115
This commit is contained in:
Charles Hsu 2020-02-06 15:35:11 +08:00 committed by Tim Burke
parent c0b4d644df
commit 61bf5ee1c4
2 changed files with 20 additions and 6 deletions

View File

@ -27,7 +27,7 @@ maximum lookup depth. If a match is found, the environment's Host header is
rewritten and the request is passed further down the WSGI chain.
"""
from six.moves import range
import six
from swift import gettext_ as _
@ -41,7 +41,8 @@ else: # executed if the try block finishes with no errors
MODULE_DEPENDENCY_MET = True
from swift.common.middleware import RewriteContext
from swift.common.swob import Request, HTTPBadRequest
from swift.common.swob import Request, HTTPBadRequest, \
str_to_wsgi, wsgi_to_str
from swift.common.utils import cache_from_env, get_logger, is_valid_ip, \
list_from_csv, parse_socket_string, register_swift_info
@ -130,9 +131,10 @@ class CNAMELookupMiddleware(object):
if not self.storage_domain:
return self.app(env, start_response)
if 'HTTP_HOST' in env:
requested_host = given_domain = env['HTTP_HOST']
requested_host = env['HTTP_HOST']
else:
requested_host = given_domain = env['SERVER_NAME']
requested_host = env['SERVER_NAME']
given_domain = wsgi_to_str(requested_host)
port = ''
if ':' in given_domain:
given_domain, port = given_domain.rsplit(':', 1)
@ -148,6 +150,8 @@ class CNAMELookupMiddleware(object):
if self.memcache:
memcache_key = ''.join(['cname-', a_domain])
found_domain = self.memcache.get(memcache_key)
if six.PY2 and found_domain:
found_domain = found_domain.encode('utf-8')
if found_domain is None:
ttl, found_domain = lookup_cname(a_domain, self.resolver)
if self.memcache and ttl > 0:
@ -166,9 +170,10 @@ class CNAMELookupMiddleware(object):
{'given_domain': given_domain,
'found_domain': found_domain})
if port:
env['HTTP_HOST'] = ':'.join([found_domain, port])
env['HTTP_HOST'] = ':'.join([
str_to_wsgi(found_domain), port])
else:
env['HTTP_HOST'] = found_domain
env['HTTP_HOST'] = str_to_wsgi(found_domain)
error = False
break
else:

View File

@ -170,6 +170,10 @@ class TestCNAMELookup(unittest.TestCase):
return self.cache.get(key, None)
def set(self, key, value, *a, **kw):
# real memcache client will JSON-serialize, so our mock
# should be sure to return unicode
if isinstance(value, bytes):
value = value.decode('utf-8')
self.cache[key] = value
module = 'swift.common.middleware.cname_lookup.lookup_cname'
@ -186,6 +190,9 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET',
'swift.cache': memcache},
headers={'Host': 'mysite2.com'})
@ -194,6 +201,8 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
for exc, num in ((dns.resolver.NXDOMAIN(), 3),
(dns.resolver.NoAnswer(), 4)):