Commit 808d0f3d authored by Petr Špaček's avatar Petr Špaček

Deckard: PEP8 whitespace fixes

Cheap re-indentation using python-autopep8-1.2.1-3.fc25 with few manual
tweaks for very long lines.

This costs nothing and will avoid PEP8 complaints about whitespace in CI.
parent ed51c163
......@@ -21,6 +21,7 @@ import string
import itertools
import calendar
def str2bool(v):
""" Return conversion of JSON-ish string value to boolean. """
return v.lower() in ('yes', 'true', 'on')
......@@ -32,7 +33,7 @@ def del_files(path_to, delpath):
os.unlink(os.path.join(root, f))
if delpath == True:
try:
os.rmdir(path_to);
os.rmdir(path_to)
except:
pass
......@@ -47,9 +48,9 @@ DEFAULT_FEATURE_PAIR_DELIM = '='
if "SOCKET_WRAPPER_DEFAULT_IFACE" in os.environ:
DEFAULT_IFACE = int(os.environ["SOCKET_WRAPPER_DEFAULT_IFACE"])
if DEFAULT_IFACE < 2 or DEFAULT_IFACE > 254 :
if DEFAULT_IFACE < 2 or DEFAULT_IFACE > 254:
DEFAULT_IFACE = 2
os.environ["SOCKET_WRAPPER_DEFAULT_IFACE"]="{}".format(DEFAULT_IFACE)
os.environ["SOCKET_WRAPPER_DEFAULT_IFACE"] = "{}".format(DEFAULT_IFACE)
if "KRESD_WRAPPER_DEFAULT_IFACE" in os.environ:
CHILD_IFACE = int(os.environ["KRESD_WRAPPER_DEFAULT_IFACE"])
......@@ -72,7 +73,9 @@ if TMPDIR == "" or os.path.isdir(TMPDIR) is False:
if "VERBOSE" in os.environ:
try:
VERBOSE = int(os.environ["VERBOSE"])
except: pass
except:
pass
def find_objects(path):
""" Recursively scan file/directory for scenarios. """
......@@ -85,12 +88,14 @@ def find_objects(path):
result.append(path)
return result
def write_timestamp_file(path, tst):
time_file = open(path, 'w')
time_file.write(datetime.fromtimestamp(tst).strftime('@%Y-%m-%d %H:%M:%S'))
time_file.flush()
time_file.close()
def setup_env(scenario, child_env, config, config_name_list, j2template_list):
""" Set up test environment and config """
# Clear test directory
......@@ -100,7 +105,7 @@ def setup_env(scenario, child_env, config, config_name_list, j2template_list):
os.environ["FAKETIME_TIMESTAMP_FILE"] = '%s/.time' % TMPDIR
child_env["FAKETIME_NO_CACHE"] = "1"
child_env["FAKETIME_TIMESTAMP_FILE"] = '%s/.time' % TMPDIR
write_timestamp_file(child_env["FAKETIME_TIMESTAMP_FILE"], int (time.time()))
write_timestamp_file(child_env["FAKETIME_TIMESTAMP_FILE"], int(time.time()))
# Set up child process env()
child_env["SOCKET_WRAPPER_DEFAULT_IFACE"] = "%i" % CHILD_IFACE
child_env["SOCKET_WRAPPER_DIR"] = TMPDIR
......@@ -114,7 +119,7 @@ def setup_env(scenario, child_env, config, config_name_list, j2template_list):
feature_list_delimiter = DEFAULT_FEATURE_LIST_DELIM
feature_pair_delimiter = DEFAULT_FEATURE_PAIR_DELIM
selfaddr = testserver.get_local_addr_str(socket.AF_INET, DEFAULT_IFACE)
for k,v in config:
for k, v in config:
# Enable selectively for some tests
if k == 'query-minimization' and str2bool(v):
no_minimize = "false"
......@@ -131,33 +136,37 @@ def setup_env(scenario, child_env, config, config_name_list, j2template_list):
ovr_hr = override_date_str[8:10]
ovr_min = override_date_str[10:12]
ovr_sec = override_date_str[12:]
override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(ovr_yr,ovr_mnt,ovr_day,ovr_hr,ovr_min,ovr_sec)
override_date = time.strptime(override_date_str_arg,"%Y %m %d %H %M %S")
override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
override_date_timestamp = calendar.timegm(override_date)
write_timestamp_file(child_env["FAKETIME_TIMESTAMP_FILE"], override_date_timestamp)
elif k == 'stub-addr':
stub_addr = v.strip('"\'')
elif k == 'features':
feature_list = v.split(feature_list_delimiter)
try :
try:
for f_item in feature_list:
if f_item.find(feature_pair_delimiter) != -1:
f_key, f_value = [x.strip() for x in f_item.split(feature_pair_delimiter,1)]
f_key, f_value = [x.strip()
for x
in f_item.split(feature_pair_delimiter, 1)]
else:
f_key = f_item.strip()
f_value = ""
features[f_key] = f_value
except Exception as e:
raise Exception ("can't parse features (%s) in config section (%s)" % (v,str(e)));
raise Exception("can't parse features (%s) in config section (%s)" % (v, str(e)))
elif k == 'feature-list':
try :
f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter,1)]
try:
f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
if f_key not in features:
features[f_key] = []
f_value = f_value.replace("{{INSTALL_DIR}}",INSTALLDIR)
f_value = f_value.replace("{{INSTALL_DIR}}", INSTALLDIR)
features[f_key].append(f_value)
except Exception as e:
raise Exception ("can't parse feature-list (%s) in config section (%s)" % (v,str(e)));
raise Exception("can't parse feature-list (%s) in config section (%s)"
% (v, str(e)))
elif k == 'force-ipv6' and v.upper() == 'TRUE':
scenario.force_ipv6 = True
......@@ -180,24 +189,26 @@ def setup_env(scenario, child_env, config, config_name_list, j2template_list):
if sock_type & socket.SOCK_STREAM:
sock.listen(5)
# Generate configuration files
j2template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(os.path.abspath(__file__)))
j2template_loader = jinja2.FileSystemLoader(
searchpath=os.path.dirname(os.path.abspath(__file__)))
j2template_env = jinja2.Environment(loader=j2template_loader)
j2template_ctx = {
"ROOT_ADDR" : selfaddr,
"SELF_ADDR" : childaddr,
"NO_MINIMIZE" : no_minimize,
"TRUST_ANCHORS" : trust_anchor_list,
"WORKING_DIR" : TMPDIR,
"INSTALL_DIR" : INSTALLDIR,
"FEATURES" : features
"ROOT_ADDR": selfaddr,
"SELF_ADDR": childaddr,
"NO_MINIMIZE": no_minimize,
"TRUST_ANCHORS": trust_anchor_list,
"WORKING_DIR": TMPDIR,
"INSTALL_DIR": INSTALLDIR,
"FEATURES": features
}
for template_name, config_name in zip(j2template_list,config_name_list):
for template_name, config_name in zip(j2template_list, config_name_list):
j2template = j2template_env.get_template(template_name)
cfg_rendered = j2template.render(j2template_ctx)
f = open(os.path.join(TMPDIR,config_name), 'w')
f = open(os.path.join(TMPDIR, config_name), 'w')
f.write(cfg_rendered)
f.close()
def play_object(path, binary_name, config_name, j2template, binary_additional_pars):
""" Play scenario from a file object. """
......@@ -216,7 +227,7 @@ def play_object(path, binary_name, config_name, j2template, binary_additional_pa
daemon_proc = None
daemon_log = open('%s/server.log' % TMPDIR, 'w')
daemon_args = [binary_name] + binary_additional_pars
try :
try:
daemon_proc = subprocess.Popen(daemon_args, stdout=daemon_log, stderr=daemon_log,
cwd=TMPDIR, preexec_fn=os.setsid, env=daemon_env)
except Exception as e:
......@@ -233,10 +244,12 @@ def play_object(path, binary_name, config_name, j2template, binary_additional_pa
if daemon_proc.poll() != None:
server.stop()
print(open('%s/server.log' % TMPDIR).read())
raise Exception('process died "%s", logs in "%s"' % (os.path.basename(binary_name), TMPDIR))
raise Exception('process died "%s", logs in "%s"' %
(os.path.basename(binary_name), TMPDIR))
try:
sock.connect((testserver.get_local_addr_str(sockfamily, CHILD_IFACE), 53))
except: continue
except:
continue
break
sock.close()
......@@ -249,7 +262,10 @@ def play_object(path, binary_name, config_name, j2template, binary_additional_pa
for s in case.steps:
if s.type == 'REPLY':
reply = s.data[0].message
for rr in itertools.chain(reply.answer,reply.additional,reply.question,reply.authority):
for rr in itertools.chain(reply.answer,
reply.additional,
reply.question,
reply.authority):
for rd in rr:
if rd.rdtype == dns.rdatatype.A:
server.start_srv((rd.address, 53), socket.AF_INET)
......@@ -276,6 +292,7 @@ def play_object(path, binary_name, config_name, j2template, binary_additional_pa
# Do not clear files if the server crashed (for analysis)
del_files(TMPDIR, OWN_TMPDIR)
def test_platform(*args):
if sys.platform == 'windows':
raise Exception('not supported at all on Windows')
......@@ -283,7 +300,8 @@ def test_platform(*args):
if __name__ == '__main__':
if len(sys.argv) < 5:
print("Usage: test_integration.py <scenario> <binary> <template> <config name> [<additional>]")
print("Usage:")
print("test_integration.py <scenario> <binary> <template> <config name> [<additional>]")
print("\t<scenario> - path to scenario")
print("\t<binary> - executable to test")
print("\t<template> - colon-separated list of jinja2 template files")
......@@ -303,8 +321,8 @@ if __name__ == '__main__':
binary_name = sys.argv[2]
template_name_list = sys.argv[3].split(':')
config_name_list = sys.argv[4].split(':')
if len(template_name_list) != len (config_name_list):
print("ERROR: Number of j2 template files not equal to number of file names to be generated")
if len(template_name_list) != len(config_name_list):
print("ERROR: Number of j2 template files not equal to number of files to be generated")
print("i.e. len(<template>) != len(<config name>), see usage")
sys.exit(0)
......@@ -316,5 +334,6 @@ if __name__ == '__main__':
for arg in [path_to_scenario]:
objects = find_objects(arg)
for path in objects:
test.add(path, play_object, path, binary_name, config_name_list, template_name_list, binary_additional_pars)
test.add(path, play_object, path, binary_name, config_name_list,
template_name_list, binary_additional_pars)
sys.exit(test.run())
......@@ -5,6 +5,7 @@ import threading
dprint_lock = threading.Lock()
def dprint(tag, msg):
""" Verbose logging (if enabled). """
if 'VERBOSE' in os.environ:
......
......@@ -6,9 +6,14 @@ import dns.rcode
import dns.dnssec
import dns.tsigkeyring
import binascii
import socket, struct
import os, sys, errno
import itertools, random, string
import socket
import struct
import os
import sys
import errno
import itertools
import random
import string
import time
from datetime import datetime
from pydnstest.dprint import dprint
......@@ -23,7 +28,8 @@ g_nqueries = 0
# Element comparators
#
def create_rr(owner, args, ttl = 3600, rdclass = 'IN', origin = '.'):
def create_rr(owner, args, ttl=3600, rdclass='IN', origin='.'):
""" Parse RR from tokenized string. """
if not owner.endswith('.'):
owner += origin
......@@ -43,10 +49,12 @@ def create_rr(owner, args, ttl = 3600, rdclass = 'IN', origin = '.'):
if (rr.rdtype == dns.rdatatype.DS):
# convert textual algorithm identifier to number
args[1] = str(dns.dnssec.algorithm_from_text(args[1]))
rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(args), origin=dns.name.from_text(origin), relativize=False)
rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
args), origin=dns.name.from_text(origin), relativize=False)
rr.add(rd)
return rr
def compare_rrs(expected, got):
""" Compare lists of RR sets, throw exception if different. """
for rr in expected:
......@@ -68,13 +76,15 @@ def compare_val(expected, got):
raise Exception("expected '%s', got '%s'" % (expected, got))
return True
def compare_sub(got, expected):
""" Check if got subdomain of expected, throw exception if different. """
if not expected.is_subdomain(got):
raise Exception("expected subdomain of '%s', got '%s'" % (expected, got))
return True
def replay_rrs(rrs, nqueries, destination, args = []):
def replay_rrs(rrs, nqueries, destination, args=[]):
""" Replay list of queries and report statistics. """
navail, queries = len(rrs), []
chunksize = 16
......@@ -82,7 +92,8 @@ def replay_rrs(rrs, nqueries, destination, args = []):
rr = rrs[i % navail]
name = rr.name
if 'RAND' in args:
prefix = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(8)])
prefix = ''.join([random.choice(string.ascii_letters + string.digits)
for n in range(8)])
name = prefix + '.' + rr.name.to_text()
msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
if 'DO' in args:
......@@ -109,7 +120,7 @@ def replay_rrs(rrs, nqueries, destination, args = []):
nwait += 1
nsent += 1
except:
pass # EINVAL
pass # EINVAL
if len(to_read) > 0:
try:
while nwait > 0:
......@@ -119,13 +130,15 @@ def replay_rrs(rrs, nqueries, destination, args = []):
except:
pass
if len(to_write) == 0 and len(to_read) == 0:
nwait = 0 # Timeout, started dropping packets
nwait = 0 # Timeout, started dropping packets
break
return nsent, nrcvd
class Entry:
"""
Data entry represents scripted message and extra metadata, notably match criteria and reply adjustments.
Data entry represents scripted message and extra metadata,
notably match criteria and reply adjustments.
"""
# Globals
......@@ -133,20 +146,20 @@ class Entry:
default_cls = 'IN'
default_rc = 'NOERROR'
def __init__(self, lineno = 0):
def __init__(self, lineno=0):
""" Initialize data entry. """
self.match_fields = ['opcode', 'qtype', 'qname']
self.adjust_fields = ['copy_id']
self.origin = '.'
self.message = dns.message.Message()
self.message.use_edns(edns = 0, payload = 4096)
self.message.use_edns(edns=0, payload=4096)
self.sections = []
self.is_raw_data_entry = False
self.raw_data_pending = False
self.raw_data = None
self.lineno = lineno
self.mandatory = False
self.fired = 0;
self.fired = 0
def match_part(self, code, msg):
""" Compare scripted reply to given message using single criteria. """
......@@ -185,7 +198,8 @@ class Entry:
if msg.edns != expected.edns:
raise Exception('expected EDNS %d, got %d' % (expected.edns, msg.edns))
if msg.payload != expected.payload:
raise Exception('expected EDNS bufsize %d, got %d' % (expected.payload, msg.payload))
raise Exception('expected EDNS bufsize %d, got %d'
% (expected.payload, msg.payload))
elif code == 'nsid':
nsid_opt = None
for opt in expected.options:
......@@ -229,11 +243,15 @@ class Entry:
if raw_value is not None:
got = binascii.hexlify(raw_value)
if expected != got:
print("expected '",expected,"', got '",got,"'")
print("expected '", expected, "', got '", got, "'")
raise Exception("comparsion failed")
def set_match(self, fields):
""" Set conditions for message comparison [all, flags, question, answer, authority, additional, edns] """
"""
Set list of conditions for message comparison
[all, flags, question, answer, authority, additional, edns]
"""
self.match_fields = fields
def adjust_reply(self, query):
......@@ -303,11 +321,12 @@ class Entry:
prefix = len(addr) * 8
if len(net) > 1:
prefix = int(net[1])
addr = addr[0 : (prefix + 7)/8]
if prefix % 8 != 0: # Mask the last byte
addr = addr[0: (prefix + 7) / 8]
if prefix % 8 != 0: # Mask the last byte
addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
opts.append(dns.edns.GenericOption(8, struct.pack("!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
self.message.use_edns(edns = version, payload = bufsize, options = opts)
opts.append(dns.edns.GenericOption(8, struct.pack(
"!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
self.message.use_edns(edns=version, payload=bufsize, options=opts)
def begin_raw(self):
""" Set raw data pending flag. """
......@@ -331,7 +350,8 @@ class Entry:
self.raw_data_pending = False
self.is_raw_data_entry = True
else:
rr = create_rr(owner, args, ttl = self.default_ttl, rdclass = self.default_cls, origin = self.origin)
rr = create_rr(owner, args, ttl=self.default_ttl,
rdclass=self.default_cls, origin=self.origin)
if self.section == 'QUESTION':
if rr.rdtype == dns.rdatatype.AXFR:
self.message.xfr = True
......@@ -358,6 +378,7 @@ class Entry:
def set_mandatory(self):
self.mandatory = True
class Range:
"""
Range represents a set of scripted queries valid for given step range.
......@@ -448,7 +469,6 @@ class Step:
except Exception as e:
raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))
def add(self, entry):
""" Append a data entry to this step. """
self.data.append(entry)
......@@ -461,7 +481,7 @@ class Step:
# Parse QUERY-specific parameters
choice, tcp, source = None, False, None
for v in self.args:
if '=' in v: # Key=Value
if '=' in v: # Key=Value
v = v.split('=')
if v[0].lower() == 'source':
source = v[1]
......@@ -469,10 +489,10 @@ class Step:
tcp = True
else:
choice = v
return self.__query(ctx, tcp = tcp, choice = choice, source = source)
return self.__query(ctx, tcp=tcp, choice=choice, source=source)
elif self.type == 'CHECK_OUT_QUERY':
dprint(dtag, '')
pass # Ignore
pass # Ignore
elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
dprint(dtag, '')
return self.__check_answer(ctx)
......@@ -507,13 +527,14 @@ class Step:
dprint("", ctx.last_answer.to_text())
expected.match(ctx.last_answer)
def __replay(self, ctx, chunksize = 8):
def __replay(self, ctx, chunksize=8):
dtag = '[ STEP %03d ] %s' % (self.id, self.type)
nqueries = len(self.queries)
if len(self.args) > 0 and self.args[0].isdigit():
nqueries = int(self.args.pop(0))
destination = ctx.client[ctx.client.keys()[0]]
dprint(dtag, 'replaying %d queries to %s@%d (%s)' % (nqueries, destination[0], destination[1], ' '.join(self.args)))
dprint(dtag, 'replaying %d queries to %s@%d (%s)' %
(nqueries, destination[0], destination[1], ' '.join(self.args)))
if 'INTENSIFY' in os.environ:
nqueries *= int(os.environ['INTENSIFY'])
tstart = datetime.now()
......@@ -527,10 +548,10 @@ class Step:
if arg.upper().startswith('PRINT'):
_, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
if tag:
print(' [ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d' % (tag.ljust(11), pps, rtt, nsent, nrcvd))
print(' [ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d' %
(tag.ljust(11), pps, rtt, nsent, nrcvd))
def __query(self, ctx, tcp = False, choice = None, source = None):
def __query(self, ctx, tcp=False, choice=None, source=None):
"""
Send query and wait for an answer (if the query is not RAW).
......@@ -616,10 +637,13 @@ class Step:
subexpr.append(str(ee))
except:
subexpr.append(expr)
assert result is True, '"%s" assertion fails (%s)' % (' '.join(self.args), ' '.join(subexpr))
assert result is True, '"%s" assertion fails (%s)' % (
' '.join(self.args), ' '.join(subexpr))
class Scenario:
def __init__(self, info, filename = ''):
def __init__(self, info, filename=''):
""" Initialize scenario with description. """
self.info = info
self.file = filename
......@@ -630,7 +654,7 @@ class Scenario:
self.client = {}
self.force_ipv6 = False
def reply(self, query, address = None):
def reply(self, query, address=None):
"""
Generate answer packet for given query.
......@@ -685,19 +709,24 @@ class Scenario:
step.play(self)
except Exception as e:
if (step.repeat_if_fail > 0):
dprint ('[play]',"step %d: exception catched - '%s', retrying step %d (%d left)" % (step.id, e, step.next_if_fail, step.repeat_if_fail))
dprint("[play]",
"step %d: exception catched - '%s', retrying step %d (%d left)" %
(step.id, e, step.next_if_fail, step.repeat_if_fail))
step.repeat_if_fail -= 1
if (step.pause_if_fail > 0):
time.sleep(step.pause_if_fail)
if (step.next_if_fail != -1):
next_steps = [j for j in range(len(self.steps)) if self.steps[j].id == step.next_if_fail]
next_steps = [j for j in range(len(self.steps)) if self.steps[
j].id == step.next_if_fail]
if (len(next_steps) == 0):
raise Exception('step %d: wrong NEXT value "%d"' % (step.id, step.next_if_fail))
raise Exception('step %d: wrong NEXT value "%d"' %
(step.id, step.next_if_fail))
next_step = next_steps[0]
if (next_step < len(self.steps)):
i = next_step
else:
raise Exception('step %d: Can''t branch to NEXT value "%d"' % (step.id, step.next_if_fail))
raise Exception('step %d: Can''t branch to NEXT value "%d"' %
(step.id, step.next_if_fail))
continue
else:
raise Exception('%s step %d %s' % (self.file, step.id, str(e)))
......@@ -709,7 +738,7 @@ class Scenario:
raise Exception('Mandatory section at line %d is not fired' % e.lineno)
def get_next(file_in, skip_empty = True):
def get_next(file_in, skip_empty=True):
""" Return next token from the input stream. """
while True:
line = file_in.readline()
......@@ -735,14 +764,15 @@ def get_next(file_in, skip_empty = True):
op = tokens.pop(0)
return op, tokens
def parse_entry(op, args, file_in, in_entry = False):
def parse_entry(op, args, file_in, in_entry=False):
""" Parse entry definition. """
out = Entry(file_in.lineno())
for op, args in iter(lambda: get_next(file_in, in_entry), False):
if op == 'ENTRY_END' or op == '':
in_entry = False
break
elif op == 'ENTRY_BEGIN': # Optional, compatibility with Unbound tests
elif op == 'ENTRY_BEGIN': # Optional, compatibility with Unbound tests
if in_entry:
raise Exception('nested ENTRY_BEGIN not supported')
in_entry = True
......@@ -767,6 +797,7 @@ def parse_entry(op, args, file_in, in_entry = False):
out.add_record(op, args)
return out
def parse_queries(out, file_in):
""" Parse list of queries terminated by blank line. """
out.queries = []
......@@ -777,6 +808,8 @@ def parse_queries(out, file_in):
return out
auto_step = 0
def parse_step(op, args, file_in):
""" Parse range definition. """
global auto_step
......@@ -813,7 +846,7 @@ def parse_range(op, args, file_in):
if op == 'ADDRESS':
out.addresses.add(args[0])
elif op == 'ENTRY_BEGIN':
out.add(parse_entry(op, args, file_in, in_entry = True))
out.add(parse_entry(op, args, file_in, in_entry=True))
elif op == 'RANGE_END':
break
return out
......@@ -848,7 +881,7 @@ def parse_file(file_in):
line = line[0:line.index('#')]
# Break to key-value pairs
# e.g.: ['minimization', 'on']
kv = [x.strip() for x in line.split(':',1)]
kv = [x.strip() for x in line.split(':', 1)]
if len(kv) >= 2:
config.append(kv)
line = file_in.readline()
......
......@@ -3,6 +3,7 @@ import os
import traceback
import time
class Test:
""" Small library to imitate CMocka output. """
......
......@@ -13,7 +13,8 @@ import struct
import binascii
from pydnstest.dprint import dprint
def recvfrom_msg(stream, raw = False):
def recvfrom_msg(stream, raw=False):
"""
Receive DNS message from TCP/UDP socket.
......@@ -27,7 +28,7 @@ def recvfrom_msg(stream, raw = False):
data = stream.recv(2)
if len(data) == 0:
return None, None
msg_len = struct.unpack_from("!H",data)[0]
msg_len = struct.unpack_from("!H", data)[0]
data = b""
received = 0
while received < msg_len:
......@@ -53,13 +54,14 @@ def sendto_msg(stream, message, addr=None):
else:
stream.sendto(message, addr)
elif stream.type & socket.SOCK_STREAM:
data = struct.pack("!H",len(message)) + message
data = struct.pack("!H", len(message)) + message
stream.send(data)
else:
assert False, "[sendto_msg]: unknown socket type '%i'" % stream.type
except: # Failure to respond is OK, resolver should recover
except: # Failure to respond is OK, resolver should recover
pass
def get_local_addr_str(family, iface):
""" Returns pattern string for localhost address """
if family == socket.AF_INET:
......@@ -70,12 +72,15 @@ def get_local_addr_str(family, iface):
raise NotImplementedError("[get_local_addr_str] family not supported '%i'" % family)
return addr_local_pattern.format(iface)
class AddrMapInfo:
""" Saves mapping info between adresses from rpl and cwrap adresses """
def __init__(self, family, local, external):
self.family = family
self.local = local
self.external = external
self.family = family
self.local = local
self.external = external
class TestServer:
""" This simulates UDP DNS server returning scripted or mirror DNS responses. """
......@@ -102,7 +107,7 @@ class TestServer:
if self.active is True:
self.stop()
def start(self, port = 53):
def start(self, port=53):
""" Synchronous start """
if self.active is True:
raise Exception('TestServer already started')
......@@ -126,7 +131,7 @@ class TestServer:
self.connections = []
self.scenario = None
def check_family (self, addr, family):
def check_family(self, addr, family):
""" Determines if address matches family """
test_addr = None
try:
......@@ -148,10 +153,10 @@ class TestServer:
if k == 'stub-addr':
kroot_addr = v
if kroot_addr is not None:
if self.check_family (kroot_addr, socket.AF_INET):
if self.check_family(kroot_addr, socket.AF_INET):
self.addr_family = socket.AF_INET
self.kroot_local = kroot_addr
elif self.check_family (kroot_addr, socket.AF_INET6):
elif self.check_family(kroot_addr, socket.AF_INET6):
self.addr_family = socket.AF_INET6
self.kroot_local = kroot_addr
else:
......@@ -160,10 +165,10 @@ class TestServer:
def address(self):
""" Returns opened sockets list """
addrlist = [];
addrlist = []
for s in self.srv_socks:
addrlist.append(s.getsockname());
return addrlist;
addrlist.append(s.getsockname())
return addrlist
def handle_query(self, client):
"""
......@@ -177,24 +182,23 @@ class TestServer:
query, addr = recvfrom_msg(client)
if query is None:
return False
dprint ("[ handle_query ]", "%s incoming query from %s\n%s" % (client_address, addr, query))
dprint("[ handle_query ]", "%s incoming query from %s\n%s" % (client_address, addr, query))
response = dns.message.make_response(query)
is_raw_data = False
if self.scenario is not None:
response, is_raw_data = self.scenario.reply(query, client_address)
if response:
if is_raw_data is False:
data_to_wire = response.to_wire(max_size = 65535)
dprint ("[ handle_query ]", "response\n%s" % response)
data_to_wire = response.to_wire(max_size=65535)
dprint("[ handle_query ]", "response\n%s" % response)
else:
data_to_wire = response
dprint ("[ handle_query ]", "raw response found")
dprint("[ handle_query ]", "raw response found")
else:
response = dns.message.make_response(query)
response.set_rcode(dns.rcode.SERVFAIL)
data_to_wire = response.to_wire()
dprint ("[ handle_query ]", "response failed, SERVFAIL")
dprint("[ handle_query ]", "response failed, SERVFAIL")
sendto_msg(client, data_to_wire, addr)
return True
......@@ -218,11 +222,12 @@ class TestServer:
sock.close()
self.connections.remove(sock)
else:
raise Exception("[query_io] Socket IO internal error {}, exit".format(sock.getsockname()))
raise Exception(
"[query_io] Socket IO internal error {}, exit".format(sock.getsockname()))
for sock in to_error:
raise Exception("[query_io] Socket IO error {}, exit".format(sock.getsockname()))
def start_srv(self, address = None, family = socket.AF_INET, proto = socket.IPPROTO_UDP):