Refactor UDP query handling, add tests
Change-Id: If4e18f50126089a1971ab3ba945b34f6774423dd
This commit is contained in:
parent
1eb64bf96c
commit
6ae192335b
@ -420,6 +420,12 @@ class DNSService(object):
|
||||
client.close()
|
||||
|
||||
def _dns_handle_udp(self, sock_udp):
|
||||
"""Handle a DNS Query over UDP in a dedicated thread
|
||||
|
||||
:param sock_udp: UDP socket
|
||||
:type sock_udp: socket
|
||||
:raises: None
|
||||
"""
|
||||
LOG.info(_LI("_handle_udp thread started"))
|
||||
|
||||
while True:
|
||||
@ -432,8 +438,8 @@ class DNSService(object):
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Dispatch a thread to handle the query
|
||||
self.tg.add_thread(self._dns_handle, addr, payload,
|
||||
sock_udp=sock_udp)
|
||||
self.tg.add_thread(self._dns_handle_udp_query, sock_udp, addr,
|
||||
payload)
|
||||
|
||||
except socket.error as e:
|
||||
errname = errno.errorcode[e.args[0]]
|
||||
@ -446,13 +452,17 @@ class DNSService(object):
|
||||
"from: %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
def _dns_handle(self, addr, payload, client=None, sock_udp=None):
|
||||
def _dns_handle_udp_query(self, sock, addr, payload):
|
||||
"""
|
||||
Handle a DNS Query
|
||||
Handle a DNS Query over UDP
|
||||
|
||||
:param sock: UDP socket
|
||||
:type sock: socket
|
||||
:param addr: Tuple of the client's (IP, Port)
|
||||
:type addr: tuple
|
||||
:param payload: Raw DNS query payload
|
||||
:param client: Client socket (for TCP only)
|
||||
:type payload: string
|
||||
:raises: None
|
||||
"""
|
||||
try:
|
||||
# Call into the DNS Application itself with the payload and addr
|
||||
@ -461,24 +471,13 @@ class DNSService(object):
|
||||
|
||||
# Send back a response only if present
|
||||
if response is not None:
|
||||
if client:
|
||||
# Handle TCP Responses
|
||||
msg_length = len(response)
|
||||
tcp_response = struct.pack("!H", msg_length) + response
|
||||
client.sendall(tcp_response)
|
||||
else:
|
||||
# Handle UDP Responses
|
||||
sock_udp.sendto(response, addr)
|
||||
sock.sendto(response, addr)
|
||||
|
||||
except Exception:
|
||||
LOG.exception(_LE("Unhandled exception while processing request "
|
||||
"from %(host)s:%(port)d") %
|
||||
{'host': addr[0], 'port': addr[1]})
|
||||
|
||||
# Close the TCP connection if we have one.
|
||||
if client:
|
||||
client.close()
|
||||
|
||||
|
||||
_launcher = None
|
||||
|
||||
|
@ -15,7 +15,9 @@
|
||||
# under the License.
|
||||
|
||||
import binascii
|
||||
import errno
|
||||
import socket
|
||||
import struct
|
||||
|
||||
import dns
|
||||
import dns.message
|
||||
@ -32,6 +34,27 @@ def hex_wire(response):
|
||||
|
||||
|
||||
class MdnsServiceTest(MdnsTestCase):
|
||||
|
||||
# DNS packet with IQUERY opcode
|
||||
query_payload = binascii.a2b_hex(
|
||||
"271209000001000000000000076578616d706c6503636f6d0000010001"
|
||||
)
|
||||
expected_response = binascii.a2b_hex(
|
||||
b"271289050001000000000000076578616d706c6503636f6d0000010001"
|
||||
)
|
||||
# expected response is an error code REFUSED. The other fields are
|
||||
# id 10002
|
||||
# opcode IQUERY
|
||||
# rcode REFUSED
|
||||
# flags QR RD
|
||||
# ;QUESTION
|
||||
# example.com. IN A
|
||||
# ;ANSWER
|
||||
# ;AUTHORITY
|
||||
# ;ADDITIONAL
|
||||
|
||||
# Use self._print_dns_msg() to display the messages
|
||||
|
||||
def setUp(self):
|
||||
super(MdnsServiceTest, self).setUp()
|
||||
|
||||
@ -41,147 +64,115 @@ class MdnsServiceTest(MdnsTestCase):
|
||||
self.service = self.start_service('mdns')
|
||||
self.addr = ['0.0.0.0', 5556]
|
||||
|
||||
@staticmethod
|
||||
def _print_dns_msg(desc, wire):
|
||||
"""Print DNS message for debugging"""
|
||||
q = dns.message.from_wire(wire).to_text()
|
||||
print("%s:\n%s\n" % (desc, q))
|
||||
|
||||
def test_stop(self):
|
||||
# NOTE: Start is already done by the fixture in start_service()
|
||||
self.service.stop()
|
||||
|
||||
@mock.patch.object(dns.message, 'make_query')
|
||||
def test_handle_empty_payload(self, query_mock):
|
||||
self.service._dns_handle(self.addr, ' '.encode('utf-8'))
|
||||
mock_socket = mock.Mock()
|
||||
self.service._dns_handle_udp_query(mock_socket, self.addr,
|
||||
' '.encode('utf-8'))
|
||||
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
|
||||
|
||||
@mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock)
|
||||
def test_handle_udp_payload(self, sendto_mock):
|
||||
# DNS packet with IQUERY opcode
|
||||
payload = "271209000001000000000000076578616d706c6503636f6d0000010001"
|
||||
def test_handle_udp_payload(self):
|
||||
mock_socket = mock.Mock()
|
||||
self.service._dns_handle_udp_query(mock_socket, self.addr,
|
||||
self.query_payload)
|
||||
mock_socket.sendto.assert_called_once_with(self.expected_response,
|
||||
self.addr)
|
||||
|
||||
# expected response is an error code REFUSED. The other fields are
|
||||
# id 10002
|
||||
# opcode IQUERY
|
||||
# rcode REFUSED
|
||||
# flags QR RD
|
||||
# ;QUESTION
|
||||
# example.com. IN A
|
||||
# ;ANSWER
|
||||
# ;AUTHORITY
|
||||
# ;ADDITIONAL
|
||||
expected_response = (b"271289050001000000000000076578616d706c6503636f6"
|
||||
b"d0000010001")
|
||||
def test__dns_handle_tcp_conn_fail_unpack(self):
|
||||
# will call recv() only once
|
||||
mock_socket = mock.Mock()
|
||||
mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack
|
||||
|
||||
sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.service._dns_handle(self.addr, binascii.a2b_hex(payload),
|
||||
sock_udp=sock_udp)
|
||||
sendto_mock.assert_called_once_with(
|
||||
binascii.a2b_hex(expected_response), self.addr)
|
||||
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||
self.assertEqual(1, mock_socket.recv.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
def _send_request_to_mdns(self, req):
|
||||
"""Send request to localhost"""
|
||||
self.assertTrue(len(self.service._dns_socks_udp))
|
||||
port = self.service._dns_socks_udp[0].getsockname()[1]
|
||||
response = dns.query.udp(req, '127.0.0.1', port=port, timeout=1)
|
||||
LOG.info("\n-- RESPONSE --\n%s\n--------------\n" % response.to_text())
|
||||
return response
|
||||
def test__dns_handle_tcp_conn_one_query(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
mock_socket.recv.side_effect = [pay_len, payload, socket.timeout]
|
||||
|
||||
def _query_mdns(self, qname, rdtype, rdclass=dns.rdataclass.IN):
|
||||
"""Send query to localhost"""
|
||||
req = dns.message.make_query(qname, rdtype, rdclass=rdclass)
|
||||
req.id = 123
|
||||
return self._send_request_to_mdns(req)
|
||||
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||
|
||||
def test_query(self):
|
||||
zone = self.create_zone()
|
||||
self.assertEqual(3, mock_socket.recv.call_count)
|
||||
self.assertEqual(1, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
wire = mock_socket.sendall.call_args[0][0]
|
||||
expected_length_raw = wire[:2]
|
||||
(expected_length, ) = struct.unpack('!H', expected_length_raw)
|
||||
self.assertEqual(len(wire), expected_length + 2)
|
||||
self.assertEqual(self.expected_response, wire[2:])
|
||||
|
||||
# Reply query for NS
|
||||
response = self._query_mdns(zone.name, dns.rdatatype.NS)
|
||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
||||
self.assertEqual(1, len(response.answer))
|
||||
ans = response.answer[0]
|
||||
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
|
||||
self.assertEqual(zone.name, ans.name.to_text())
|
||||
self.assertEqual(zone.ttl, ans.ttl)
|
||||
def test__dns_handle_tcp_conn_multiple_queries(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
# Process 5 queries, than receive a misaligned query and close the
|
||||
# connection there
|
||||
mock_socket.recv.side_effect = [
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
'X', payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
]
|
||||
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||
|
||||
# Reply query for SOA
|
||||
response = self._query_mdns(zone.name, dns.rdatatype.SOA)
|
||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
||||
self.assertEqual(1, len(response.answer))
|
||||
ans = response.answer[0]
|
||||
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
|
||||
self.assertEqual(zone.name, ans.name.to_text())
|
||||
self.assertEqual(zone.ttl, ans.ttl)
|
||||
self.assertEqual(11, mock_socket.recv.call_count)
|
||||
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
# Refuse query for incorrect rdclass
|
||||
response = self._query_mdns(zone.name, dns.rdatatype.SOA,
|
||||
rdclass=dns.rdataclass.RESERVED0)
|
||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
||||
expected = b'007b81050001000000000000076578616d706c6503636f6d0000060000' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
def test__dns_handle_tcp_conn_multiple_queries_socket_error(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
# Process 5 queries, than receive a socket error and close the
|
||||
# connection there
|
||||
mock_socket.recv.side_effect = [
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
socket.error(errno.EAGAIN),
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
]
|
||||
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||
|
||||
# Refuse query for ANY
|
||||
response = self._query_mdns("www.%s" % zone.name, dns.rdatatype.ANY)
|
||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
||||
expected = b'007b8105000100000000000003777777076578616d706c6503636f6d0000ff0001' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
self.assertEqual(11, mock_socket.recv.call_count)
|
||||
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
# Reply query for A against inexistent record
|
||||
response = self._query_mdns("nope.%s" % zone.name, dns.rdatatype.A)
|
||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
||||
expected = b'007b81050001000000000000046e6f7065076578616d706c6503636f6d0000010001' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
# Ignore a broken query and keep going as long as the query len
|
||||
# header was correct
|
||||
mock_socket.recv.side_effect = [
|
||||
pay_len, payload,
|
||||
pay_len, payload[:-5] + b'hello',
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
pay_len, payload,
|
||||
]
|
||||
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||
|
||||
# Reply query for A
|
||||
recordset = self.create_recordset(zone)
|
||||
self.create_record(zone, recordset)
|
||||
response = self._query_mdns(recordset.name, dns.rdatatype.A)
|
||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
||||
self.assertEqual(1, len(response.answer))
|
||||
ans = response.answer[0]
|
||||
self.assertEqual(dns.rdatatype.A, ans.rdtype)
|
||||
self.assertEqual(recordset.name, ans.name.to_text())
|
||||
self.assertEqual(zone.ttl, ans.ttl)
|
||||
self.assertEqual('3600 IN A 192.0.2.1', str(ans.to_rdataset()))
|
||||
expected = b'007b85000001000100000000046d61696c076578616d706c6503636f6d0000010001c00c0001000100000e100004c0000201' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
|
||||
def test_query_axfr(self):
|
||||
zone = self.create_zone()
|
||||
|
||||
# Query for AXFR
|
||||
response = self._query_mdns(zone.name, dns.rdatatype.AXFR)
|
||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
||||
self.assertEqual(2, len(response.answer))
|
||||
ans = response.answer[0] # SOA
|
||||
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
|
||||
self.assertEqual(zone.name, ans.name.to_text())
|
||||
self.assertEqual(zone.ttl, ans.ttl)
|
||||
ans = response.answer[1] # NS
|
||||
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
|
||||
self.assertEqual(zone.name, ans.name.to_text())
|
||||
self.assertEqual(zone.ttl, ans.ttl)
|
||||
|
||||
def test_notify_notauth_primary_zone(self):
|
||||
zone = self.create_zone()
|
||||
|
||||
# Send NOTIFY to mdns: NOTAUTH for primary zone
|
||||
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
|
||||
notify.id = 123
|
||||
notify.flags = 0
|
||||
notify.set_opcode(dns.opcode.NOTIFY)
|
||||
notify.flags |= dns.flags.AA
|
||||
response = self._send_request_to_mdns(notify)
|
||||
self.assertEqual(dns.rcode.NOTAUTH, response.rcode())
|
||||
expected = b'007ba0090001000000000000076578616d706c6503636f6d0000060001' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
|
||||
def test_notify_non_master(self):
|
||||
zone = self.create_zone(type='SECONDARY', email='test@example.com')
|
||||
|
||||
# Send NOTIFY to mdns: refuse from non-master
|
||||
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
|
||||
notify.id = 123
|
||||
notify.flags = 0
|
||||
notify.set_opcode(dns.opcode.NOTIFY)
|
||||
notify.flags |= dns.flags.AA
|
||||
response = self._send_request_to_mdns(notify)
|
||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
||||
expected = b'007ba0050001000000000000076578616d706c6503636f6d0000060001' # noqa
|
||||
self.assertEqual(expected, hex_wire(response))
|
||||
self.assertEqual(11, mock_socket.recv.call_count)
|
||||
self.assertEqual(4, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
Loading…
Reference in New Issue
Block a user