pydnstest/scenario: return DNSMessage from reply()

parent e6bbb4fe
# FIXME pylint: disable=too-many-lines
import binascii import binascii
import calendar import calendar
from datetime import datetime
import errno import errno
import logging import logging
import os import os
...@@ -9,7 +11,7 @@ import socket ...@@ -9,7 +11,7 @@ import socket
import string import string
import struct import struct
import time import time
from datetime import datetime from typing import Optional
import dns.dnssec import dns.dnssec
import dns.message import dns.message
...@@ -135,6 +137,30 @@ def replay_rrs(rrs, nqueries, destination, args=None): ...@@ -135,6 +137,30 @@ def replay_rrs(rrs, nqueries, destination, args=None):
return nsent, nrcvd return nsent, nrcvd
class DNSMessage:
def __init__(
self,
message: Optional[dns.message.Message] = None,
wire: Optional[bytes] = None
) -> None:
self.message = message
self._wire = wire
self.is_raw_data = wire is not None
@property
def wire(self) -> bytes:
if self.is_raw_data:
assert self._wire is not None
return self._wire
elif self.message is not None:
return self.message.to_wire(max_size=65535)
raise ValueError('No DNS message or raw wire data available!')
class ReplyNotFound(Exception):
pass
class Entry: class Entry:
""" """
Data entry represents scripted message and extra metadata, Data entry represents scripted message and extra metadata,
...@@ -155,7 +181,7 @@ class Entry: ...@@ -155,7 +181,7 @@ class Entry:
self.fired = 0 self.fired = 0
# RAW # RAW
self.raw_data = None self.raw_data = None # type: Optional[bytes]
self.is_raw_data_entry = self.process_raw() self.is_raw_data_entry = self.process_raw()
# MATCH # MATCH
...@@ -364,7 +390,7 @@ class Entry: ...@@ -364,7 +390,7 @@ class Entry:
if expected != got: if expected != got:
raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got)) raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
def _adjust_reply(self, query): def _adjust_reply(self, query: dns.message.Message) -> dns.message.Message:
""" Copy scripted reply and adjust to received query. """ """ Copy scripted reply and adjust to received query. """
answer = dns.message.from_wire(self.message.to_wire(), answer = dns.message.from_wire(self.message.to_wire(),
xfr=self.message.xfr, xfr=self.message.xfr,
...@@ -386,7 +412,8 @@ class Entry: ...@@ -386,7 +412,8 @@ class Entry:
assert len(answer.additional) == len(self.message.additional) assert len(answer.additional) == len(self.message.additional)
return answer return answer
def _adjust_raw_reply(self, query): def _adjust_raw_reply(self, query: dns.message.Message) -> bytes:
assert self.raw_data is not None
if 'raw_id' in self.adjust_fields: if 'raw_id' in self.adjust_fields:
assert len(self.raw_data) >= 2, "RAW message has to contain at least 2 bytes" assert len(self.raw_data) >= 2, "RAW message has to contain at least 2 bytes"
raw_answer = bytearray(self.raw_data) raw_answer = bytearray(self.raw_data)
...@@ -394,12 +421,14 @@ class Entry: ...@@ -394,12 +421,14 @@ class Entry:
return bytes(raw_answer) return bytes(raw_answer)
return self.raw_data return self.raw_data
def reply(self, query): def reply(self, query) -> Optional[DNSMessage]:
if self.ignore: if 'do_not_answer' in self.adjust_fields:
return None, True return None
if self.is_raw_data_entry: if self.is_raw_data_entry:
return self._adjust_raw_reply(query), True wire = self._adjust_raw_reply(query)
return self._adjust_reply(query), False return DNSMessage(wire=wire)
msg = self._adjust_reply(query)
return DNSMessage(msg)
def set_edns(self, fields): def set_edns(self, fields):
""" Set EDNS version and bufsize. """ """ Set EDNS version and bufsize. """
...@@ -477,15 +506,8 @@ class Range: ...@@ -477,15 +506,8 @@ class Range:
or address in self.addresses) or address in self.addresses)
return False return False
def reply(self, query): def reply(self, query: dns.message.Message) -> Optional[DNSMessage]:
""" """Get answer for given query (adjusted if needed)."""
Get answer for given query (adjusted if needed).
Returns: (answer, is_raw_data)
answer: DNS message object or wire-format (bytes) or None if there is no
candidate in this range
is_raw_data: True if response is wire-format bytes, instead of DNS message object
"""
self.received += 1 self.received += 1
for candidate in self.stored: for candidate in self.stored:
try: try:
...@@ -494,13 +516,13 @@ class Range: ...@@ -494,13 +516,13 @@ class Range:
# Probabilistic loss # Probabilistic loss
if 'LOSS' in self.args: if 'LOSS' in self.args:
if random.random() < float(self.args['LOSS']): if random.random() < float(self.args['LOSS']):
return None, None raise ReplyNotFound
self.sent += 1 self.sent += 1
candidate.fired += 1 candidate.fired += 1
return resp return resp
except ValueError: except ValueError:
pass pass
return None, None raise ReplyNotFound
class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods
...@@ -758,18 +780,12 @@ class Scenario: ...@@ -758,18 +780,12 @@ class Scenario:
txt += "\nSCENARIO_END" txt += "\nSCENARIO_END"
return txt return txt
def reply(self, query, address=None): def reply(self, query: dns.message.Message, address=None) -> Optional[DNSMessage]:
""" """Generate answer packet for given query."""
Generate answer packet for given query.
The answer can be DNS message object or a binary blob.
Returns:
(answer, boolean "is the answer binary blob?")
"""
current_step_id = self.current_step.id current_step_id = self.current_step.id
# Unknown address, select any match # Unknown address, select any match
# TODO: workaround until the server supports stub zones # TODO: workaround until the server supports stub zones
all_addresses = set() all_addresses = set() # type: ignore
for rng in self.ranges: for rng in self.ranges:
all_addresses.update(rng.addresses) all_addresses.update(rng.addresses)
if address not in all_addresses: if address not in all_addresses:
...@@ -790,7 +806,7 @@ class Scenario: ...@@ -790,7 +806,7 @@ class Scenario:
return candidate.reply(query) return candidate.reply(query)
except (IndexError, ValueError): except (IndexError, ValueError):
pass pass
return None, False raise ReplyNotFound
def play(self, paddr): def play(self, paddr):
""" Play given scenario. """ """ Play given scenario. """
......
...@@ -85,34 +85,40 @@ class TestServer: ...@@ -85,34 +85,40 @@ class TestServer:
True if client socket should be closed by caller True if client socket should be closed by caller
False if client socket should be kept open False if client socket should be kept open
""" """
log = logging.getLogger('pydnstest.testserver.handle_query') def servfail_reply() -> bytes:
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
response, second = self.scenario.reply(query, server_addr)
if response:
is_raw_data = second
if not is_raw_data:
data_to_wire = response.to_wire(max_size=65535)
log.debug('response: %s', response)
else:
data_to_wire = response
log.debug('raw response not printed')
elif not second:
response = dns.message.make_response(query) response = dns.message.make_response(query)
response.set_rcode(dns.rcode.SERVFAIL) response.set_rcode(dns.rcode.SERVFAIL)
data_to_wire = response.to_wire()
self.undefined_answers += 1 self.undefined_answers += 1
self.scenario.current_step.log.error( self.scenario.current_step.log.error(
'server %s has no response for question %s, answering with SERVFAIL', 'server %s has no response for question %s, answering with SERVFAIL',
server_addr, server_addr,
'; '.join([str(rr) for rr in query.question])) '; '.join([str(rr) for rr in query.question]))
return response.to_wire()
log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
try:
message = self.scenario.reply(query, server_addr)
except scenario.ReplyNotFound:
data_to_wire = servfail_reply()
else: else:
# Just ignore if not message:
log.debug('ignoring') log.debug('ignoring')
return True return True
if message.is_raw_data:
log.debug('raw response not printed')
else:
log.debug('response: %s', message.message)
try:
data_to_wire = message.wire
except ValueError:
data_to_wire = servfail_reply()
scenario.sendto_msg(client, data_to_wire, client_addr) scenario.sendto_msg(client, data_to_wire, client_addr)
return True return True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment