Commit a5710bba authored by Petr Špaček's avatar Petr Špaček

Merge branch 'query-drop' into 'master'

scenario: ignore selected queries

Closes #4

See merge request !137
parents f011d004 ac0c32b8
Pipeline #41733 failed with stage
in 2 minutes and 10 seconds
......@@ -17,3 +17,4 @@ coverage.xml
*.cover
.hypothesis/
.pytest_cache/
.mypy_cache/
......@@ -33,7 +33,7 @@ let hex_re = /[0-9a-fA-F]+/
let match_option = "opcode" | "qtype" | "qcase" | "qname" | "subdomain" | "flags" | "rcode" | "question" | "answer" | "authority" | "additional" | "all" | "edns"
let adjust_option = "copy_id" | "copy_query" | "raw_id"
let adjust_option = "copy_id" | "copy_query" | "raw_id" | "do_not_answer"
let reply_option = "QR" | "TC" | "AA" | "AD" | "RD" | "RA" | "CD" | "DO" | "NOERROR" | "FORMERR" | "SERVFAIL" | "NXDOMAIN" | "NOTIMP" | "REFUSED" | "YXDOMAIN" | "YXRRSET" | "NXRRSET" | "NOTAUTH" | "NOTZONE" | "BADVERS" | "BADSIG" | "BADKEY" | "BADTIME" | "BADMODE" | "BADNAME" | "BADALG" | "BADTRUNC" | "BADCOOKIE"
let step_option = "REPLY" | "QUERY" | "CHECK_ANSWER" | "CHECK_OUT_QUERY" | /TIME_PASSES[ \t]+ELAPSE/
......
# FIXME pylint: disable=too-many-lines
from abc import ABC
import binascii
import calendar
from datetime import datetime
import errno
import logging
import os
......@@ -9,7 +12,7 @@ import socket
import string
import struct
import time
from datetime import datetime
from typing import Optional
import dns.dnssec
import dns.message
......@@ -135,6 +138,106 @@ def replay_rrs(rrs, nqueries, destination, args=None):
return nsent, nrcvd
class DNSBlob(ABC):
def to_wire(self) -> bytes:
raise NotImplementedError
def __str__(self) -> str:
return '<DNSBlob>'
class DNSMessage(DNSBlob):
def __init__(self, message: dns.message.Message) -> None:
assert message is not None
self.message = message
def to_wire(self) -> bytes:
return self.message.to_wire(max_size=65535)
def __str__(self) -> str:
return str(self.message)
class DNSReply(DNSMessage):
def __init__(
self,
message: dns.message.Message,
query: Optional[dns.message.Message] = None,
copy_id: bool = False,
copy_query: bool = False
) -> None:
super().__init__(message)
if copy_id or copy_query:
if query is None:
raise ValueError("query must be provided to adjust copy_id/copy_query")
self.adjust_reply(query, copy_id, copy_query)
def adjust_reply(
self,
query: dns.message.Message,
copy_id: bool = True,
copy_query: bool = True
) -> None:
answer = dns.message.from_wire(self.message.to_wire(),
xfr=self.message.xfr,
one_rr_per_rrset=True)
answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
if copy_id:
answer.id = query.id
# Copy letter-case if the template has QD
if answer.question:
answer.question[0].name = query.question[0].name
if copy_query:
answer.question = query.question
# Re-set, as the EDNS might have reset the ext-rcode
answer.set_rcode(self.message.rcode())
# sanity check: adjusted answer should be almost the same
assert len(answer.answer) == len(self.message.answer)
assert len(answer.authority) == len(self.message.authority)
assert len(answer.additional) == len(self.message.additional)
self.message = answer
class DNSReplyRaw(DNSBlob):
def __init__(
self,
wire: bytes,
query: Optional[dns.message.Message] = None,
copy_id: bool = False
) -> None:
assert wire is not None
self.wire = wire
if copy_id:
self.adjust_reply(query, copy_id)
def adjust_reply(
self,
query: dns.message.Message,
copy_id: bool = True
) -> None:
if copy_id:
if len(self.wire) < 2:
raise ValueError(
'wire data must contain at least 2 bytes to adjust query id')
raw_answer = bytearray(self.wire)
struct.pack_into('!H', raw_answer, 0, query.id)
self.wire = bytes(raw_answer)
def to_wire(self) -> bytes:
return self.wire
def __str__(self) -> str:
return '<DNSReplyRaw>'
class DNSReplyServfail(DNSMessage):
def __init__(self, query: dns.message.Message) -> None:
message = dns.message.make_response(query)
message.set_rcode(dns.rcode.SERVFAIL)
super().__init__(message)
class Entry:
"""
Data entry represents scripted message and extra metadata,
......@@ -155,7 +258,7 @@ class Entry:
self.fired = 0
# RAW
self.raw_data = None
self.raw_data = None # type: Optional[bytes]
self.is_raw_data_entry = self.process_raw()
# MATCH
......@@ -364,40 +467,16 @@ class Entry:
if expected != got:
raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
def _adjust_reply(self, query):
""" Copy scripted reply and adjust to received query. """
answer = dns.message.from_wire(self.message.to_wire(),
xfr=self.message.xfr,
one_rr_per_rrset=True)
answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
if 'copy_id' in self.adjust_fields:
answer.id = query.id
# Copy letter-case if the template has QD
if answer.question:
answer.question[0].name = query.question[0].name
if 'copy_query' in self.adjust_fields:
answer.question = query.question
# Re-set, as the EDNS might have reset the ext-rcode
answer.set_rcode(self.message.rcode())
# sanity check: adjusted answer should be almost the same
assert len(answer.answer) == len(self.message.answer)
assert len(answer.authority) == len(self.message.authority)
assert len(answer.additional) == len(self.message.additional)
return answer
def _adjust_raw_reply(self, query):
if 'raw_id' in self.adjust_fields:
assert len(self.raw_data) >= 2, "RAW message has to contain at least 2 bytes"
raw_answer = bytearray(self.raw_data)
struct.pack_into('!H', raw_answer, 0, query.id)
return bytes(raw_answer)
return self.raw_data
def reply(self, query):
def reply(self, query) -> Optional[DNSBlob]:
if 'do_not_answer' in self.adjust_fields:
return None
if self.is_raw_data_entry:
return self._adjust_raw_reply(query), True
return self._adjust_reply(query), False
copy_id = 'raw_data' in self.adjust_fields
assert self.raw_data is not None
return DNSReplyRaw(self.raw_data, query, copy_id)
copy_id = 'copy_id' in self.adjust_fields
copy_query = 'copy_query' in self.adjust_fields
return DNSReply(self.message, query, copy_id, copy_query)
def set_edns(self, fields):
""" Set EDNS version and bufsize. """
......@@ -475,15 +554,8 @@ class Range:
or address in self.addresses)
return False
def reply(self, query):
"""
Get answer for given query (adjusted if needed).
Returns: (answer, is_raw_data)
answer: DNS message object or wire-format (bytes) or None if there is no
candidate in this range
is_raw_data: True if response is wire-format bytes, instead of DNS message object
"""
def reply(self, query: dns.message.Message) -> Optional[DNSBlob]:
"""Get answer for given query (adjusted if needed)."""
self.received += 1
for candidate in self.stored:
try:
......@@ -492,13 +564,13 @@ class Range:
# Probabilistic loss
if 'LOSS' in self.args:
if random.random() < float(self.args['LOSS']):
return None, None
return DNSReplyServfail(query)
self.sent += 1
candidate.fired += 1
return resp
except ValueError:
pass
return None, None
return DNSReplyServfail(query)
class StepLogger(logging.LoggerAdapter): # pylint: disable=too-few-public-methods
......@@ -756,18 +828,12 @@ class Scenario:
txt += "\nSCENARIO_END"
return txt
def reply(self, query, address=None):
"""
Generate answer packet for given query.
The answer can be DNS message object or a binary blob.
Returns:
(answer, boolean "is the answer binary blob?")
"""
def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]:
"""Generate answer packet for given query."""
current_step_id = self.current_step.id
# Unknown address, select any match
# TODO: workaround until the server supports stub zones
all_addresses = set()
all_addresses = set() # type: ignore
for rng in self.ranges:
all_addresses.update(rng.addresses)
if address not in all_addresses:
......@@ -788,7 +854,7 @@ class Scenario:
return candidate.reply(query)
except (IndexError, ValueError):
pass
return None, True
return DNSReplyServfail(query)
def play(self, paddr):
""" Play given scenario. """
......
......@@ -91,25 +91,21 @@ class TestServer:
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
response, is_raw_data = self.scenario.reply(query, server_addr)
if response:
if not is_raw_data:
data_to_wire = response.to_wire(max_size=65535)
log.debug('response: %s', response)
else:
data_to_wire = response
log.debug('raw response not printed')
else:
response = dns.message.make_response(query)
response.set_rcode(dns.rcode.SERVFAIL)
data_to_wire = response.to_wire()
message = self.scenario.reply(query, server_addr)
if not message:
log.debug('ignoring')
return True
elif isinstance(message, scenario.DNSReplyServfail):
self.undefined_answers += 1
self.scenario.current_step.log.error(
'server %s has no response for question %s, answering with SERVFAIL',
server_addr,
'; '.join([str(rr) for rr in query.question]))
else:
log.debug('response: %s', message)
scenario.sendto_msg(client, data_to_wire, client_addr)
scenario.sendto_msg(client, message.to_wire(), client_addr)
return True
def query_io(self):
......
......@@ -13,6 +13,7 @@ disable=
invalid-name,
global-statement,
no-else-return,
bad-continuation,
[SIMILARITIES]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment