Commit c192819c authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Petr Špaček

matchpart: refactors out match_part function to module

Now it can be used independently from, and maybe even from respdiff in the future.
parent 5d2fd650
"""matchpart is used to compare two DNS messages using a single criterion"""
import dns.rcode
import dns.edns
class DataMismatch(Exception):
def __init__(self, exp_val, got_val):
self.exp_val = exp_val
self.got_val = got_val
def __str__(self):
return 'expected "{0.exp_val}" got "{0.got_val}"'.format(self)
def __hash__(self):
return hash((self.exp_val, self.got_val))
def __eq__(self, other):
return (isinstance(other, DataMismatch)
and self.exp_val == other.exp_val
and self.got_val == other.got_val)
def __ne__(self, other):
return not self.__eq__(other)
def compare_val(exp, got):
"""Compare arbitraty objects, throw exception if different. """
if exp != got:
raise DataMismatch(exp, got)
return True
def compare_rrs(expected, got):
""" Compare lists of RR sets, throw exception if different. """
for rr in expected:
if rr not in got:
raise DataMismatch(expected, got)
for rr in got:
if rr not in expected:
raise DataMismatch(expected, got)
if len(expected) != len(got):
raise DataMismatch(expected, got)
return True
def compare_rrs_types(exp_val, got_val, skip_rrsigs):
"""sets of RR types in both sections must match"""
def rr_ordering_key(rrset):
if rrset.covers:
return rrset.covers, 1 # RRSIGs go to the end of RRtype list
return rrset.rdtype, 0
def key_to_text(rrtype, rrsig):
if not rrsig:
return dns.rdatatype.to_text(rrtype)
return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype)
if skip_rrsigs:
exp_val = (rrset for rrset in exp_val
if rrset.rdtype != dns.rdatatype.RRSIG)
got_val = (rrset for rrset in got_val
if rrset.rdtype != dns.rdatatype.RRSIG)
exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
if exp_types != got_types:
exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
got_types = tuple(key_to_text(*i) for i in sorted(got_types))
raise DataMismatch(exp_types, got_types)
def match_opcode(exp, got):
return compare_val(exp.opcode(),
def match_qtype(exp, got):
if not exp.question:
return True
return compare_val(exp.question[0].rdtype,
def match_qname(exp, got):
if not exp.question:
return True
return compare_val(exp.question[0].name,
def match_qcase(exp, got):
return compare_val(exp.question[0].name.labels,
def match_subdomain(exp, got):
if not exp.question:
return True
qname =[0].name.to_text().lower())
if exp.question[0].name.is_superdomain(qname):
return True
raise DataMismatch(exp, got)
def match_flags(exp, got):
return compare_val(dns.flags.to_text(exp.flags),
def match_rcode(exp, got):
return compare_val(dns.rcode.to_text(exp.rcode()),
def match_question(exp, got):
return compare_rrs(exp.question,
def match_answer(exp, got):
return compare_rrs(exp.answer,
def match_ttl(exp, got):
return compare_rrs(exp.answer,
def match_answertypes(exp, got):
return compare_rrs_types(exp.answer,
got.answer, skip_rrsigs=True)
def match_answerrrsigs(exp, got):
return compare_rrs_types(exp.answer,
got.answer, skip_rrsigs=False)
def match_authority(exp, got):
return compare_rrs(exp.authority,
def match_additional(exp, got):
return compare_rrs(exp.additional,
def match_edns(exp, got):
if got.edns != exp.edns:
raise DataMismatch(exp.edns,
if got.payload != exp.payload:
raise DataMismatch(exp.payload,
def match_nsid(exp, got):
nsid_opt = None
for opt in exp.options:
if opt.otype == dns.edns.NSID:
nsid_opt = opt
# Find matching NSID
for opt in got.options:
if opt.otype == dns.edns.NSID:
if not nsid_opt:
raise DataMismatch(None,
if opt == nsid_opt:
return True
raise DataMismatch(,
if nsid_opt:
raise DataMismatch(, None)
return True
MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
"subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
"question": match_question, "answer": match_answer, "ttl": match_ttl,
"answertypes": match_answertypes, "answerrrsigs": match_answerrrsigs,
"authority": match_authority, "additional": match_additional, "edns": match_edns,
"nsid": match_nsid}
def match_part(exp, got, code):
return MATCH[code](exp, got)
except KeyError:
raise NotImplementedError('unknown match request "%s"' % code)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment