scenario: use object model for DNS messages

parent 29af67d9
Pipeline #41680 passed with stage
in 1 minute and 34 seconds
# FIXME pylint: disable=too-many-lines
from abc import ABC
import binascii
import calendar
from datetime import datetime
......@@ -137,28 +138,104 @@ def replay_rrs(rrs, nqueries, destination, args=None):
return nsent, nrcvd
class DNSMessage:
class DNSBlob(ABC):
def to_wire(self) -> bytes:
raise NotImplementedError
def __str__(self) -> str:
return '<DNSBlob>'
class DNSMessage(DNSBlob):
def __init__(self, message: dns.message.Message) -> None:
assert message is not None
self.message = message
def to_wire(self) -> bytes:
return self.message.to_wire(max_size=65535)
def __str__(self) -> str:
return str(self.message)
class DNSReply(DNSMessage):
def __init__(
self,
message: Optional[dns.message.Message] = None,
wire: Optional[bytes] = None
message: dns.message.Message,
query: Optional[dns.message.Message] = None,
copy_id: bool = False,
copy_query: bool = False
) -> None:
self.message = message
self._wire = wire
self.is_raw_data = wire is not None
super().__init__(message)
if copy_id or copy_query:
if query is None:
raise ValueError("query must be provided to adjust copy_id/copy_query")
self.adjust_reply(query, copy_id, copy_query)
@property
def wire(self) -> bytes:
if self.is_raw_data:
assert self._wire is not None
return self._wire
elif self.message is not None:
return self.message.to_wire(max_size=65535)
raise ValueError('No DNS message or raw wire data available!')
def adjust_reply(
self,
query: dns.message.Message,
copy_id: bool = True,
copy_query: bool = True
) -> None:
answer = dns.message.from_wire(self.message.to_wire(),
xfr=self.message.xfr,
one_rr_per_rrset=True)
answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
if copy_id:
answer.id = query.id
# Copy letter-case if the template has QD
if answer.question:
answer.question[0].name = query.question[0].name
if copy_query:
answer.question = query.question
# Re-set, as the EDNS might have reset the ext-rcode
answer.set_rcode(self.message.rcode())
# sanity check: adjusted answer should be almost the same
assert len(answer.answer) == len(self.message.answer)
assert len(answer.authority) == len(self.message.authority)
assert len(answer.additional) == len(self.message.additional)
self.message = answer
class ReplyNotFound(Exception):
pass
class DNSReplyRaw(DNSBlob):
def __init__(
self,
wire: bytes,
query: Optional[dns.message.Message] = None,
copy_id: bool = False
) -> None:
assert wire is not None
self.wire = wire
if copy_id:
self.adjust_reply(query, copy_id)
def adjust_reply(
self,
query: dns.message.Message,
copy_id: bool = True
) -> None:
if copy_id:
if len(self.wire) < 2:
raise ValueError(
'wire data must contain at least 2 bytes to adjust query id')
raw_answer = bytearray(self.wire)
struct.pack_into('!H', raw_answer, 0, query.id)
self.wire = bytes(raw_answer)
def to_wire(self) -> bytes:
return self.wire
def __str__(self) -> str:
return '<DNSReplyRaw>'
class DNSReplyServfail(DNSMessage):
def __init__(self, query: dns.message.Message) -> None:
message = dns.message.make_response(query)
message.set_rcode(dns.rcode.SERVFAIL)
super().__init__(message)
class Entry:
......@@ -390,45 +467,16 @@ class Entry:
if expected != got:
raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
def _adjust_reply(self, query: dns.message.Message) -> dns.message.Message:
""" Copy scripted reply and adjust to received query. """
answer = dns.message.from_wire(self.message.to_wire(),
xfr=self.message.xfr,
one_rr_per_rrset=True)
answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
if 'copy_id' in self.adjust_fields:
answer.id = query.id
# Copy letter-case if the template has QD
if answer.question:
answer.question[0].name = query.question[0].name
if 'copy_query' in self.adjust_fields:
answer.question = query.question
# Re-set, as the EDNS might have reset the ext-rcode
answer.set_rcode(self.message.rcode())
# sanity check: adjusted answer should be almost the same
assert len(answer.answer) == len(self.message.answer)
assert len(answer.authority) == len(self.message.authority)
assert len(answer.additional) == len(self.message.additional)
return answer
def _adjust_raw_reply(self, query: dns.message.Message) -> bytes:
assert self.raw_data is not None
if 'raw_id' in self.adjust_fields:
assert len(self.raw_data) >= 2, "RAW message has to contain at least 2 bytes"
raw_answer = bytearray(self.raw_data)
struct.pack_into('!H', raw_answer, 0, query.id)
return bytes(raw_answer)
return self.raw_data
def reply(self, query) -> Optional[DNSMessage]:
def reply(self, query) -> Optional[DNSBlob]:
if 'do_not_answer' in self.adjust_fields:
return None
if self.is_raw_data_entry:
wire = self._adjust_raw_reply(query)
return DNSMessage(wire=wire)
msg = self._adjust_reply(query)
return DNSMessage(msg)
copy_id = 'raw_data' in self.adjust_fields
assert self.raw_data is not None
return DNSReplyRaw(self.raw_data, query, copy_id)
copy_id = 'copy_id' in self.adjust_fields
copy_query = 'copy_query' in self.adjust_fields
return DNSReply(self.message, query, copy_id, copy_query)
def set_edns(self, fields):
""" Set EDNS version and bufsize. """
......@@ -506,7 +554,7 @@ class Range:
or address in self.addresses)
return False
def reply(self, query: dns.message.Message) -> Optional[DNSMessage]:
def reply(self, query: dns.message.Message) -> Optional[DNSBlob]:
"""Get answer for given query (adjusted if needed)."""
self.received += 1
for candidate in self.stored:
......@@ -516,13 +564,13 @@ class Range:
# Probabilistic loss
if 'LOSS' in self.args:
if random.random() < float(self.args['LOSS']):
raise ReplyNotFound
return DNSReplyServfail(query)
self.sent += 1
candidate.fired += 1
return resp
except ValueError:
pass
raise ReplyNotFound
return DNSReplyServfail(query)
class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods
......@@ -780,7 +828,7 @@ class Scenario:
txt += "\nSCENARIO_END"
return txt
def reply(self, query: dns.message.Message, address=None) -> Optional[DNSMessage]:
def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]:
"""Generate answer packet for given query."""
current_step_id = self.current_step.id
# Unknown address, select any match
......@@ -806,7 +854,7 @@ class Scenario:
return candidate.reply(query)
except (IndexError, ValueError):
pass
raise ReplyNotFound
return DNSReplyServfail(query)
def play(self, paddr):
""" Play given scenario. """
......
......@@ -85,42 +85,27 @@ class TestServer:
True if client socket should be closed by caller
False if client socket should be kept open
"""
def servfail_reply() -> bytes:
response = dns.message.make_response(query)
response.set_rcode(dns.rcode.SERVFAIL)
self.undefined_answers += 1
self.scenario.current_step.log.error(
'server %s has no response for question %s, answering with SERVFAIL',
server_addr,
'; '.join([str(rr) for rr in query.question]))
return response.to_wire()
log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
try:
message = self.scenario.reply(query, server_addr)
except scenario.ReplyNotFound:
data_to_wire = servfail_reply()
else:
if not message:
log.debug('ignoring')
return True
if message.is_raw_data:
log.debug('raw response not printed')
else:
log.debug('response: %s', message.message)
try:
data_to_wire = message.wire
except ValueError:
data_to_wire = servfail_reply()
message = self.scenario.reply(query, server_addr)
if not message:
log.debug('ignoring')
return True
elif isinstance(message, scenario.DNSReplyServfail):
self.undefined_answers += 1
self.scenario.current_step.log.error(
'server %s has no response for question %s, answering with SERVFAIL',
server_addr,
'; '.join([str(rr) for rr in query.question]))
else:
log.debug('response: %s', message)
scenario.sendto_msg(client, data_to_wire, client_addr)
scenario.sendto_msg(client, message.to_wire(), client_addr)
return True
def query_io(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment