Refactor UDP query handling, add tests
Change-Id: If4e18f50126089a1971ab3ba945b34f6774423dd
This commit is contained in:
parent
1eb64bf96c
commit
6ae192335b
@ -420,6 +420,12 @@ class DNSService(object):
|
|||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
def _dns_handle_udp(self, sock_udp):
|
def _dns_handle_udp(self, sock_udp):
|
||||||
|
"""Handle a DNS Query over UDP in a dedicated thread
|
||||||
|
|
||||||
|
:param sock_udp: UDP socket
|
||||||
|
:type sock_udp: socket
|
||||||
|
:raises: None
|
||||||
|
"""
|
||||||
LOG.info(_LI("_handle_udp thread started"))
|
LOG.info(_LI("_handle_udp thread started"))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -432,8 +438,8 @@ class DNSService(object):
|
|||||||
{'host': addr[0], 'port': addr[1]})
|
{'host': addr[0], 'port': addr[1]})
|
||||||
|
|
||||||
# Dispatch a thread to handle the query
|
# Dispatch a thread to handle the query
|
||||||
self.tg.add_thread(self._dns_handle, addr, payload,
|
self.tg.add_thread(self._dns_handle_udp_query, sock_udp, addr,
|
||||||
sock_udp=sock_udp)
|
payload)
|
||||||
|
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
errname = errno.errorcode[e.args[0]]
|
errname = errno.errorcode[e.args[0]]
|
||||||
@ -446,13 +452,17 @@ class DNSService(object):
|
|||||||
"from: %(host)s:%(port)d") %
|
"from: %(host)s:%(port)d") %
|
||||||
{'host': addr[0], 'port': addr[1]})
|
{'host': addr[0], 'port': addr[1]})
|
||||||
|
|
||||||
def _dns_handle(self, addr, payload, client=None, sock_udp=None):
|
def _dns_handle_udp_query(self, sock, addr, payload):
|
||||||
"""
|
"""
|
||||||
Handle a DNS Query
|
Handle a DNS Query over UDP
|
||||||
|
|
||||||
|
:param sock: UDP socket
|
||||||
|
:type sock: socket
|
||||||
:param addr: Tuple of the client's (IP, Port)
|
:param addr: Tuple of the client's (IP, Port)
|
||||||
|
:type addr: tuple
|
||||||
:param payload: Raw DNS query payload
|
:param payload: Raw DNS query payload
|
||||||
:param client: Client socket (for TCP only)
|
:type payload: string
|
||||||
|
:raises: None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Call into the DNS Application itself with the payload and addr
|
# Call into the DNS Application itself with the payload and addr
|
||||||
@ -461,24 +471,13 @@ class DNSService(object):
|
|||||||
|
|
||||||
# Send back a response only if present
|
# Send back a response only if present
|
||||||
if response is not None:
|
if response is not None:
|
||||||
if client:
|
sock.sendto(response, addr)
|
||||||
# Handle TCP Responses
|
|
||||||
msg_length = len(response)
|
|
||||||
tcp_response = struct.pack("!H", msg_length) + response
|
|
||||||
client.sendall(tcp_response)
|
|
||||||
else:
|
|
||||||
# Handle UDP Responses
|
|
||||||
sock_udp.sendto(response, addr)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
LOG.exception(_LE("Unhandled exception while processing request "
|
LOG.exception(_LE("Unhandled exception while processing request "
|
||||||
"from %(host)s:%(port)d") %
|
"from %(host)s:%(port)d") %
|
||||||
{'host': addr[0], 'port': addr[1]})
|
{'host': addr[0], 'port': addr[1]})
|
||||||
|
|
||||||
# Close the TCP connection if we have one.
|
|
||||||
if client:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
|
|
||||||
_launcher = None
|
_launcher = None
|
||||||
|
|
||||||
|
@ -15,7 +15,9 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import binascii
|
import binascii
|
||||||
|
import errno
|
||||||
import socket
|
import socket
|
||||||
|
import struct
|
||||||
|
|
||||||
import dns
|
import dns
|
||||||
import dns.message
|
import dns.message
|
||||||
@ -32,6 +34,27 @@ def hex_wire(response):
|
|||||||
|
|
||||||
|
|
||||||
class MdnsServiceTest(MdnsTestCase):
|
class MdnsServiceTest(MdnsTestCase):
|
||||||
|
|
||||||
|
# DNS packet with IQUERY opcode
|
||||||
|
query_payload = binascii.a2b_hex(
|
||||||
|
"271209000001000000000000076578616d706c6503636f6d0000010001"
|
||||||
|
)
|
||||||
|
expected_response = binascii.a2b_hex(
|
||||||
|
b"271289050001000000000000076578616d706c6503636f6d0000010001"
|
||||||
|
)
|
||||||
|
# expected response is an error code REFUSED. The other fields are
|
||||||
|
# id 10002
|
||||||
|
# opcode IQUERY
|
||||||
|
# rcode REFUSED
|
||||||
|
# flags QR RD
|
||||||
|
# ;QUESTION
|
||||||
|
# example.com. IN A
|
||||||
|
# ;ANSWER
|
||||||
|
# ;AUTHORITY
|
||||||
|
# ;ADDITIONAL
|
||||||
|
|
||||||
|
# Use self._print_dns_msg() to display the messages
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(MdnsServiceTest, self).setUp()
|
super(MdnsServiceTest, self).setUp()
|
||||||
|
|
||||||
@ -41,147 +64,115 @@ class MdnsServiceTest(MdnsTestCase):
|
|||||||
self.service = self.start_service('mdns')
|
self.service = self.start_service('mdns')
|
||||||
self.addr = ['0.0.0.0', 5556]
|
self.addr = ['0.0.0.0', 5556]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _print_dns_msg(desc, wire):
|
||||||
|
"""Print DNS message for debugging"""
|
||||||
|
q = dns.message.from_wire(wire).to_text()
|
||||||
|
print("%s:\n%s\n" % (desc, q))
|
||||||
|
|
||||||
def test_stop(self):
|
def test_stop(self):
|
||||||
# NOTE: Start is already done by the fixture in start_service()
|
# NOTE: Start is already done by the fixture in start_service()
|
||||||
self.service.stop()
|
self.service.stop()
|
||||||
|
|
||||||
@mock.patch.object(dns.message, 'make_query')
|
@mock.patch.object(dns.message, 'make_query')
|
||||||
def test_handle_empty_payload(self, query_mock):
|
def test_handle_empty_payload(self, query_mock):
|
||||||
self.service._dns_handle(self.addr, ' '.encode('utf-8'))
|
mock_socket = mock.Mock()
|
||||||
|
self.service._dns_handle_udp_query(mock_socket, self.addr,
|
||||||
|
' '.encode('utf-8'))
|
||||||
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
|
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
|
||||||
|
|
||||||
@mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock)
|
def test_handle_udp_payload(self):
|
||||||
def test_handle_udp_payload(self, sendto_mock):
|
mock_socket = mock.Mock()
|
||||||
# DNS packet with IQUERY opcode
|
self.service._dns_handle_udp_query(mock_socket, self.addr,
|
||||||
payload = "271209000001000000000000076578616d706c6503636f6d0000010001"
|
self.query_payload)
|
||||||
|
mock_socket.sendto.assert_called_once_with(self.expected_response,
|
||||||
|
self.addr)
|
||||||
|
|
||||||
# expected response is an error code REFUSED. The other fields are
|
def test__dns_handle_tcp_conn_fail_unpack(self):
|
||||||
# id 10002
|
# will call recv() only once
|
||||||
# opcode IQUERY
|
mock_socket = mock.Mock()
|
||||||
# rcode REFUSED
|
mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack
|
||||||
# flags QR RD
|
|
||||||
# ;QUESTION
|
|
||||||
# example.com. IN A
|
|
||||||
# ;ANSWER
|
|
||||||
# ;AUTHORITY
|
|
||||||
# ;ADDITIONAL
|
|
||||||
expected_response = (b"271289050001000000000000076578616d706c6503636f6"
|
|
||||||
b"d0000010001")
|
|
||||||
|
|
||||||
sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||||
self.service._dns_handle(self.addr, binascii.a2b_hex(payload),
|
self.assertEqual(1, mock_socket.recv.call_count)
|
||||||
sock_udp=sock_udp)
|
self.assertEqual(1, mock_socket.close.call_count)
|
||||||
sendto_mock.assert_called_once_with(
|
|
||||||
binascii.a2b_hex(expected_response), self.addr)
|
|
||||||
|
|
||||||
def _send_request_to_mdns(self, req):
|
def test__dns_handle_tcp_conn_one_query(self):
|
||||||
"""Send request to localhost"""
|
payload = self.query_payload
|
||||||
self.assertTrue(len(self.service._dns_socks_udp))
|
mock_socket = mock.Mock()
|
||||||
port = self.service._dns_socks_udp[0].getsockname()[1]
|
pay_len = struct.pack("!H", len(payload))
|
||||||
response = dns.query.udp(req, '127.0.0.1', port=port, timeout=1)
|
mock_socket.recv.side_effect = [pay_len, payload, socket.timeout]
|
||||||
LOG.info("\n-- RESPONSE --\n%s\n--------------\n" % response.to_text())
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _query_mdns(self, qname, rdtype, rdclass=dns.rdataclass.IN):
|
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||||
"""Send query to localhost"""
|
|
||||||
req = dns.message.make_query(qname, rdtype, rdclass=rdclass)
|
|
||||||
req.id = 123
|
|
||||||
return self._send_request_to_mdns(req)
|
|
||||||
|
|
||||||
def test_query(self):
|
self.assertEqual(3, mock_socket.recv.call_count)
|
||||||
zone = self.create_zone()
|
self.assertEqual(1, mock_socket.sendall.call_count)
|
||||||
|
self.assertEqual(1, mock_socket.close.call_count)
|
||||||
|
wire = mock_socket.sendall.call_args[0][0]
|
||||||
|
expected_length_raw = wire[:2]
|
||||||
|
(expected_length, ) = struct.unpack('!H', expected_length_raw)
|
||||||
|
self.assertEqual(len(wire), expected_length + 2)
|
||||||
|
self.assertEqual(self.expected_response, wire[2:])
|
||||||
|
|
||||||
# Reply query for NS
|
def test__dns_handle_tcp_conn_multiple_queries(self):
|
||||||
response = self._query_mdns(zone.name, dns.rdatatype.NS)
|
payload = self.query_payload
|
||||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
mock_socket = mock.Mock()
|
||||||
self.assertEqual(1, len(response.answer))
|
pay_len = struct.pack("!H", len(payload))
|
||||||
ans = response.answer[0]
|
# Process 5 queries, than receive a misaligned query and close the
|
||||||
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
|
# connection there
|
||||||
self.assertEqual(zone.name, ans.name.to_text())
|
mock_socket.recv.side_effect = [
|
||||||
self.assertEqual(zone.ttl, ans.ttl)
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
'X', payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
]
|
||||||
|
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||||
|
|
||||||
# Reply query for SOA
|
self.assertEqual(11, mock_socket.recv.call_count)
|
||||||
response = self._query_mdns(zone.name, dns.rdatatype.SOA)
|
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
self.assertEqual(1, mock_socket.close.call_count)
|
||||||
self.assertEqual(1, len(response.answer))
|
|
||||||
ans = response.answer[0]
|
|
||||||
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
|
|
||||||
self.assertEqual(zone.name, ans.name.to_text())
|
|
||||||
self.assertEqual(zone.ttl, ans.ttl)
|
|
||||||
|
|
||||||
# Refuse query for incorrect rdclass
|
def test__dns_handle_tcp_conn_multiple_queries_socket_error(self):
|
||||||
response = self._query_mdns(zone.name, dns.rdatatype.SOA,
|
payload = self.query_payload
|
||||||
rdclass=dns.rdataclass.RESERVED0)
|
mock_socket = mock.Mock()
|
||||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
pay_len = struct.pack("!H", len(payload))
|
||||||
expected = b'007b81050001000000000000076578616d706c6503636f6d0000060000' # noqa
|
# Process 5 queries, than receive a socket error and close the
|
||||||
self.assertEqual(expected, hex_wire(response))
|
# connection there
|
||||||
|
mock_socket.recv.side_effect = [
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
socket.error(errno.EAGAIN),
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
]
|
||||||
|
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||||
|
|
||||||
# Refuse query for ANY
|
self.assertEqual(11, mock_socket.recv.call_count)
|
||||||
response = self._query_mdns("www.%s" % zone.name, dns.rdatatype.ANY)
|
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
self.assertEqual(1, mock_socket.close.call_count)
|
||||||
expected = b'007b8105000100000000000003777777076578616d706c6503636f6d0000ff0001' # noqa
|
|
||||||
self.assertEqual(expected, hex_wire(response))
|
|
||||||
|
|
||||||
# Reply query for A against inexistent record
|
def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
|
||||||
response = self._query_mdns("nope.%s" % zone.name, dns.rdatatype.A)
|
payload = self.query_payload
|
||||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
mock_socket = mock.Mock()
|
||||||
expected = b'007b81050001000000000000046e6f7065076578616d706c6503636f6d0000010001' # noqa
|
pay_len = struct.pack("!H", len(payload))
|
||||||
self.assertEqual(expected, hex_wire(response))
|
# Ignore a broken query and keep going as long as the query len
|
||||||
|
# header was correct
|
||||||
|
mock_socket.recv.side_effect = [
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload[:-5] + b'hello',
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
pay_len, payload,
|
||||||
|
]
|
||||||
|
self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket)
|
||||||
|
|
||||||
# Reply query for A
|
self.assertEqual(11, mock_socket.recv.call_count)
|
||||||
recordset = self.create_recordset(zone)
|
self.assertEqual(4, mock_socket.sendall.call_count)
|
||||||
self.create_record(zone, recordset)
|
self.assertEqual(1, mock_socket.close.call_count)
|
||||||
response = self._query_mdns(recordset.name, dns.rdatatype.A)
|
|
||||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
|
||||||
self.assertEqual(1, len(response.answer))
|
|
||||||
ans = response.answer[0]
|
|
||||||
self.assertEqual(dns.rdatatype.A, ans.rdtype)
|
|
||||||
self.assertEqual(recordset.name, ans.name.to_text())
|
|
||||||
self.assertEqual(zone.ttl, ans.ttl)
|
|
||||||
self.assertEqual('3600 IN A 192.0.2.1', str(ans.to_rdataset()))
|
|
||||||
expected = b'007b85000001000100000000046d61696c076578616d706c6503636f6d0000010001c00c0001000100000e100004c0000201' # noqa
|
|
||||||
self.assertEqual(expected, hex_wire(response))
|
|
||||||
|
|
||||||
def test_query_axfr(self):
|
|
||||||
zone = self.create_zone()
|
|
||||||
|
|
||||||
# Query for AXFR
|
|
||||||
response = self._query_mdns(zone.name, dns.rdatatype.AXFR)
|
|
||||||
self.assertEqual(dns.rcode.NOERROR, response.rcode())
|
|
||||||
self.assertEqual(2, len(response.answer))
|
|
||||||
ans = response.answer[0] # SOA
|
|
||||||
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
|
|
||||||
self.assertEqual(zone.name, ans.name.to_text())
|
|
||||||
self.assertEqual(zone.ttl, ans.ttl)
|
|
||||||
ans = response.answer[1] # NS
|
|
||||||
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
|
|
||||||
self.assertEqual(zone.name, ans.name.to_text())
|
|
||||||
self.assertEqual(zone.ttl, ans.ttl)
|
|
||||||
|
|
||||||
def test_notify_notauth_primary_zone(self):
|
|
||||||
zone = self.create_zone()
|
|
||||||
|
|
||||||
# Send NOTIFY to mdns: NOTAUTH for primary zone
|
|
||||||
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
|
|
||||||
notify.id = 123
|
|
||||||
notify.flags = 0
|
|
||||||
notify.set_opcode(dns.opcode.NOTIFY)
|
|
||||||
notify.flags |= dns.flags.AA
|
|
||||||
response = self._send_request_to_mdns(notify)
|
|
||||||
self.assertEqual(dns.rcode.NOTAUTH, response.rcode())
|
|
||||||
expected = b'007ba0090001000000000000076578616d706c6503636f6d0000060001' # noqa
|
|
||||||
self.assertEqual(expected, hex_wire(response))
|
|
||||||
|
|
||||||
def test_notify_non_master(self):
|
|
||||||
zone = self.create_zone(type='SECONDARY', email='test@example.com')
|
|
||||||
|
|
||||||
# Send NOTIFY to mdns: refuse from non-master
|
|
||||||
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
|
|
||||||
notify.id = 123
|
|
||||||
notify.flags = 0
|
|
||||||
notify.set_opcode(dns.opcode.NOTIFY)
|
|
||||||
notify.flags |= dns.flags.AA
|
|
||||||
response = self._send_request_to_mdns(notify)
|
|
||||||
self.assertEqual(dns.rcode.REFUSED, response.rcode())
|
|
||||||
expected = b'007ba0050001000000000000076578616d706c6503636f6d0000060001' # noqa
|
|
||||||
self.assertEqual(expected, hex_wire(response))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user