Refactor UDP query handling, add tests

Change-Id: If4e18f50126089a1971ab3ba945b34f6774423dd
This commit is contained in:
Federico Ceratto 2016-04-30 17:20:21 +01:00
parent 1eb64bf96c
commit 6ae192335b
2 changed files with 132 additions and 142 deletions

View File

@ -420,6 +420,12 @@ class DNSService(object):
client.close() client.close()
def _dns_handle_udp(self, sock_udp): def _dns_handle_udp(self, sock_udp):
"""Handle a DNS Query over UDP in a dedicated thread
:param sock_udp: UDP socket
:type sock_udp: socket
:raises: None
"""
LOG.info(_LI("_handle_udp thread started")) LOG.info(_LI("_handle_udp thread started"))
while True: while True:
@ -432,8 +438,8 @@ class DNSService(object):
{'host': addr[0], 'port': addr[1]}) {'host': addr[0], 'port': addr[1]})
# Dispatch a thread to handle the query # Dispatch a thread to handle the query
self.tg.add_thread(self._dns_handle, addr, payload, self.tg.add_thread(self._dns_handle_udp_query, sock_udp, addr,
sock_udp=sock_udp) payload)
except socket.error as e: except socket.error as e:
errname = errno.errorcode[e.args[0]] errname = errno.errorcode[e.args[0]]
@ -446,13 +452,17 @@ class DNSService(object):
"from: %(host)s:%(port)d") % "from: %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]}) {'host': addr[0], 'port': addr[1]})
def _dns_handle(self, addr, payload, client=None, sock_udp=None): def _dns_handle_udp_query(self, sock, addr, payload):
""" """
Handle a DNS Query Handle a DNS Query over UDP
:param sock: UDP socket
:type sock: socket
:param addr: Tuple of the client's (IP, Port) :param addr: Tuple of the client's (IP, Port)
:type addr: tuple
:param payload: Raw DNS query payload :param payload: Raw DNS query payload
:param client: Client socket (for TCP only) :type payload: string
:raises: None
""" """
try: try:
# Call into the DNS Application itself with the payload and addr # Call into the DNS Application itself with the payload and addr
@ -461,24 +471,13 @@ class DNSService(object):
# Send back a response only if present # Send back a response only if present
if response is not None: if response is not None:
if client: sock.sendto(response, addr)
# Handle TCP Responses
msg_length = len(response)
tcp_response = struct.pack("!H", msg_length) + response
client.sendall(tcp_response)
else:
# Handle UDP Responses
sock_udp.sendto(response, addr)
except Exception: except Exception:
LOG.exception(_LE("Unhandled exception while processing request " LOG.exception(_LE("Unhandled exception while processing request "
"from %(host)s:%(port)d") % "from %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]}) {'host': addr[0], 'port': addr[1]})
# Close the TCP connection if we have one.
if client:
client.close()
_launcher = None _launcher = None

View File

@ -15,7 +15,9 @@
# under the License. # under the License.
import binascii import binascii
import errno
import socket import socket
import struct
import dns import dns
import dns.message import dns.message
@ -32,6 +34,27 @@ def hex_wire(response):
class MdnsServiceTest(MdnsTestCase): class MdnsServiceTest(MdnsTestCase):
# DNS packet with IQUERY opcode
query_payload = binascii.a2b_hex(
"271209000001000000000000076578616d706c6503636f6d0000010001"
)
expected_response = binascii.a2b_hex(
b"271289050001000000000000076578616d706c6503636f6d0000010001"
)
# expected response is an error code REFUSED. The other fields are
# id 10002
# opcode IQUERY
# rcode REFUSED
# flags QR RD
# ;QUESTION
# example.com. IN A
# ;ANSWER
# ;AUTHORITY
# ;ADDITIONAL
# Use self._print_dns_msg() to display the messages
def setUp(self): def setUp(self):
super(MdnsServiceTest, self).setUp() super(MdnsServiceTest, self).setUp()
@ -41,147 +64,115 @@ class MdnsServiceTest(MdnsTestCase):
self.service = self.start_service('mdns') self.service = self.start_service('mdns')
self.addr = ['0.0.0.0', 5556] self.addr = ['0.0.0.0', 5556]
@staticmethod
def _print_dns_msg(desc, wire):
"""Print DNS message for debugging"""
q = dns.message.from_wire(wire).to_text()
print("%s:\n%s\n" % (desc, q))
def test_stop(self): def test_stop(self):
# NOTE: Start is already done by the fixture in start_service() # NOTE: Start is already done by the fixture in start_service()
self.service.stop() self.service.stop()
@mock.patch.object(dns.message, 'make_query') @mock.patch.object(dns.message, 'make_query')
def test_handle_empty_payload(self, query_mock): def test_handle_empty_payload(self, query_mock):
self.service._dns_handle(self.addr, ' '.encode('utf-8')) mock_socket = mock.Mock()
self.service._dns_handle_udp_query(mock_socket, self.addr,
' '.encode('utf-8'))
query_mock.assert_called_once_with('unknown', dns.rdatatype.A) query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
@mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock) def test_handle_udp_payload(self):
def test_handle_udp_payload(self, sendto_mock): mock_socket = mock.Mock()
# DNS packet with IQUERY opcode self.service._dns_handle_udp_query(mock_socket, self.addr,
payload = "271209000001000000000000076578616d706c6503636f6d0000010001" self.query_payload)
mock_socket.sendto.assert_called_once_with(self.expected_response,
self.addr)
# expected response is an error code REFUSED. The other fields are def test__dns_handle_tcp_conn_fail_unpack(self):
# id 10002 # will call recv() only once
# opcode IQUERY mock_socket = mock.Mock()
# rcode REFUSED mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack
# flags QR RD
# ;QUESTION
# example.com. IN A
# ;ANSWER
# ;AUTHORITY
# ;ADDITIONAL
expected_response = (b"271289050001000000000000076578616d706c6503636f6"
b"d0000010001")
sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
self.service._dns_handle(self.addr, binascii.a2b_hex(payload), self.assertEqual(1, mock_socket.recv.call_count)
sock_udp=sock_udp) self.assertEqual(1, mock_socket.close.call_count)
sendto_mock.assert_called_once_with(
binascii.a2b_hex(expected_response), self.addr)
def _send_request_to_mdns(self, req): def test__dns_handle_tcp_conn_one_query(self):
"""Send request to localhost""" payload = self.query_payload
self.assertTrue(len(self.service._dns_socks_udp)) mock_socket = mock.Mock()
port = self.service._dns_socks_udp[0].getsockname()[1] pay_len = struct.pack("!H", len(payload))
response = dns.query.udp(req, '127.0.0.1', port=port, timeout=1) mock_socket.recv.side_effect = [pay_len, payload, socket.timeout]
LOG.info("\n-- RESPONSE --\n%s\n--------------\n" % response.to_text())
return response
def _query_mdns(self, qname, rdtype, rdclass=dns.rdataclass.IN): self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
"""Send query to localhost"""
req = dns.message.make_query(qname, rdtype, rdclass=rdclass)
req.id = 123
return self._send_request_to_mdns(req)
def test_query(self): self.assertEqual(3, mock_socket.recv.call_count)
zone = self.create_zone() self.assertEqual(1, mock_socket.sendall.call_count)
self.assertEqual(1, mock_socket.close.call_count)
wire = mock_socket.sendall.call_args[0][0]
expected_length_raw = wire[:2]
(expected_length, ) = struct.unpack('!H', expected_length_raw)
self.assertEqual(len(wire), expected_length + 2)
self.assertEqual(self.expected_response, wire[2:])
# Reply query for NS def test__dns_handle_tcp_conn_multiple_queries(self):
response = self._query_mdns(zone.name, dns.rdatatype.NS) payload = self.query_payload
self.assertEqual(dns.rcode.NOERROR, response.rcode()) mock_socket = mock.Mock()
self.assertEqual(1, len(response.answer)) pay_len = struct.pack("!H", len(payload))
ans = response.answer[0] # Process 5 queries, than receive a misaligned query and close the
self.assertEqual(dns.rdatatype.NS, ans.rdtype) # connection there
self.assertEqual(zone.name, ans.name.to_text()) mock_socket.recv.side_effect = [
self.assertEqual(zone.ttl, ans.ttl) pay_len, payload,
pay_len, payload,
pay_len, payload,
pay_len, payload,
pay_len, payload,
'X', payload,
pay_len, payload,
pay_len, payload,
]
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
# Reply query for SOA self.assertEqual(11, mock_socket.recv.call_count)
response = self._query_mdns(zone.name, dns.rdatatype.SOA) self.assertEqual(5, mock_socket.sendall.call_count)
self.assertEqual(dns.rcode.NOERROR, response.rcode()) self.assertEqual(1, mock_socket.close.call_count)
self.assertEqual(1, len(response.answer))
ans = response.answer[0]
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
# Refuse query for incorrect rdclass def test__dns_handle_tcp_conn_multiple_queries_socket_error(self):
response = self._query_mdns(zone.name, dns.rdatatype.SOA, payload = self.query_payload
rdclass=dns.rdataclass.RESERVED0) mock_socket = mock.Mock()
self.assertEqual(dns.rcode.REFUSED, response.rcode()) pay_len = struct.pack("!H", len(payload))
expected = b'007b81050001000000000000076578616d706c6503636f6d0000060000' # noqa # Process 5 queries, than receive a socket error and close the
self.assertEqual(expected, hex_wire(response)) # connection there
mock_socket.recv.side_effect = [
pay_len, payload,
pay_len, payload,
pay_len, payload,
pay_len, payload,
pay_len, payload,
socket.error(errno.EAGAIN),
pay_len, payload,
pay_len, payload,
]
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
# Refuse query for ANY self.assertEqual(11, mock_socket.recv.call_count)
response = self._query_mdns("www.%s" % zone.name, dns.rdatatype.ANY) self.assertEqual(5, mock_socket.sendall.call_count)
self.assertEqual(dns.rcode.REFUSED, response.rcode()) self.assertEqual(1, mock_socket.close.call_count)
expected = b'007b8105000100000000000003777777076578616d706c6503636f6d0000ff0001' # noqa
self.assertEqual(expected, hex_wire(response))
# Reply query for A against inexistent record def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
response = self._query_mdns("nope.%s" % zone.name, dns.rdatatype.A) payload = self.query_payload
self.assertEqual(dns.rcode.REFUSED, response.rcode()) mock_socket = mock.Mock()
expected = b'007b81050001000000000000046e6f7065076578616d706c6503636f6d0000010001' # noqa pay_len = struct.pack("!H", len(payload))
self.assertEqual(expected, hex_wire(response)) # Ignore a broken query and keep going as long as the query len
# header was correct
mock_socket.recv.side_effect = [
pay_len, payload,
pay_len, payload[:-5] + b'hello',
pay_len, payload,
pay_len, payload,
pay_len, payload,
]
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
# Reply query for A self.assertEqual(11, mock_socket.recv.call_count)
recordset = self.create_recordset(zone) self.assertEqual(4, mock_socket.sendall.call_count)
self.create_record(zone, recordset) self.assertEqual(1, mock_socket.close.call_count)
response = self._query_mdns(recordset.name, dns.rdatatype.A)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(1, len(response.answer))
ans = response.answer[0]
self.assertEqual(dns.rdatatype.A, ans.rdtype)
self.assertEqual(recordset.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
self.assertEqual('3600 IN A 192.0.2.1', str(ans.to_rdataset()))
expected = b'007b85000001000100000000046d61696c076578616d706c6503636f6d0000010001c00c0001000100000e100004c0000201' # noqa
self.assertEqual(expected, hex_wire(response))
def test_query_axfr(self):
zone = self.create_zone()
# Query for AXFR
response = self._query_mdns(zone.name, dns.rdatatype.AXFR)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(2, len(response.answer))
ans = response.answer[0] # SOA
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
ans = response.answer[1] # NS
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
def test_notify_notauth_primary_zone(self):
zone = self.create_zone()
# Send NOTIFY to mdns: NOTAUTH for primary zone
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
notify.id = 123
notify.flags = 0
notify.set_opcode(dns.opcode.NOTIFY)
notify.flags |= dns.flags.AA
response = self._send_request_to_mdns(notify)
self.assertEqual(dns.rcode.NOTAUTH, response.rcode())
expected = b'007ba0090001000000000000076578616d706c6503636f6d0000060001' # noqa
self.assertEqual(expected, hex_wire(response))
def test_notify_non_master(self):
zone = self.create_zone(type='SECONDARY', email='test@example.com')
# Send NOTIFY to mdns: refuse from non-master
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
notify.id = 123
notify.flags = 0
notify.set_opcode(dns.opcode.NOTIFY)
notify.flags |= dns.flags.AA
response = self._send_request_to_mdns(notify)
self.assertEqual(dns.rcode.REFUSED, response.rcode())
expected = b'007ba0050001000000000000076578616d706c6503636f6d0000060001' # noqa
self.assertEqual(expected, hex_wire(response))