diff --git a/swift/common/direct_client.py b/swift/common/direct_client.py index d0a3cc1795..2d4e2ce6a7 100644 --- a/swift/common/direct_client.py +++ b/swift/common/direct_client.py @@ -177,11 +177,21 @@ def _get_direct_account_container(path, stype, node, part, return resp_headers, json.loads(resp.read().decode('ascii')) -def gen_headers(hdrs_in=None, add_ts=False, add_user_agent=True): +def gen_headers(hdrs_in=None, add_ts=True): + """ + Get the headers ready for a request. All requests should have a User-Agent + string, but if one is passed in don't over-write it. Not all requests will + need an X-Timestamp, but if one is passed in do not over-write it. + + :param headers: dict or None, base for HTTP headers + :param add_ts: boolean, should be True for any "unsafe" HTTP request + + :returns: HeaderKeyDict based on headers and ready for the request + """ hdrs_out = HeaderKeyDict(hdrs_in) if hdrs_in else HeaderKeyDict() - if add_ts: + if add_ts and 'X-Timestamp' not in hdrs_out: hdrs_out['X-Timestamp'] = Timestamp.now().internal - if add_user_agent: + if 'user-agent' not in hdrs_out: hdrs_out['User-Agent'] = 'direct-client %s' % os.getpid() return hdrs_out @@ -332,8 +342,7 @@ def direct_put_container(node, part, account, container, conn_timeout=5, lower_headers = set(k.lower() for k in headers) headers_out = gen_headers(headers, - add_ts='x-timestamp' not in lower_headers, - add_user_agent='user-agent' not in lower_headers) + add_ts='x-timestamp' not in lower_headers) path = _make_path(account, container) _make_req(node, part, 'PUT', path, headers_out, 'Container', conn_timeout, response_timeout, contents=contents, diff --git a/test/unit/common/test_direct_client.py b/test/unit/common/test_direct_client.py index 38fda1ccc8..36d2d56c5d 100644 --- a/test/unit/common/test_direct_client.py +++ b/test/unit/common/test_direct_client.py @@ -131,25 +131,47 @@ class TestDirectClient(unittest.TestCase): def test_gen_headers(self): stub_user_agent = 'direct-client %s' % os.getpid() - headers = direct_client.gen_headers() + headers = direct_client.gen_headers(add_ts=False) self.assertEqual(headers['user-agent'], stub_user_agent) self.assertEqual(1, len(headers)) now = time.time() - headers = direct_client.gen_headers(add_ts=True) + headers = direct_client.gen_headers() self.assertEqual(headers['user-agent'], stub_user_agent) self.assertTrue(now - 1 < Timestamp(headers['x-timestamp']) < now + 1) self.assertEqual(headers['x-timestamp'], Timestamp(headers['x-timestamp']).internal) self.assertEqual(2, len(headers)) - headers = direct_client.gen_headers(hdrs_in={'foo-bar': '47'}) + headers = direct_client.gen_headers(hdrs_in={'x-timestamp': '15'}) + self.assertEqual(headers['x-timestamp'], '15') self.assertEqual(headers['user-agent'], stub_user_agent) - self.assertEqual(headers['foo-bar'], '47') self.assertEqual(2, len(headers)) - headers = direct_client.gen_headers(hdrs_in={'user-agent': '47'}) + headers = direct_client.gen_headers(hdrs_in={'foo-bar': '63'}) self.assertEqual(headers['user-agent'], stub_user_agent) + self.assertEqual(headers['foo-bar'], '63') + self.assertTrue(now - 1 < Timestamp(headers['x-timestamp']) < now + 1) + self.assertEqual(headers['x-timestamp'], + Timestamp(headers['x-timestamp']).internal) + self.assertEqual(3, len(headers)) + + hdrs_in = {'foo-bar': '55'} + headers = direct_client.gen_headers(hdrs_in, add_ts=False) + self.assertEqual(headers['user-agent'], stub_user_agent) + self.assertEqual(headers['foo-bar'], '55') + self.assertEqual(2, len(headers)) + + headers = direct_client.gen_headers(hdrs_in={'user-agent': '32'}) + self.assertEqual(headers['user-agent'], '32') + self.assertTrue(now - 1 < Timestamp(headers['x-timestamp']) < now + 1) + self.assertEqual(headers['x-timestamp'], + Timestamp(headers['x-timestamp']).internal) + self.assertEqual(2, len(headers)) + + hdrs_in = {'user-agent': '47'} + headers = direct_client.gen_headers(hdrs_in, add_ts=False) + self.assertEqual(headers['user-agent'], '47') self.assertEqual(1, len(headers)) for policy in POLICIES: @@ -570,7 +592,7 @@ class TestDirectClient(unittest.TestCase): self.assertEqual(conn.req_headers['user-agent'], self.user_agent) self.assertEqual('bar', conn.req_headers.get('x-foo')) - self.assertNotIn('x-timestamp', conn.req_headers) + self.assertIn('x-timestamp', conn.req_headers) self.assertEqual(headers, resp) def test_direct_head_object_error(self):