Verified Commit 11207635 authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Petr Špaček

scenario: now using refactored version of matchpart

parent dfa4fcb0
......@@ -21,6 +21,7 @@ import dns.rrset
import dns.tsigkeyring
import pydnstest.augwrap
import pydnstest.matchpart
def str2bool(v):
......@@ -32,39 +33,6 @@ def str2bool(v):
g_rtt = 0.0
g_nqueries = 0
#
# Element comparators
#
def compare_rrs(expected, got):
""" Compare lists of RR sets, throw exception if different. """
for rr in expected:
if rr not in got:
raise ValueError("expected record '%s'" % rr.to_text())
for rr in got:
if rr not in expected:
raise ValueError("unexpected record '%s'" % rr.to_text())
if len(expected) != len(got):
raise ValueError("expected %s records but got %s records "
"(a duplicate RR somewhere?)"
% (len(expected), len(got)))
return True
def compare_val(expected, got):
""" Compare values, throw exception if different. """
if expected != got:
raise ValueError("expected '%s', got '%s'" % (expected, got))
return True
def compare_sub(got, expected):
""" Check if got subdomain of expected, throw exception if different. """
if not expected.is_subdomain(got):
raise ValueError("expected subdomain of '%s', got '%s'" % (expected, got))
return True
def recvfrom_msg(stream, raw=False):
"""
......@@ -352,67 +320,6 @@ class Entry:
return None
return opcodes[0]
def match_part(self, code, msg):
""" Compare scripted reply to given message using single criteria. """
if code not in self.match_fields and 'all' not in self.match_fields:
return True
expected = self.message
if code == 'opcode':
return compare_val(expected.opcode(), msg.opcode())
elif code == 'qtype':
if not expected.question:
return True
return compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
elif code == 'qname':
if not expected.question:
return True
qname = dns.name.from_text(msg.question[0].name.to_text().lower())
return compare_val(expected.question[0].name, qname)
elif code == 'qcase':
return compare_val(msg.question[0].name.labels, expected.question[0].name.labels)
elif code == 'subdomain':
if not expected.question:
return True
qname = dns.name.from_text(msg.question[0].name.to_text().lower())
return compare_sub(expected.question[0].name, qname)
elif code == 'flags':
return compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
elif code == 'rcode':
return compare_val(dns.rcode.to_text(expected.rcode()), dns.rcode.to_text(msg.rcode()))
elif code == 'question':
return compare_rrs(expected.question, msg.question)
elif code == 'answer' or code == 'ttl':
return compare_rrs(expected.answer, msg.answer)
elif code == 'authority':
return compare_rrs(expected.authority, msg.authority)
elif code == 'additional':
return compare_rrs(expected.additional, msg.additional)
elif code == 'edns':
if msg.edns != expected.edns:
raise ValueError('expected EDNS %d, got %d' % (expected.edns, msg.edns))
if msg.payload != expected.payload:
raise ValueError('expected EDNS bufsize %d, got %d'
% (expected.payload, msg.payload))
elif code == 'nsid':
nsid_opt = None
for opt in expected.options:
if opt.otype == dns.edns.NSID:
nsid_opt = opt
break
# Find matching NSID
for opt in msg.options:
if opt.otype == dns.edns.NSID:
if not nsid_opt:
raise ValueError('unexpected NSID value "%s"' % opt.data)
if opt == nsid_opt:
return True
else:
raise ValueError('expected NSID "%s", got "%s"' % (nsid_opt.data, opt.data))
if nsid_opt:
raise ValueError('expected NSID "%s"' % nsid_opt.data)
else:
raise ValueError('unknown match request "%s"' % code)
def match(self, msg):
""" Compare scripted reply to given message based on match criteria. """
match_fields = self.match_fields
......@@ -421,8 +328,8 @@ class Entry:
match_fields += ['flags'] + ['rcode'] + self.sections
for code in match_fields:
try:
self.match_part(code, msg)
except ValueError as ex:
pydnstest.matchpart.match_part(self.message, msg, code)
except pydnstest.matchpart.DataMismatch as ex:
errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
# TODO: cisla radku
raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
......@@ -503,6 +410,7 @@ class Range:
self.node = node
self.a = int(node['/from'].value)
self.b = int(node['/to'].value)
assert self.a <= self.b
address = node["/address"].value
self.addresses = {address} if address is not None else set()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment