Ensure mDNS can AXFR zones over 65k

Change-Id: Ic209a6d1c552326f51541c9ad9e524df347abe16
Closes-Bug: 1439125
This commit is contained in:
Kiall Mac Innes 2015-04-01 15:33:38 +01:00
parent 290e039922
commit 224fecf0b8
8 changed files with 288 additions and 85 deletions

View File

@ -88,7 +88,8 @@ class RequestHandler(object):
response = self._handle_query_error(request, dns.rcode.REFUSED) response = self._handle_query_error(request, dns.rcode.REFUSED)
# TODO(Tim): Answer Type 65XXX queries # TODO(Tim): Answer Type 65XXX queries
return response yield response
raise StopIteration
def _handle_query_error(self, request, rcode): def _handle_query_error(self, request, rcode):
""" """

View File

@ -124,11 +124,13 @@ class SerializationMiddleware(DNSMiddleware):
else: else:
# Hand the Deserialized packet onto the Application # Hand the Deserialized packet onto the Application
response = self.application(message) for response in self.application(message):
# Serialize and return the response if present # Serialize and return the response if present
if response is not None: if isinstance(response, dns.message.Message):
return response.to_wire(max_size=65535) yield response.to_wire(max_size=65535)
elif isinstance(response, dns.renderer.Renderer):
yield response.get_wire()
class TsigInfoMiddleware(DNSMiddleware): class TsigInfoMiddleware(DNSMiddleware):

View File

@ -40,6 +40,8 @@ OPTS = [
'signed'), 'signed'),
cfg.StrOpt('storage-driver', default='sqlalchemy', cfg.StrOpt('storage-driver', default='sqlalchemy',
help='The storage driver to use'), help='The storage driver to use'),
cfg.IntOpt('max-message-size', default=65535,
help='Maximum message size to emit'),
] ]
cfg.CONF.register_opts(OPTS, group='service:mdns') cfg.CONF.register_opts(OPTS, group='service:mdns')

View File

@ -58,7 +58,8 @@ class RequestHandler(xfr.XFRMixin):
# TSIG places the pseudo records into the additional section. # TSIG places the pseudo records into the additional section.
if (len(request.question) != 1 or if (len(request.question) != 1 or
request.question[0].rdclass != dns.rdataclass.IN): request.question[0].rdclass != dns.rdataclass.IN):
return self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
q_rrset = request.question[0] q_rrset = request.question[0]
# Handle AXFR and IXFR requests with an AXFR responses for now. # Handle AXFR and IXFR requests with an AXFR responses for now.
@ -66,15 +67,24 @@ class RequestHandler(xfr.XFRMixin):
# receiving an IXFR request. # receiving an IXFR request.
# TODO(Ron): send IXFR response when receiving IXFR request. # TODO(Ron): send IXFR response when receiving IXFR request.
if q_rrset.rdtype in (dns.rdatatype.AXFR, dns.rdatatype.IXFR): if q_rrset.rdtype in (dns.rdatatype.AXFR, dns.rdatatype.IXFR):
response = self._handle_axfr(request) for response in self._handle_axfr(request):
yield response
raise StopIteration
else: else:
response = self._handle_record_query(request) for response in self._handle_record_query(request):
yield response
raise StopIteration
elif request.opcode() == dns.opcode.NOTIFY: elif request.opcode() == dns.opcode.NOTIFY:
response = self._handle_notify(request) for response in self._handle_notify(request):
yield response
raise StopIteration
else: else:
# Unhandled OpCode's include STATUS, IQUERY, NOTIFY, UPDATE # Unhandled OpCode's include STATUS, IQUERY, NOTIFY, UPDATE
response = self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
return response raise StopIteration
def _handle_notify(self, request): def _handle_notify(self, request):
""" """
@ -90,7 +100,8 @@ class RequestHandler(xfr.XFRMixin):
if len(request.question) != 1: if len(request.question) != 1:
response.set_rcode(dns.rcode.FORMERR) response.set_rcode(dns.rcode.FORMERR)
return response yield response
raise StopIteration
else: else:
question = request.question[0] question = request.question[0]
@ -104,7 +115,8 @@ class RequestHandler(xfr.XFRMixin):
domain = self.storage.find_domain(context, criterion) domain = self.storage.find_domain(context, criterion)
except exceptions.DomainNotFound: except exceptions.DomainNotFound:
response.set_rcode(dns.rcode.NOTAUTH) response.set_rcode(dns.rcode.NOTAUTH)
return response yield response
raise StopIteration
notify_addr = request.environ['addr'][0] notify_addr = request.environ['addr'][0]
@ -117,7 +129,8 @@ class RequestHandler(xfr.XFRMixin):
"%(addr)s, ignoring.") "%(addr)s, ignoring.")
LOG.warn(msg % {"name": domain.name, "addr": notify_addr}) LOG.warn(msg % {"name": domain.name, "addr": notify_addr})
response.set_rcode(dns.rcode.REFUSED) response.set_rcode(dns.rcode.REFUSED)
return response yield response
raise StopIteration
resolver = dns.resolver.Resolver() resolver = dns.resolver.Resolver()
# According to RFC we should query the server that sent the NOTIFY # According to RFC we should query the server that sent the NOTIFY
@ -138,7 +151,8 @@ class RequestHandler(xfr.XFRMixin):
response.flags |= dns.flags.AA response.flags |= dns.flags.AA
return response yield response
raise StopIteration
def _handle_query_error(self, request, rcode): def _handle_query_error(self, request, rcode):
""" """
@ -244,9 +258,8 @@ class RequestHandler(xfr.XFRMixin):
def _handle_axfr(self, request): def _handle_axfr(self, request):
context = request.environ['context'] context = request.environ['context']
response = dns.message.make_response(request)
q_rrset = request.question[0] q_rrset = request.question[0]
# First check if there is an existing zone # First check if there is an existing zone
# TODO(vinod) once validation is separated from the api, # TODO(vinod) once validation is separated from the api,
# validate the parameters # validate the parameters
@ -259,42 +272,82 @@ class RequestHandler(xfr.XFRMixin):
LOG.warning(_LW("DomainNotFound while handling axfr request. " LOG.warning(_LW("DomainNotFound while handling axfr request. "
"Question was %(qr)s") % {'qr': q_rrset}) "Question was %(qr)s") % {'qr': q_rrset})
return self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
except exceptions.Forbidden: except exceptions.Forbidden:
LOG.warning(_LW("Forbidden while handling axfr request. " LOG.warning(_LW("Forbidden while handling axfr request. "
"Question was %(qr)s") % {'qr': q_rrset}) "Question was %(qr)s") % {'qr': q_rrset})
return self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
r_rrsets = []
# The AXFR response needs to have a SOA at the beginning and end. # The AXFR response needs to have a SOA at the beginning and end.
criterion = {'domain_id': domain.id, 'type': 'SOA'} criterion = {'domain_id': domain.id, 'type': 'SOA'}
soa_recordsets = self.storage.find_recordsets(context, criterion) soa_records = self.storage.find_recordsets_axfr(context, criterion)
for recordset in soa_recordsets: # Get all the records other than SOA
r_rrsets.append(self._convert_to_rrset(domain, recordset))
# Get all the recordsets other than SOA
criterion = {'domain_id': domain.id, 'type': '!SOA'} criterion = {'domain_id': domain.id, 'type': '!SOA'}
records = self.storage.find_recordsets_axfr(context, criterion)
# Get the raw record data out of storage and parse it # Place the SOA RRSet at the front and end of the RRSet list
raw_records = self.storage.find_recordsets_axfr(context, criterion) records.insert(0, soa_records[0])
r_rrsets.extend(self._prep_rrsets(raw_records, domain.ttl)) records.append(soa_records[0])
# Append the SOA recordset at the end # Build the DNSPython RRSets from the Records
for recordset in soa_recordsets: rrsets = self._prep_rrsets(records, domain.ttl)
r_rrsets.append(self._convert_to_rrset(domain, recordset))
response.set_rcode(dns.rcode.NOERROR) # Build up a dummy response, we're stealing it's logic for building
# TODO(vinod) check if we dnspython has an upper limit on the number # the Flags.
# of rrsets. response = dns.message.make_response(request)
response.answer = r_rrsets
# For all the data stored in designate mdns is Authoritative
response.flags |= dns.flags.AA response.flags |= dns.flags.AA
response.set_rcode(dns.rcode.NOERROR)
return response max_message_size = CONF['service:mdns'].max_message_size
# Render the results, yielding a packet after each TooBig exception.
i, renderer = 0, None
while i < len(rrsets):
# No renderer? Build one
if renderer is None:
renderer = dns.renderer.Renderer(
response.id, response.flags, max_message_size)
for q in request.question:
renderer.add_question(q.name, q.rdtype, q.rdclass)
try:
renderer.add_rrset(dns.renderer.ANSWER, rrsets[i])
i += 1
except dns.exception.TooBig:
renderer.write_header()
if request.had_tsig:
renderer.add_tsig(
request.keyname,
request.keyring[request.keyname],
request.fudge,
request.original_id,
request.tsig_error,
request.other_data,
request.request_mac,
request.keyalgorithm)
yield renderer
renderer = None
if renderer is not None:
renderer.write_header()
if request.had_tsig:
renderer.add_tsig(
request.keyname,
request.keyring[request.keyname],
request.fudge,
request.original_id,
request.tsig_error,
request.other_data,
request.request_mac,
request.keyalgorithm)
yield renderer
raise StopIteration
def _handle_record_query(self, request): def _handle_record_query(self, request):
"""Handle a DNS QUERY request for a record""" """Handle a DNS QUERY request for a record"""
@ -321,13 +374,15 @@ class RequestHandler(xfr.XFRMixin):
LOG.warning(_LW("DomainNotFound while handling query request" LOG.warning(_LW("DomainNotFound while handling query request"
". Question was %(qr)s") % {'qr': q_rrset}) ". Question was %(qr)s") % {'qr': q_rrset})
return self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
except exceptions.Forbidden: except exceptions.Forbidden:
LOG.warning(_LW("Forbidden while handling query request. " LOG.warning(_LW("Forbidden while handling query request. "
"Question was %(qr)s") % {'qr': q_rrset}) "Question was %(qr)s") % {'qr': q_rrset})
return self._handle_query_error(request, dns.rcode.REFUSED) yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
r_rrset = self._convert_to_rrset(domain, recordset) r_rrset = self._convert_to_rrset(domain, recordset)
response.set_rcode(dns.rcode.NOERROR) response.set_rcode(dns.rcode.NOERROR)
@ -357,4 +412,4 @@ class RequestHandler(xfr.XFRMixin):
except exceptions.Forbidden: except exceptions.Forbidden:
response.set_rcode(dns.rcode.REFUSED) response.set_rcode(dns.rcode.REFUSED)
return response yield response

View File

@ -318,10 +318,8 @@ class DNSService(object):
""" """
try: try:
# Call into the DNS Application itself with the payload and addr # Call into the DNS Application itself with the payload and addr
response = self._dns_application({ for response in self._dns_application(
'payload': payload, {'payload': payload, 'addr': addr}):
'addr': addr
})
# Send back a response only if present # Send back a response only if present
if response is not None: if response is not None:
@ -330,11 +328,14 @@ class DNSService(object):
msg_length = len(response) msg_length = len(response)
tcp_response = struct.pack("!H", msg_length) + response tcp_response = struct.pack("!H", msg_length) + response
client.send(tcp_response) client.send(tcp_response)
client.close()
else: else:
# Handle UDP Responses # Handle UDP Responses
self._dns_sock_udp.sendto(response, addr) self._dns_sock_udp.sendto(response, addr)
# Close the TCP connection if we have one.
if client:
client.close()
except Exception: except Exception:
LOG.exception(_LE("Unhandled exception while processing request " LOG.exception(_LE("Unhandled exception while processing request "
"from %(host)s:%(port)d") % "from %(host)s:%(port)d") %

View File

@ -551,7 +551,8 @@ class TestCase(base.BaseTestCase):
return self.central_service.create_domain( return self.central_service.create_domain(
context, objects.Domain.from_dict(values)) context, objects.Domain.from_dict(values))
def create_recordset(self, domain, type='A', **kwargs): def create_recordset(self, domain, type='A', increment_serial=True,
**kwargs):
context = kwargs.pop('context', self.admin_context) context = kwargs.pop('context', self.admin_context)
fixture = kwargs.pop('fixture', 0) fixture = kwargs.pop('fixture', 0)
@ -560,9 +561,11 @@ class TestCase(base.BaseTestCase):
values=kwargs) values=kwargs)
return self.central_service.create_recordset( return self.central_service.create_recordset(
context, domain['id'], objects.RecordSet.from_dict(values)) context, domain['id'], objects.RecordSet.from_dict(values),
increment_serial=increment_serial)
def create_record(self, domain, recordset, **kwargs): def create_record(self, domain, recordset, increment_serial=True,
**kwargs):
context = kwargs.pop('context', self.admin_context) context = kwargs.pop('context', self.admin_context)
fixture = kwargs.pop('fixture', 0) fixture = kwargs.pop('fixture', 0)
@ -571,7 +574,8 @@ class TestCase(base.BaseTestCase):
return self.central_service.create_record( return self.central_service.create_record(
context, domain['id'], recordset['id'], context, domain['id'], recordset['id'],
objects.Record.from_dict(values)) objects.Record.from_dict(values),
increment_serial=increment_serial)
def create_blacklist(self, **kwargs): def create_blacklist(self, **kwargs):
context = kwargs.pop('context', self.admin_context) context = kwargs.pop('context', self.admin_context)

View File

@ -56,7 +56,7 @@ class AgentRequestHandlerTest(AgentTestCase):
"6f6d0000060001") "6f6d0000060001")
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': ["0.0.0.0", 1234]} request.environ = {'addr': ["0.0.0.0", 1234]}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
def test_receive_notify_bad_notifier(self): def test_receive_notify_bad_notifier(self):
@ -78,7 +78,7 @@ class AgentRequestHandlerTest(AgentTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
# Bad 'requester' # Bad 'requester'
request.environ = {'addr': ["6.6.6.6", 1234]} request.environ = {'addr': ["6.6.6.6", 1234]}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -108,7 +108,7 @@ class AgentRequestHandlerTest(AgentTestCase):
designate.backend.agent_backend.impl_fake.FakeBackend, designate.backend.agent_backend.impl_fake.FakeBackend,
'find_domain_serial', return_value=None): 'find_domain_serial', return_value=None):
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
def test_receive_create_bad_notifier(self): def test_receive_create_bad_notifier(self):
@ -130,7 +130,7 @@ class AgentRequestHandlerTest(AgentTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
# Bad 'requester' # Bad 'requester'
request.environ = {'addr': ["6.6.6.6", 1234]} request.environ = {'addr': ["6.6.6.6", 1234]}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -154,7 +154,7 @@ class AgentRequestHandlerTest(AgentTestCase):
"00ff03ff00") "00ff03ff00")
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': ["0.0.0.0", 1234]} request.environ = {'addr': ["0.0.0.0", 1234]}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -178,7 +178,7 @@ class AgentRequestHandlerTest(AgentTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
# Bad 'requester' # Bad 'requester'
request.environ = {'addr': ["6.6.6.6", 1234]} request.environ = {'addr': ["6.6.6.6", 1234]}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -206,6 +206,6 @@ class AgentRequestHandlerTest(AgentTestCase):
with mock.patch.object( with mock.patch.object(
designate.backend.agent_backend.impl_fake.FakeBackend, designate.backend.agent_backend.impl_fake.FakeBackend,
'find_domain_serial', return_value=None): 'find_domain_serial', return_value=None):
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
doaxfr.assert_called_with('example.com.', [], source="1.2.3.4") doaxfr.assert_called_with('example.com.', [], source="1.2.3.4")
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))

View File

@ -93,7 +93,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -116,7 +116,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -172,7 +172,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
with mock.patch.object(self.handler.storage, 'find_domain', with mock.patch.object(self.handler.storage, 'find_domain',
return_value=domain): return_value=domain):
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.mock_tg.add_thread.assert_called_with( self.mock_tg.add_thread.assert_called_with(
self.handler.domain_sync, self.context, domain, [master]) self.handler.domain_sync, self.context, domain, [master])
@ -213,7 +213,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
with mock.patch.object(self.handler.storage, 'find_domain', with mock.patch.object(self.handler.storage, 'find_domain',
return_value=domain): return_value=domain):
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
assert not self.mock_tg.add_thread.called assert not self.mock_tg.add_thread.called
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -251,7 +251,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
with mock.patch.object(self.handler.storage, 'find_domain', with mock.patch.object(self.handler.storage, 'find_domain',
return_value=domain): return_value=domain):
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
assert not self.mock_tg.add_thread.called assert not self.mock_tg.add_thread.called
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -277,7 +277,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
'context': self.context 'context': self.context
} }
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
assert not self.mock_tg.add_thread.called assert not self.mock_tg.add_thread.called
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -305,7 +305,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
'context': self.context 'context': self.context
} }
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
assert not self.mock_tg.add_thread.called assert not self.mock_tg.add_thread.called
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -329,7 +329,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -350,7 +350,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
# request = dns.message.from_wire(binascii.a2b_hex(payload)) # request = dns.message.from_wire(binascii.a2b_hex(payload))
# request.environ = {'addr': self.addr, 'context': self.context} # request.environ = {'addr': self.addr, 'context': self.context}
# response = self.handler(request).to_wire() # response = self.handler(request).next().to_wire()
# # strip the id from the response and compare # # strip the id from the response and compare
# self.assertEqual(expected_response, binascii.b2a_hex(response)[5:]) # self.assertEqual(expected_response, binascii.b2a_hex(response)[5:])
@ -377,7 +377,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
"00000100010000292000000000000000") "00000100010000292000000000000000")
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -408,7 +408,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -439,10 +439,148 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
def test_dispatch_opcode_query_AXFR(self):
# Query is for example.com. IN AXFR
# id 18883
# opcode QUERY
# rcode NOERROR
# flags AD
# edns 0
# payload 4096
# ;QUESTION
# example.com. IN AXFR
# ;ANSWER
# ;AUTHORITY
# ;ADDITIONAL
payload = ("49c300200001000000000001076578616d706c6503636f6d0000fc0001"
"0000291000000000000000")
# id 18883
# opcode QUERY
# rcode NOERROR
# flags QR AA
# ;QUESTION
# example.com. IN AXFR
# ;ANSWER
# example.com. 3600 IN SOA ns1.example.org. example.example.com.
# -> 1427899961 3600 600 86400 3600
# mail.example.com. 3600 IN A 192.0.2.1
# example.com. 3600 IN NS ns1.example.org.
# ;AUTHORITY
# ;ADDITIONAL
expected_response = \
("49c384000001000400000000076578616d706c6503636f6d0000fc0001c00c00"
"06000100000e10002f036e7331076578616d706c65036f726700076578616d70"
"6c65c00c551c063900000e10000002580001518000000e10c00c000200010000"
"0e100002c029046d61696cc00c0001000100000e100004c0000201c00c000600"
"0100000e100018c029c03a551c063900000e10000002580001518000000e10")
domain = objects.Domain.from_dict({
'name': 'example.com.',
'ttl': 3600,
'serial': 1427899961,
'email': 'example@example.com',
})
def _find_recordsets_axfr(context, criterion):
if criterion['type'] == 'SOA':
return [['UUID1', 'SOA', '3600', 'example.com.',
'ns1.example.org. example.example.com. 1427899961 '
'3600 600 86400 3600', 'ACTION']]
elif criterion['type'] == '!SOA':
return [
['UUID2', 'NS', '3600', 'example.com.', 'ns1.example.org.',
'ACTION'],
['UUID3', 'A', '3600', 'mail.example.com.', '192.0.2.1',
'ACTION'],
]
with mock.patch.object(self.storage, 'find_domain',
return_value=domain):
with mock.patch.object(self.storage, 'find_recordsets_axfr',
side_effect=_find_recordsets_axfr):
request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).next().get_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response))
def test_dispatch_opcode_query_AXFR_multiple_messages(self):
# Query is for example.com. IN AXFR
# id 18883
# opcode QUERY
# rcode NOERROR
# flags AD
# edns 0
# payload 4096
# ;QUESTION
# example.com. IN AXFR
# ;ANSWER
# ;AUTHORITY
# ;ADDITIONAL
payload = ("49c300200001000000000001076578616d706c6503636f6d0000fc0001"
"0000291000000000000000")
expected_response = [
("49c384000001000300000000076578616d706c6503636f6d0000fc0001c00c00"
"06000100000e10002f036e7331076578616d706c65036f726700076578616d70"
"6c65c00c551c063900000e10000002580001518000000e10c00c000200010000"
"0e100002c029046d61696cc00c0001000100000e100004c0000201"),
("49c384000001000100000000076578616d706c6503636f6d0000fc0001c00c00"
"06000100000e10002f036e7331076578616d706c65036f726700076578616d70"
"6c65c00c551c063900000e10000002580001518000000e10"),
]
# Set the max-message-size to 128
self.config(max_message_size=128, group='service:mdns')
domain = objects.Domain.from_dict({
'name': 'example.com.',
'ttl': 3600,
'serial': 1427899961,
'email': 'example@example.com',
})
def _find_recordsets_axfr(context, criterion):
if criterion['type'] == 'SOA':
return [['UUID1', 'SOA', '3600', 'example.com.',
'ns1.example.org. example.example.com. 1427899961 '
'3600 600 86400 3600', 'ACTION']]
elif criterion['type'] == '!SOA':
return [
['UUID2', 'NS', '3600', 'example.com.', 'ns1.example.org.',
'ACTION'],
['UUID3', 'A', '3600', 'mail.example.com.', '192.0.2.1',
'ACTION'],
]
with mock.patch.object(self.storage, 'find_domain',
return_value=domain):
with mock.patch.object(self.storage, 'find_recordsets_axfr',
side_effect=_find_recordsets_axfr):
request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context}
response_generator = self.handler(request)
# Validate the first response
response_one = response_generator.next().get_wire()
self.assertEqual(
expected_response[0], binascii.b2a_hex(response_one))
# Validate the second response
response_two = response_generator.next().get_wire()
self.assertEqual(
expected_response[1], binascii.b2a_hex(response_two))
def test_dispatch_opcode_query_nonexistent_recordtype(self): def test_dispatch_opcode_query_nonexistent_recordtype(self):
# query is for mail.example.com. IN CNAME # query is for mail.example.com. IN CNAME
payload = ("271801000001000000000000046d61696c076578616d706c6503636f6d" payload = ("271801000001000000000000046d61696c076578616d706c6503636f6d"
@ -469,7 +607,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -492,7 +630,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
request = dns.message.from_wire(binascii.a2b_hex(payload)) request = dns.message.from_wire(binascii.a2b_hex(payload))
request.environ = {'addr': self.addr, 'context': self.context} request.environ = {'addr': self.addr, 'context': self.context}
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -532,7 +670,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
"0000010001c00c0001000100000e100004c0000205000029" "0000010001c00c0001000100000e100004c0000205000029"
"2000000000000000") "2000000000000000")
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -553,7 +691,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
expected_response = ("c28981050001000000000001076578616d706c6503636f6d" expected_response = ("c28981050001000000000001076578616d706c6503636f6d"
"00000100010000292000000000000000") "00000100010000292000000000000000")
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
def test_dispatch_opcode_query_tsig_scope_zone(self): def test_dispatch_opcode_query_tsig_scope_zone(self):
@ -598,7 +736,7 @@ class MdnsRequestHandlerTest(MdnsTestCase):
"0000010001c00c0001000100000e100004c0000205000029" "0000010001c00c0001000100000e100004c0000205000029"
"2000000000000000") "2000000000000000")
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))
@ -619,5 +757,5 @@ class MdnsRequestHandlerTest(MdnsTestCase):
expected_response = ("c28981050001000000000001076578616d706c6503636f6d" expected_response = ("c28981050001000000000001076578616d706c6503636f6d"
"00000100010000292000000000000000") "00000100010000292000000000000000")
response = self.handler(request).to_wire() response = self.handler(request).next().to_wire()
self.assertEqual(expected_response, binascii.b2a_hex(response)) self.assertEqual(expected_response, binascii.b2a_hex(response))