Merge "Ensure domain stored in memcached gets utf8 decoded on py2"

This commit is contained in:
Zuul 2020-02-19 04:15:13 +00:00 committed by Gerrit Code Review
commit 6a47d9e4f9
2 changed files with 20 additions and 6 deletions

View File

@ -27,7 +27,7 @@ maximum lookup depth. If a match is found, the environment's Host header is
rewritten and the request is passed further down the WSGI chain. rewritten and the request is passed further down the WSGI chain.
""" """
from six.moves import range import six
from swift import gettext_ as _ from swift import gettext_ as _
@ -41,7 +41,8 @@ else: # executed if the try block finishes with no errors
MODULE_DEPENDENCY_MET = True MODULE_DEPENDENCY_MET = True
from swift.common.middleware import RewriteContext from swift.common.middleware import RewriteContext
from swift.common.swob import Request, HTTPBadRequest from swift.common.swob import Request, HTTPBadRequest, \
str_to_wsgi, wsgi_to_str
from swift.common.utils import cache_from_env, get_logger, is_valid_ip, \ from swift.common.utils import cache_from_env, get_logger, is_valid_ip, \
list_from_csv, parse_socket_string, register_swift_info list_from_csv, parse_socket_string, register_swift_info
@ -130,9 +131,10 @@ class CNAMELookupMiddleware(object):
if not self.storage_domain: if not self.storage_domain:
return self.app(env, start_response) return self.app(env, start_response)
if 'HTTP_HOST' in env: if 'HTTP_HOST' in env:
requested_host = given_domain = env['HTTP_HOST'] requested_host = env['HTTP_HOST']
else: else:
requested_host = given_domain = env['SERVER_NAME'] requested_host = env['SERVER_NAME']
given_domain = wsgi_to_str(requested_host)
port = '' port = ''
if ':' in given_domain: if ':' in given_domain:
given_domain, port = given_domain.rsplit(':', 1) given_domain, port = given_domain.rsplit(':', 1)
@ -148,6 +150,8 @@ class CNAMELookupMiddleware(object):
if self.memcache: if self.memcache:
memcache_key = ''.join(['cname-', a_domain]) memcache_key = ''.join(['cname-', a_domain])
found_domain = self.memcache.get(memcache_key) found_domain = self.memcache.get(memcache_key)
if six.PY2 and found_domain:
found_domain = found_domain.encode('utf-8')
if found_domain is None: if found_domain is None:
ttl, found_domain = lookup_cname(a_domain, self.resolver) ttl, found_domain = lookup_cname(a_domain, self.resolver)
if self.memcache and ttl > 0: if self.memcache and ttl > 0:
@ -166,9 +170,10 @@ class CNAMELookupMiddleware(object):
{'given_domain': given_domain, {'given_domain': given_domain,
'found_domain': found_domain}) 'found_domain': found_domain})
if port: if port:
env['HTTP_HOST'] = ':'.join([found_domain, port]) env['HTTP_HOST'] = ':'.join([
str_to_wsgi(found_domain), port])
else: else:
env['HTTP_HOST'] = found_domain env['HTTP_HOST'] = str_to_wsgi(found_domain)
error = False error = False
break break
else: else:

View File

@ -170,6 +170,10 @@ class TestCNAMELookup(unittest.TestCase):
return self.cache.get(key, None) return self.cache.get(key, None)
def set(self, key, value, *a, **kw): def set(self, key, value, *a, **kw):
# real memcache client will JSON-serialize, so our mock
# should be sure to return unicode
if isinstance(value, bytes):
value = value.decode('utf-8')
self.cache[key] = value self.cache[key] = value
module = 'swift.common.middleware.cname_lookup.lookup_cname' module = 'swift.common.middleware.cname_lookup.lookup_cname'
@ -186,6 +190,9 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'), self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com') 'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
req = Request.blank('/', environ={'REQUEST_METHOD': 'GET', req = Request.blank('/', environ={'REQUEST_METHOD': 'GET',
'swift.cache': memcache}, 'swift.cache': memcache},
headers={'Host': 'mysite2.com'}) headers={'Host': 'mysite2.com'})
@ -194,6 +201,8 @@ class TestCNAMELookup(unittest.TestCase):
self.assertEqual(m.call_count, 1) self.assertEqual(m.call_count, 1)
self.assertEqual(memcache.cache.get('cname-mysite2.com'), self.assertEqual(memcache.cache.get('cname-mysite2.com'),
'c.example.com') 'c.example.com')
self.assertIsInstance(req.environ['HTTP_HOST'], str)
self.assertEqual(req.environ['HTTP_HOST'], 'c.example.com')
for exc, num in ((dns.resolver.NXDOMAIN(), 3), for exc, num in ((dns.resolver.NXDOMAIN(), 3),
(dns.resolver.NoAnswer(), 4)): (dns.resolver.NoAnswer(), 4)):