pydnstest/scenario: return DNSMessage from reply()

parent e6bbb4fe
# FIXME pylint: disable=too-many-lines
import binascii
import calendar
from datetime import datetime
import errno
import logging
import os
......@@ -9,7 +11,7 @@ import socket
import string
import struct
import time
from datetime import datetime
from typing import Optional
import dns.dnssec
import dns.message
......@@ -135,6 +137,30 @@ def replay_rrs(rrs, nqueries, destination, args=None):
return nsent, nrcvd
class DNSMessage:
def __init__(
message: Optional[dns.message.Message] = None,
wire: Optional[bytes] = None
) -> None:
self.message = message
self._wire = wire
self.is_raw_data = wire is not None
def wire(self) -> bytes:
if self.is_raw_data:
assert self._wire is not None
return self._wire
elif self.message is not None:
return self.message.to_wire(max_size=65535)
raise ValueError('No DNS message or raw wire data available!')
class ReplyNotFound(Exception):
class Entry:
Data entry represents scripted message and extra metadata,
......@@ -155,7 +181,7 @@ class Entry:
self.fired = 0
self.raw_data = None
self.raw_data = None # type: Optional[bytes]
self.is_raw_data_entry = self.process_raw()
......@@ -364,7 +390,7 @@ class Entry:
if expected != got:
raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
def _adjust_reply(self, query):
def _adjust_reply(self, query: dns.message.Message) -> dns.message.Message:
""" Copy scripted reply and adjust to received query. """
answer = dns.message.from_wire(self.message.to_wire(),
......@@ -386,7 +412,8 @@ class Entry:
assert len(answer.additional) == len(self.message.additional)
return answer
def _adjust_raw_reply(self, query):
def _adjust_raw_reply(self, query: dns.message.Message) -> bytes:
assert self.raw_data is not None
if 'raw_id' in self.adjust_fields:
assert len(self.raw_data) >= 2, "RAW message has to contain at least 2 bytes"
raw_answer = bytearray(self.raw_data)
......@@ -394,12 +421,14 @@ class Entry:
return bytes(raw_answer)
return self.raw_data
def reply(self, query):
if self.ignore:
return None, True
def reply(self, query) -> Optional[DNSMessage]:
if 'do_not_answer' in self.adjust_fields:
return None
if self.is_raw_data_entry:
return self._adjust_raw_reply(query), True
return self._adjust_reply(query), False
wire = self._adjust_raw_reply(query)
return DNSMessage(wire=wire)
msg = self._adjust_reply(query)
return DNSMessage(msg)
def set_edns(self, fields):
""" Set EDNS version and bufsize. """
......@@ -477,15 +506,8 @@ class Range:
or address in self.addresses)
return False
def reply(self, query):
Get answer for given query (adjusted if needed).
Returns: (answer, is_raw_data)
answer: DNS message object or wire-format (bytes) or None if there is no
candidate in this range
is_raw_data: True if response is wire-format bytes, instead of DNS message object
def reply(self, query: dns.message.Message) -> Optional[DNSMessage]:
"""Get answer for given query (adjusted if needed)."""
self.received += 1
for candidate in self.stored:
......@@ -494,13 +516,13 @@ class Range:
# Probabilistic loss
if 'LOSS' in self.args:
if random.random() < float(self.args['LOSS']):
return None, None
raise ReplyNotFound
self.sent += 1
candidate.fired += 1
return resp
except ValueError:
return None, None
raise ReplyNotFound
class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods
......@@ -758,18 +780,12 @@ class Scenario:
txt += "\nSCENARIO_END"
return txt
def reply(self, query, address=None):
Generate answer packet for given query.
The answer can be DNS message object or a binary blob.
(answer, boolean "is the answer binary blob?")
def reply(self, query: dns.message.Message, address=None) -> Optional[DNSMessage]:
"""Generate answer packet for given query."""
current_step_id =
# Unknown address, select any match
# TODO: workaround until the server supports stub zones
all_addresses = set()
all_addresses = set() # type: ignore
for rng in self.ranges:
if address not in all_addresses:
......@@ -790,7 +806,7 @@ class Scenario:
return candidate.reply(query)
except (IndexError, ValueError):
return None, False
raise ReplyNotFound
def play(self, paddr):
""" Play given scenario. """
......@@ -85,34 +85,40 @@ class TestServer:
True if client socket should be closed by caller
False if client socket should be kept open
log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
response, second = self.scenario.reply(query, server_addr)
if response:
is_raw_data = second
if not is_raw_data:
data_to_wire = response.to_wire(max_size=65535)
log.debug('response: %s', response)
data_to_wire = response
log.debug('raw response not printed')
elif not second:
def servfail_reply() -> bytes:
response = dns.message.make_response(query)
data_to_wire = response.to_wire()
self.undefined_answers += 1
'server %s has no response for question %s, answering with SERVFAIL',
'; '.join([str(rr) for rr in query.question]))
return response.to_wire()
log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
message = self.scenario.reply(query, server_addr)
except scenario.ReplyNotFound:
data_to_wire = servfail_reply()
# Just ignore
return True
if not message:
return True
if message.is_raw_data:
log.debug('raw response not printed')
log.debug('response: %s', message.message)
data_to_wire = message.wire
except ValueError:
data_to_wire = servfail_reply()
scenario.sendto_msg(client, data_to_wire, client_addr)
return True
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment