Fix rrset serialization, improve mdns tests

Code refactor and cleanup
Add port number logging

Change-Id: Ied150676166e038a005d73884788d406ad0e296c
Closes-Bug: #1550441
This commit is contained in:
Federico Ceratto 2016-03-01 12:36:47 +00:00
parent 3cf67d6e75
commit 4bc65992ce
6 changed files with 274 additions and 44 deletions

View File

@ -142,6 +142,10 @@ class SerializationMiddleware(DNSMiddleware):
elif isinstance(response, dns.renderer.Renderer):
yield response.get_wire()
else:
LOG.error(_LE("Unexpected response %(resp)s") %
repr(response))
class TsigInfoMiddleware(DNSMiddleware):
"""Middleware which looks up the information available for a TsigKey"""

View File

@ -64,6 +64,7 @@ class RequestHandler(xfr.XFRMixin):
# TSIG places the pseudo records into the additional section.
if (len(request.question) != 1 or
request.question[0].rdclass != dns.rdataclass.IN):
LOG.debug("Refusing due to numbers of questions or rdclass")
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
@ -88,6 +89,7 @@ class RequestHandler(xfr.XFRMixin):
else:
# Unhandled OpCode's include STATUS, IQUERY, UPDATE
LOG.debug("Refusing unhandled opcode")
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
@ -131,7 +133,7 @@ class RequestHandler(xfr.XFRMixin):
master_addr = zone.get_master_by_ip(notify_addr)
if not master_addr:
msg = _LW("NOTIFY for %(name)s from non-master server "
"%(addr)s, ignoring.")
"%(addr)s, refusing.")
LOG.warning(msg % {"name": zone.name, "addr": notify_addr})
response.set_rcode(dns.rcode.REFUSED)
yield response
@ -200,18 +202,13 @@ class RequestHandler(xfr.XFRMixin):
def _convert_to_rrset(self, zone, recordset):
# Fetch the zone or the config ttl if the recordset ttl is null
if recordset.ttl:
ttl = recordset.ttl
else:
ttl = zone.ttl
ttl = recordset.ttl or zone.ttl
# construct rdata from all the records
rdata = []
for record in recordset.records:
# TODO(Ron): this should be handled in the Storage query where we
# find the recordsets.
if record.action != 'DELETE':
rdata.append(str(record.data))
# TODO(Ron): this should be handled in the Storage query where we
# find the recordsets.
rdata = [str(record.data) for record in recordset.records
if record.action != 'DELETE']
# Now put the records into dnspython's RRsets
# answer section has 1 RR set. If the RR set has multiple
@ -219,13 +216,10 @@ class RequestHandler(xfr.XFRMixin):
# section.
# RRSet has name, ttl, class, type and rdata
# The rdata has one or more records
r_rrset = None
if rdata:
r_rrset = dns.rrset.from_text_list(
return dns.rrset.from_text_list(
recordset.name, ttl, dns.rdataclass.IN, recordset.type, rdata)
return r_rrset
def _handle_axfr(self, request):
context = request.environ['context']
q_rrset = request.question[0]
@ -361,31 +355,6 @@ class RequestHandler(xfr.XFRMixin):
}
recordset = self.storage.find_recordset(context, criterion)
try:
criterion = self._zone_criterion_from_request(
request, {'id': recordset.zone_id})
zone = self.storage.find_zone(context, criterion)
except exceptions.ZoneNotFound:
LOG.warning(_LW("ZoneNotFound while handling query request"
". Question was %(qr)s") % {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
except exceptions.Forbidden:
LOG.warning(_LW("Forbidden while handling query request. "
"Question was %(qr)s") % {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
r_rrset = self._convert_to_rrset(zone, recordset)
response.set_rcode(dns.rcode.NOERROR)
response.answer = [r_rrset]
# For all the data stored in designate mdns is Authoritative
response.flags |= dns.flags.AA
except exceptions.NotFound:
# If an FQDN exists, like www.rackspace.com, but the specific
# record type doesn't exist, like type SPF, then the return code
@ -403,9 +372,37 @@ class RequestHandler(xfr.XFRMixin):
#
# To simply things currently this returns a REFUSED in all cases.
# If zone transfers needs different errors, we could revisit this.
response.set_rcode(dns.rcode.REFUSED)
LOG.info(_LI("NotFound, refusing. Question was %(qr)s"),
{'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
except exceptions.Forbidden:
response.set_rcode(dns.rcode.REFUSED)
LOG.info(_LI("Forbidden, refusing. Question was %(qr)s"),
{'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
try:
criterion = self._zone_criterion_from_request(
request, {'id': recordset.zone_id})
zone = self.storage.find_zone(context, criterion)
except exceptions.ZoneNotFound:
LOG.warning(_LW("ZoneNotFound while handling query request"
". Question was %(qr)s") % {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
except exceptions.Forbidden:
LOG.warning(_LW("Forbidden while handling query request. "
"Question was %(qr)s") % {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
raise StopIteration
r_rrset = self._convert_to_rrset(zone, recordset)
response.answer = [r_rrset] if r_rrset else []
response.set_rcode(dns.rcode.NOERROR)
# For all the data stored in designate mdns is Authoritative
response.flags |= dns.flags.AA
yield response

View File

@ -13,15 +13,23 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import binascii
import socket
import dns
import dns.message
import mock
from oslo_log import log as logging
from designate.tests.test_mdns import MdnsTestCase
LOG = logging.getLogger(__name__)
def hex_wire(response):
return binascii.b2a_hex(response.to_wire())
class MdnsServiceTest(MdnsTestCase):
def setUp(self):
@ -65,3 +73,115 @@ class MdnsServiceTest(MdnsTestCase):
sock_udp=sock_udp)
sendto_mock.assert_called_once_with(
binascii.a2b_hex(expected_response), self.addr)
def _send_request_to_mdns(self, req):
"""Send request to localhost"""
self.assertTrue(len(self.service._dns_socks_udp))
port = self.service._dns_socks_udp[0].getsockname()[1]
response = dns.query.udp(req, '127.0.0.1', port=port, timeout=1)
LOG.info("\n-- RESPONSE --\n%s\n--------------\n" % response.to_text())
return response
def _query_mdns(self, qname, rdtype, rdclass=dns.rdataclass.IN):
"""Send query to localhost"""
req = dns.message.make_query(qname, rdtype, rdclass=rdclass)
req.id = 123
return self._send_request_to_mdns(req)
def test_query(self):
zone = self.create_zone()
# Reply query for NS
response = self._query_mdns(zone.name, dns.rdatatype.NS)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(1, len(response.answer))
ans = response.answer[0]
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
# Reply query for SOA
response = self._query_mdns(zone.name, dns.rdatatype.SOA)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(1, len(response.answer))
ans = response.answer[0]
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
# Refuse query for incorrect rdclass
response = self._query_mdns(zone.name, dns.rdatatype.SOA,
rdclass=dns.rdataclass.RESERVED0)
self.assertEqual(dns.rcode.REFUSED, response.rcode())
expected = b'007b81050001000000000000076578616d706c6503636f6d0000060000' # noqa
self.assertEqual(expected, hex_wire(response))
# Refuse query for ANY
response = self._query_mdns("www.%s" % zone.name, dns.rdatatype.ANY)
self.assertEqual(dns.rcode.REFUSED, response.rcode())
expected = b'007b8105000100000000000003777777076578616d706c6503636f6d0000ff0001' # noqa
self.assertEqual(expected, hex_wire(response))
# Reply query for A against inexistent record
response = self._query_mdns("nope.%s" % zone.name, dns.rdatatype.A)
self.assertEqual(dns.rcode.REFUSED, response.rcode())
expected = b'007b81050001000000000000046e6f7065076578616d706c6503636f6d0000010001' # noqa
self.assertEqual(expected, hex_wire(response))
# Reply query for A
recordset = self.create_recordset(zone)
self.create_record(zone, recordset)
response = self._query_mdns(recordset.name, dns.rdatatype.A)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(1, len(response.answer))
ans = response.answer[0]
self.assertEqual(dns.rdatatype.A, ans.rdtype)
self.assertEqual(recordset.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
self.assertEqual('3600 IN A 192.0.2.1', str(ans.to_rdataset()))
expected = b'007b85000001000100000000046d61696c076578616d706c6503636f6d0000010001c00c0001000100000e100004c0000201' # noqa
self.assertEqual(expected, hex_wire(response))
def test_query_axfr(self):
zone = self.create_zone()
# Query for AXFR
response = self._query_mdns(zone.name, dns.rdatatype.AXFR)
self.assertEqual(dns.rcode.NOERROR, response.rcode())
self.assertEqual(2, len(response.answer))
ans = response.answer[0] # SOA
self.assertEqual(dns.rdatatype.SOA, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
ans = response.answer[1] # NS
self.assertEqual(dns.rdatatype.NS, ans.rdtype)
self.assertEqual(zone.name, ans.name.to_text())
self.assertEqual(zone.ttl, ans.ttl)
def test_notify_notauth_primary_zone(self):
zone = self.create_zone()
# Send NOTIFY to mdns: NOTAUTH for primary zone
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
notify.id = 123
notify.flags = 0
notify.set_opcode(dns.opcode.NOTIFY)
notify.flags |= dns.flags.AA
response = self._send_request_to_mdns(notify)
self.assertEqual(dns.rcode.NOTAUTH, response.rcode())
expected = b'007ba0090001000000000000076578616d706c6503636f6d0000060001' # noqa
self.assertEqual(expected, hex_wire(response))
def test_notify_non_master(self):
zone = self.create_zone(type='SECONDARY', email='test@example.com')
# Send NOTIFY to mdns: refuse from non-master
notify = dns.message.make_query(zone.name, dns.rdatatype.SOA)
notify.id = 123
notify.flags = 0
notify.set_opcode(dns.opcode.NOTIFY)
notify.flags |= dns.flags.AA
response = self._send_request_to_mdns(notify)
self.assertEqual(dns.rcode.REFUSED, response.rcode())
expected = b'007ba0050001000000000000076578616d706c6503636f6d0000060001' # noqa
self.assertEqual(expected, hex_wire(response))

View File

@ -13,13 +13,19 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import dns
from mock import Mock
from oslo_log import log as logging
import dns
from designate import exceptions
from designate import objects
from designate.mdns import handler
LOG = logging.getLogger(__name__)
class TestRequestHandlerCall(unittest.TestCase):
"""
@ -42,7 +48,6 @@ class TestRequestHandlerCall(unittest.TestCase):
self.handler._handle_query_error.assert_called_with(
request, error_type
)
return True
def test_central_api_property(self):
self.handler._central_api = 'foo'
@ -104,3 +109,99 @@ class TestRequestHandlerCall(unittest.TestCase):
request = Mock()
request.opcode.return_value = dns.opcode.NOTIFY
assert list(self.handler(request)) == ['Notify']
def test__convert_to_rrset_no_records(self):
zone = objects.Zone.from_dict({'ttl': 1234})
recordset = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
])
)
r_rrset = self.handler._convert_to_rrset(zone, recordset)
self.assertEqual(None, r_rrset)
def test__convert_to_rrset(self):
zone = objects.Zone.from_dict({'ttl': 1234})
recordset = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.1'),
objects.Record(data='192.0.2.2'),
])
)
r_rrset = self.handler._convert_to_rrset(zone, recordset)
self.assertEqual(2, len(r_rrset))
class HandleRecordQueryTest(unittest.TestCase):
def setUp(self):
self.storage = Mock()
self.tg = Mock()
self.handler = handler.RequestHandler(self.storage, self.tg)
def test__handle_record_query_empty_recordlist(self):
# bug #1550441
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
])
)
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response_gen = self.handler._handle_record_query(request)
for r in response_gen:
# This was raising an exception due to bug #1550441
out = r.to_wire(max_size=65535)
self.assertEqual(33, len(out))
def test__handle_record_query_zone_not_found(self):
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.2'),
])
)
self.storage.find_zone.side_effect = exceptions.ZoneNotFound
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_forbidden(self):
self.storage.find_recordset.return_value = objects.RecordSet(
name='www.example.org.',
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.2'),
])
)
self.storage.find_zone.side_effect = exceptions.Forbidden
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_find_recordsed_forbidden(self):
self.storage.find_recordset.side_effect = exceptions.Forbidden
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
def test__handle_record_query_find_recordsed_not_found(self):
self.storage.find_recordset.side_effect = exceptions.NotFound
request = dns.message.make_query('www.example.org.', dns.rdatatype.A)
request.environ = dict(context='ctx')
response = tuple(self.handler._handle_record_query(request))
self.assertEqual(1, len(response))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())

View File

@ -482,6 +482,10 @@ def bind_tcp(host, port, tcp_backlog, tcp_keepidle=None):
sock_tcp.setblocking(True)
sock_tcp.bind((host, port))
if port == 0:
newport = sock_tcp.getsockname()[1]
LOG.info(_LI('Listening on TCP port %(port)d'), {'port': newport})
sock_tcp.listen(tcp_backlog)
return sock_tcp
@ -502,5 +506,8 @@ def bind_udp(host, port):
sock_udp.setblocking(True)
sock_udp.bind((host, port))
if port == 0:
newport = sock_udp.getsockname()[1]
LOG.info(_LI('Listening on UDP port %(port)d'), {'port': newport})
return sock_udp

View File

@ -8,6 +8,7 @@ designate.tests.test_backend.test_nsd4
designate.tests.test_central.test_service
designate.tests.test_dnsutils
designate.tests.test_mdns.test_handler
designate.tests.test_mdns.test_service.MdnsServiceTest.test_query
designate.tests.test_notification_handler.test_neutron
designate.tests.test_notification_handler.test_nova
designate.tests.test_pool_manager.test_service