Commit ee3d740c authored by Tomas Krizek's avatar Tomas Krizek

dbhelper.DNSReply: add timeout property

parent 95a49bca
......@@ -30,7 +30,8 @@ test:pylint:
test:pytest:
<<: *debian_stable
script:
- export PYTHONPATH="response_differences/:$PYTHONPATH"
# TODO: https://gitlab.labs.nic.cz/knot/resolver-benchmarking/issues/36
- export PYTHONPATH="response_differences/:response_differences/respdiff/:$PYTHONPATH"
- python3 -m pytest response_differences
test:respdiff:
......
......@@ -128,6 +128,14 @@ class LMDB:
class DNSReply:
def __init__(self, wire: Optional[WireFormat], duration: float) -> None:
self.wire = wire
self.duration = duration
def __init__(self, wire: Optional[WireFormat], time: float = 0) -> None:
if wire is None:
self.wire = b''
self.time = float('+inf')
else:
self.wire = wire
self.time = time
@property
def timeout(self):
return self.time == float('+inf')
......@@ -83,7 +83,7 @@ def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]:
def process_answers(
qkey: QKey,
replies: RepliesBlob,
replies_blob: RepliesBlob,
report: DiffReport,
criteria: Sequence[FieldLabel],
target: ResolverID
......@@ -92,8 +92,8 @@ def process_answers(
raise RuntimeError("Report doesn't contain necessary data!")
qid = key2qid(qkey)
reprocounter = report.reprodata[qid]
wire_dict = pickle.loads(replies)
answers = msgdiff.decode_wire_dict(wire_dict)
replies = pickle.loads(replies_blob)
answers = msgdiff.decode_replies(replies)
others_agree, mismatches = msgdiff.compare(answers, criteria, target)
reprocounter.retries += 1
......
......@@ -2,6 +2,7 @@
import argparse
from functools import partial
import logging
import multiprocessing.pool as pool
import pickle
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
......@@ -142,20 +143,18 @@ def match(
yield (code, ex)
def decode_wire_dict(
wire_dict: Mapping[ResolverID, DNSReply]
def decode_replies(
replies: Mapping[ResolverID, DNSReply]
) -> Mapping[ResolverID, dns.message.Message]:
answers = {} # type: Dict[ResolverID, dns.message.Message]
for k, v in wire_dict.items():
# decode bytes to dns.message objects
# convert from wire format to DNS message object
if v.wire is None: # query timed out
answers[k] = None
for resolver, reply in replies.items():
if reply.timeout:
answers[resolver] = None
continue
try:
answers[k] = dns.message.from_wire(v.wire)
except Exception:
# answers[k] = ex # decoding failed, record it!
answers[resolver] = dns.message.from_wire(reply.wire)
except Exception as exc:
logging.warning('Failed to decode DNS message from wire format: %s', exc)
continue
return answers
......@@ -165,10 +164,10 @@ def read_answers_lmdb(qid: QID) -> Mapping[ResolverID, dns.message.Message]:
raise RuntimeError("LMDB wasn't initialized!")
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
blob = txn.get(qid)
assert blob
wire_dict = pickle.loads(blob)
return decode_wire_dict(wire_dict)
replies_blob = txn.get(qid)
assert replies_blob
replies = pickle.loads(replies_blob)
return decode_replies(replies)
def diff_pair(
......
......@@ -46,7 +46,7 @@ __ignore_timeout = False
__timeout = 10
__time_delay_min = 0
__time_delay_max = 0
__timeout_replies = {} # type: Dict[float, DNSReply]
__timeout_reply = DNSReply(None) # optimization: create only one timeout_reply object
def module_init(args: Namespace) -> None:
......@@ -204,8 +204,6 @@ def send_recv_parallel(
) -> Tuple[Mapping[ResolverID, DNSReply], ReinitFlag]:
replies = {} # type: Dict[ResolverID, DNSReply]
streammsg = None
# optimization: create only one timeout_reply object per timeout value
timeout_reply = __timeout_replies.setdefault(timeout, DNSReply(None, timeout))
start_time = time.perf_counter()
end_time = start_time + timeout
for _, sock, isstream in sockets:
......@@ -242,6 +240,6 @@ def send_recv_parallel(
# set missing replies as timeout
for resolver, *_ in sockets: # type: ignore # python/mypy#465
if resolver not in replies:
replies[resolver] = timeout_reply
replies[resolver] = __timeout_reply
return replies, reinit
from respdiff.dbhelper import DNSReply
def test_dns_reply_timeout():
reply = DNSReply(None)
assert reply.timeout
assert reply.time == float('+inf')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment