Commit 2cc7b9ba authored by Tomas Krizek's avatar Tomas Krizek

mypy: fix mypy errors (for mypy 0.6)

parent 04429d99
......@@ -17,7 +17,7 @@ SaveFunction = Optional[Callable[[Any], Any]]
class Reply:
def __init__(self, wire: WireFormat, duration: float) -> None:
def __init__(self, wire: Optional[WireFormat], duration: float) -> None:
self.wire = wire
self.duration = duration
......@@ -307,6 +307,11 @@ class Summary(Disagreements):
reproducibility_threshold: float = 1
) -> 'Summary':
"""Get summary of disagreements above the specified reproduciblity threshold (0, 1]."""
if (report.other_disagreements is None
or report.target_disagreements is None
or report.total_answers is None):
raise RuntimeError("Report has insufficient data to create Summary")
summary = Summary()
summary.upstream_unstable = len(report.other_disagreements)
......@@ -448,8 +453,7 @@ class DiffReport(JSONDataObject): # pylint: disable=too-many-instance-attribute
return DiffReport(data=data)
@property
def duration(self) -> int:
try:
return self.end_time - self.start_time
except TypeError:
def duration(self) -> Optional[int]:
if self.end_time is None or self.start_time is None:
return None
return self.end_time - self.start_time
......@@ -51,6 +51,8 @@ def disagreement_query_stream(
skip_non_reproducible: bool = True,
shuffle: bool = True
) -> Iterator[Tuple[QKey, WireFormat]]:
if report.target_disagreements is None or report.reprodata is None:
raise RuntimeError("Report doesn't contain necessary data!")
qids = report.target_disagreements.keys() # type: Union[Sequence[QID], AbstractSet[QID]]
if shuffle:
# create a new, randomized list from disagreements
......@@ -86,6 +88,8 @@ def process_answers(
criteria: Sequence[FieldLabel],
target: ResolverID
) -> None:
if report.target_disagreements is None or report.reprodata is None:
raise RuntimeError("Report doesn't contain necessary data!")
qid = key2qid(qkey)
reprocounter = report.reprodata[qid]
wire_dict = pickle.loads(replies)
......@@ -95,6 +99,7 @@ def process_answers(
reprocounter.retries += 1
if others_agree:
reprocounter.upstream_stable += 1
assert mismatches is not None
if Diff(qid, mismatches) == report.target_disagreements[qid]:
reprocounter.verified += 1
......
......@@ -4,7 +4,9 @@ import argparse
import collections
import logging
import sys
from typing import Any, Callable, Iterable, Iterator, ItemsView, List, Set, Tuple, Union # noqa
from typing import ( # noqa
Any, Callable, Iterable, Iterator, ItemsView, List, Optional, Set, Tuple,
Union)
import dns.message
import dns.rdatatype
......@@ -23,6 +25,8 @@ GLOBAL_STATS_PCT_FORMAT = '{:21s} {:8d} {:5.2f} % {:s}'
def print_global_stats(report: DiffReport) -> None:
if report.total_answers is None or report.total_queries is None:
raise RuntimeError("Report doesn't contain sufficient data to print statistics!")
print('== Global statistics')
print(GLOBAL_STATS_FORMAT.format('duration', '{:d} s'.format(report.duration)))
print(GLOBAL_STATS_FORMAT.format('queries', report.total_queries))
......@@ -101,7 +105,7 @@ def print_mismatch_queries(
field: FieldLabel,
mismatch: DataMismatch,
queries: Iterator[Tuple[QID, WireFormat]],
limit: int = DEFAULT_LIMIT,
limit: Optional[int] = DEFAULT_LIMIT,
qwire_to_text_func: Callable[[WireFormat], str] = qwire_to_qname_qtype
) -> None:
occurences = collections.Counter() # type: collections.Counter
......
......@@ -17,7 +17,7 @@ from dbhelper import LMDB, key2qid
from sendrecv import ResolverID
lmdb = None # type: Optional[Any]
lmdb = None
def compare_val(exp_val: MismatchValue, got_val: MismatchValue):
......@@ -160,6 +160,8 @@ def decode_wire_dict(
def read_answers_lmdb(qid: QID) -> Mapping[ResolverID, dns.message.Message]:
if lmdb is None:
raise RuntimeError("LMDB wasn't initialized!")
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
blob = txn.get(qid)
......
......@@ -5,7 +5,7 @@ import logging
import multiprocessing.pool as pool
import struct
import sys
from typing import Tuple
from typing import Optional, Tuple
import dpkt
import dns.exception
......@@ -45,7 +45,9 @@ def parse_pcap(pcap_file):
yield (i, wire, '')
def wrk_process_line(args: Tuple[int, str, str]) -> Tuple[bytes, bytes]:
def wrk_process_line(
args: Tuple[int, str, str]
) -> Tuple[Optional[bytes], Optional[bytes]]:
"""
Worker: parse input line, creates a packet in binary format
......@@ -69,7 +71,11 @@ def wrk_process_packet(args: Tuple[int, bytes, str]):
wrk_process_wire_packet(qid, wire, log_repr)
def wrk_process_wire_packet(qid: int, wire_packet: bytes, log_repr: str) -> Tuple[bytes, bytes]:
def wrk_process_wire_packet(
qid: int,
wire_packet: bytes,
log_repr: str
) -> Tuple[Optional[bytes], Optional[bytes]]:
"""
Worker: Return packet's data if it's not blacklisted
......
......@@ -44,7 +44,7 @@ __resolvers = [] # type: Sequence[Tuple[ResolverID, IP, Protocol, Port]]
__worker_state = threading.local()
__max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver
__ignore_timeout = False
__timeout = None
__timeout = 10
__time_delay_min = 0
__time_delay_max = 0
__timeout_replies = {} # type: Dict[float, Reply]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment