Commit d9ab5f37 authored by Tomas Krizek's avatar Tomas Krizek

mypy: add additional annotations

parent 9b17273f
......@@ -9,6 +9,7 @@ from typing import ( # noqa
MismatchValue = Union[str, Sequence[str]]
ResolverID = str
QID = int
WireFormat = bytes
FieldLabel = str
......@@ -122,7 +123,7 @@ class Diff(collections.abc.Mapping):
__setitem__ = None
__delitem__ = None
def __init__(self, qid: QID, mismatches: Dict[FieldLabel, DataMismatch]) -> None:
def __init__(self, qid: QID, mismatches: Mapping[FieldLabel, DataMismatch]) -> None:
super(Diff, self).__init__()
self.qid = qid
self._mismatches = mismatches
......
......@@ -7,20 +7,22 @@ import multiprocessing.pool as pool
import os
import pickle
import sys
from typing import Dict
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
import dns.message
import dns.exception
import cfg
from dataformat import Reply, DataMismatch, DiffReport, Disagreements, DisagreementsCounter
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter, FieldLabel, MismatchValue,
Reply, ResolverID, QID)
from dbhelper import LMDB, key2qid
lmdb = None
lmdb = None # type: Optional[Any]
def compare_val(exp_val, got_val):
def compare_val(exp_val: MismatchValue, got_val: MismatchValue):
""" Compare values, throw exception if different. """
if exp_val != got_val:
raise DataMismatch(exp_val, got_val)
......@@ -122,7 +124,11 @@ def match_part(exp_msg, got_msg, code): # pylint: disable=inconsistent-return-s
raise NotImplementedError('unknown match request "%s"' % code)
def match(expected, got, match_fields):
def match(
expected: dns.message.Message,
got: dns.message.Message,
match_fields: Sequence[FieldLabel]
) -> Iterator[Tuple[FieldLabel, DataMismatch]]:
""" Compare scripted reply to given message based on match criteria. """
if expected is None or got is None:
if expected is not None:
......@@ -137,9 +143,10 @@ def match(expected, got, match_fields):
yield (code, ex)
def decode_wire_dict(wire_dict: Dict[str, Reply]) \
-> Dict[str, dns.message.Message]:
answers = {} # type: Dict[str, dns.message.Message]
def decode_wire_dict(
wire_dict: Mapping[ResolverID, Reply]
) -> Mapping[ResolverID, dns.message.Message]:
answers = {} # type: Dict[ResolverID, dns.message.Message]
for k, v in wire_dict.items():
# decode bytes to dns.message objects
# convert from wire format to DNS message object
......@@ -154,7 +161,7 @@ def decode_wire_dict(wire_dict: Dict[str, Reply]) \
return answers
def read_answers_lmdb(qid):
def read_answers_lmdb(qid: QID) -> Mapping[ResolverID, dns.message.Message]:
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
blob = txn.get(qid)
......@@ -163,14 +170,20 @@ def read_answers_lmdb(qid):
return decode_wire_dict(wire_dict)
def diff_pair(answers, criteria, name1, name2):
"""
Returns: sequence of (field, DataMismatch())
"""
def diff_pair(
answers: Mapping[ResolverID, dns.message.Message],
criteria: Sequence[FieldLabel],
name1: ResolverID,
name2: ResolverID
) -> Iterator[Tuple[FieldLabel, DataMismatch]]:
yield from match(answers[name1], answers[name2], criteria)
def transitive_equality(answers, criteria, resolvers):
def transitive_equality(
answers: Mapping[ResolverID, dns.message.Message],
criteria: Sequence[FieldLabel],
resolvers: Sequence[ResolverID]
) -> bool:
"""
Compare answers from all resolvers.
Optimization is based on transitivity of equivalence relation.
......@@ -183,7 +196,11 @@ def transitive_equality(answers, criteria, resolvers):
res_others))
def compare(answers, criteria, target):
def compare(
answers: Mapping[ResolverID, dns.message.Message],
criteria: Sequence[FieldLabel],
target: ResolverID
) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]:
others = list(answers.keys())
try:
others.remove(target)
......
......@@ -3,15 +3,15 @@ import socket
import ssl
import struct
import time
from typing import Dict # noqa: used in comment type hint
from typing import Dict, Mapping, Tuple # noqa
import dns.inet
import dns.message
import dataformat
from dataformat import Reply, ResolverID
TIMEOUT_REPLIES = {} # type: Dict[int, dataformat.Reply]
TIMEOUT_REPLIES = {} # type: Dict[float, Reply]
def sock_init(resolvers):
......@@ -66,14 +66,19 @@ def _recv_msg(sock, isstream):
return sock.recv(length)
def send_recv_parallel(dgram, selector, sockets, timeout):
def send_recv_parallel(
dgram,
selector,
sockets,
timeout: float
) -> Tuple[Mapping[ResolverID, Reply], bool]:
"""
dgram: DNS message in binary format suitable for UDP transport
"""
replies = {}
replies = {} # type: Dict[ResolverID, Reply]
streammsg = None
# optimization: create only one timeout_reply object per timeout value
timeout_reply = TIMEOUT_REPLIES.setdefault(timeout, dataformat.Reply(None, timeout))
timeout_reply = TIMEOUT_REPLIES.setdefault(timeout, Reply(None, timeout))
start_time = time.perf_counter()
end_time = start_time + timeout
for _, sock, isstream in sockets:
......@@ -104,7 +109,7 @@ def send_recv_parallel(dgram, selector, sockets, timeout):
# assert len(wire) > 14
if dgram[0:2] != wire[0:2]:
continue # wrong msgid, this might be a delayed answer - ignore it
replies[name] = dataformat.Reply(wire, time.perf_counter() - start_time)
replies[name] = Reply(wire, time.perf_counter() - start_time)
# set missing replies as timeout
for resolver, *_ in sockets:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment