Commit 42ed2a91 authored by Štěpán Kotek's avatar Štěpán Kotek

Reformat code to accommodate PEP8.

refs knot/resolver-benchmarking#1
parent d5f886fb
import dns.rdatatype
def obj_blacklisted(msg):
"""
Detect blacklisted DNS message objects.
......
......@@ -18,7 +18,7 @@ def open_db(envdir):
'path': envdir,
'readonly': False,
'create': False
})
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=b'queries', create=False, **dbhelper.db_open)
ddb = lenv.open_db(key=b'diffs', create=False, **dbhelper.db_open)
......@@ -50,7 +50,9 @@ def save_stats(lenv, reprodb, qid, stats):
def main():
criteria = ['opcode', 'rcode', 'flags', 'question', 'qname', 'qtype', 'answertypes', 'answerrrsigs'] # FIXME
criteria = [
'opcode', 'rcode', 'flags', 'question', 'qname', 'qtype', 'answertypes', 'answerrrsigs'
] # FIXME
selector, sockets = sendrecv.sock_init(orchestrator.resolvers)
lenv, qdb, ddb, reprodb = open_db(sys.argv[1])
diff_stream = diffsum.read_diffs_lmdb(lenv, qdb, ddb)
......@@ -88,8 +90,7 @@ def main():
print('processed :', processed)
print('verified :', verified)
print('falzified : {} {:6.2f} %'.format(
processed-verified,
100.0*(processed-verified)/processed))
processed - verified, 100.0 * (processed - verified) / processed))
if __name__ == '__main__':
......
......@@ -13,6 +13,7 @@ import cfg
import dbhelper
from msgdiff import DataMismatch # needed for unpickling
def process_diff(field_weights, field_stats, qwire, diff):
found = False
for field in field_weights:
......@@ -41,12 +42,12 @@ def process_results(field_weights, diff_generator):
global_stats = {
'others_disagree': 0,
'target_disagrees': 0,
}
}
field_stats = {}
#print('diffs = {')
# print('diffs = {')
for qid, qwire, others_agree, target_diff in diff_generator:
#print(qid, others_agree, target_diff)
# print(qid, others_agree, target_diff)
if not others_agree:
global_stats['others_disagree'] += 1
continue
......@@ -54,36 +55,40 @@ def process_results(field_weights, diff_generator):
if not target_diff: # everybody agreed, nothing to count
continue
#print('(%s, %s): ' % (qid, question))
#print(target_diff, ',')
# print('(%s, %s): ' % (qid, question))
# print(target_diff, ',')
global_stats['target_disagrees'] += 1
process_diff(field_weights, field_stats, qwire, target_diff)
#print('}')
# print('}')
return global_stats, field_stats
def combine_stats(counters):
field_mismatch_sums = {}
for field in counters:
field_mismatch_sums[field] = collections.Counter(
{mismatch: sum(counter.values())
for mismatch, counter in counters[field].items()})
{mismatch: sum(counter.values())
for mismatch, counter in counters[field].items()})
field_sums = collections.Counter(
{field: sum(counter.values())
for field, counter in field_mismatch_sums.items()})
return field_sums, field_mismatch_sums
def mismatch2str(mismatch):
if not isinstance(mismatch[0], str):
return (' '.join(mismatch[0]), ' '.join(mismatch[1]))
else:
return mismatch
def maxlen(iterable):
return max(len(str(it)) for it in iterable)
def print_results(gstats, field_weights, counters, n=10):
# global stats
field_sums, field_mismatch_sums = combine_stats(counters)
......@@ -92,17 +97,17 @@ def print_results(gstats, field_weights, counters, n=10):
print('== Global statistics')
print('queries {:{}}'.format(gstats['queries'], maxcntlen))
print('answers {:{}} {:6.2f} % of queries'.format(
gstats['answers'], maxcntlen, float(100)*gstats['answers']/gstats['queries']))
gstats['answers'], maxcntlen, float(100) * gstats['answers'] / gstats['queries']))
others_agree = gstats['answers'] - gstats['others_disagree']
print('others agree {:{}} {:6.2f} % of answers (ignoring {:.2f} % of answers)'.format(
others_agree, maxcntlen,
100.0*others_agree/gstats['answers'],
100.0*gstats['others_disagree']/gstats['answers']))
100.0 * others_agree / gstats['answers'],
100.0 * gstats['others_disagree'] / gstats['answers']))
target_disagrees = gstats['target_disagrees']
print('target diagrees {:{}} {:6.2f} % of matching answers from others'.format(
gstats['target_disagrees'], maxcntlen,
100.0*gstats['target_disagrees']/gstats['answers']))
100.0 * gstats['target_disagrees'] / gstats['answers']))
if not field_sums.keys():
return
......@@ -118,33 +123,34 @@ def print_results(gstats, field_weights, counters, n=10):
for field, n in (field_sums.most_common()):
print('{:{}} {:{}} {:3.0f} %'.format(
field, maxnamelen + 3,
n, maxcntlen + 3,
100.0*n/target_disagrees))
n, maxcntlen + 3, 100.0 * n / target_disagrees))
for field in field_weights:
if not field in field_mismatch_sums:
if field not in field_mismatch_sums:
continue
print('')
print('== Field "%s" mismatch statistics' % field)
maxvallen = max((max(len(str(mismatch2str(mism)[0])), len(str(mismatch2str(mism)[1])))
for mism in field_mismatch_sums[field].keys()))
for mism in field_mismatch_sums[field].keys()))
maxcntlen = maxlen(field_mismatch_sums[field].values())
print('{:{}} != {:{}} {:{}} {}'.format(
'Expected', maxvallen,
'Got', (maxvallen - (len('count') - maxcntlen)) if maxvallen - (len('count') - maxcntlen) > 1 else 1,
'Got',
(maxvallen - (len('count') - maxcntlen)) if maxvallen - (len('count') - maxcntlen) > 1
else 1,
'count', maxcntlen,
'% of mismatches'
))
))
for mismatch, n in (field_mismatch_sums[field].most_common()):
mismatch = mismatch2str(mismatch)
print('{:{}} != {:{}} {:{}} {:3.0f} %'.format(
str(mismatch[0]), maxvallen,
str(mismatch[1]), maxvallen,
n, maxcntlen,
100.0*n/target_disagrees))
100.0 * n / target_disagrees))
for field in field_weights:
if not field in counters:
if field not in counters:
continue
for mismatch, n in (field_mismatch_sums[field].most_common()):
print('')
......@@ -154,29 +160,32 @@ def print_results(gstats, field_weights, counters, n=10):
def print_field_queries(field, counter, n):
#print('queries leading to mismatch in field "%s":' % field)
# print('queries leading to mismatch in field "%s":' % field)
for query, count in counter.most_common(n):
qname, qtype = query
qtype = dns.rdatatype.to_text(qtype)
print("%s %s\t\t%s mismatches" % (qname, qtype, count))
def open_db(envdir):
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'readonly': False,
'create': False
})
})
lenv = lmdb.Environment(**config)
try:
qdb = lenv.open_db(key=dbhelper.QUERIES_DB_NAME, create=False, **dbhelper.db_open)
adb = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=False, **dbhelper.db_open)
ddb = lenv.open_db(key=dbhelper.DIFFS_DB_NAME, create=False, **dbhelper.db_open)
except lmdb.NotFoundError:
logging.critical('Unable to generate statistics. LMDB does not contain queries, answers, or diffs!')
logging.critical(
'Unable to generate statistics. LMDB does not contain queries, answers, or diffs!')
raise
return lenv, qdb, adb, ddb
def read_diffs_lmdb(levn, qdb, ddb):
with levn.begin() as txn:
with txn.cursor(ddb) as diffcur:
......@@ -185,6 +194,7 @@ def read_diffs_lmdb(levn, qdb, ddb):
qwire = txn.get(qid, db=qdb)
yield (qid, qwire, others_agree, diff)
def main():
logging.basicConfig(format='%(levelname)s %(message)s', level=logging.DEBUG)
parser = argparse.ArgumentParser(
......@@ -206,5 +216,6 @@ def main():
global_stats['answers'] = txn.stat(adb)['entries']
print_results(global_stats, field_weights, field_stats)
if __name__ == '__main__':
main()
......@@ -28,7 +28,7 @@ config = dbhelper.env_open.copy()
config.update({
'path': sys.argv[1],
'readonly': True
})
})
lenv = lmdb.Environment(**config)
db = lenv.open_db(key=b'answers', **dbhelper.db_open, create=False)
......
......@@ -33,6 +33,7 @@ class DataMismatch(Exception):
def __ne__(self, other):
return self.__eq__(other)
def compare_val(exp_val, got_val):
""" Compare values, throw exception if different. """
if exp_val != got_val:
......@@ -49,7 +50,7 @@ def compare_rrs(expected, got):
if rr not in expected:
raise DataMismatch(expected, got)
if len(expected) != len(got):
raise DataMismatch(expected, got)
raise DataMismatch(expected, got)
return True
......@@ -80,6 +81,7 @@ def compare_rrs_types(exp_val, got_val, skip_rrsigs):
got_types = tuple(key_to_text(*i) for i in sorted(got_types))
raise DataMismatch(exp_types, got_types)
def match_part(exp_msg, got_msg, code):
""" Compare scripted reply to given message using single criteria. """
if code == 'opcode':
......@@ -135,6 +137,7 @@ def match_part(exp_msg, got_msg, code):
else:
raise NotImplementedError('unknown match request "%s"' % code)
def match(expected, got, match_fields):
""" Compare scripted reply to given message based on match criteria. """
for code in match_fields:
......@@ -149,12 +152,12 @@ def decode_wire_dict(wire_dict):
answers = {}
for k, v in wire_dict.items():
# decode bytes to dns.message objects
#if isinstance(v, bytes):
# if isinstance(v, bytes):
# convert from wire format to DNS message object
try:
answers[k] = dns.message.from_wire(v)
except Exception as ex:
#answers[k] = ex # decoding failed, record it!
# answers[k] = ex # decoding failed, record it!
continue
return answers
......@@ -223,7 +226,7 @@ def worker_init(envdir_arg, criteria_arg, target_arg):
'create': False,
'writemap': True,
'sync': False
})
})
lenv = lmdb.Environment(**config)
answers_db = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=False, **dbhelper.db_open)
diffs_db = lenv.open_db(key=dbhelper.DIFFS_DB_NAME, create=True, **dbhelper.db_open)
......@@ -263,7 +266,7 @@ def main():
'path': args.envdir,
'readonly': False,
'create': False
})
})
lenv = lmdb.Environment(**envconfig)
try:
......@@ -282,11 +285,12 @@ def main():
qid_stream = dbhelper.key_stream(lenv, db)
with pool.Pool(
initializer=worker_init,
initargs=(args.envdir, config['diff']['criteria'], config['diff']['target'])
) as p:
initializer=worker_init,
initargs=(args.envdir, config['diff']['criteria'], config['diff']['target'])
) as p:
for i in p.imap_unordered(compare_lmdb_wrapper, qid_stream, chunksize=10):
pass
if __name__ == '__main__':
main()
......@@ -9,7 +9,7 @@ for i in range(0, 256):
trans[i] = bytes(chr(i), encoding='ascii')
else:
trans[i] = ('\%03i' % i).encode('ascii')
#pprint(trans)
# pprint(trans)
while True:
line = sys.stdin.buffer.readline()
......@@ -26,7 +26,7 @@ while True:
# normalize name
normalized = b''
for nb in line[:typestart-1]:
for nb in line[:typestart - 1]:
normalized += trans[nb]
sys.stdout.buffer.write(normalized)
sys.stdout.buffer.write(b' ')
......
......@@ -35,7 +35,7 @@ def worker_init(envdir, resolvers, init_timeout):
'sync': False,
'map_async': True,
'readonly': False
})
})
lenv = lmdb.Environment(**config)
adb = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=True, **dbhelper.db_open)
......@@ -62,7 +62,7 @@ def reader_init(envdir):
config.update({
'path': envdir,
'readonly': True
})
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=dbhelper.QUERIES_DB_NAME,
create=False,
......@@ -110,5 +110,6 @@ def main():
for _ in p.imap_unordered(worker_query_lmdb_wrapper, qstream, chunksize=100):
pass
if __name__ == "__main__":
main()
......@@ -65,7 +65,7 @@ def wrk_lmdb_init(envdir):
'path': envdir,
'sync': False, # unsafe but fast
'writemap': True # we do not care, this is a new database
})
})
env = lmdb.Environment(**config)
db = env.open_db(key=b'queries', **dbhelper.db_open)
......@@ -124,5 +124,6 @@ def main():
for _ in workers.imap_unordered(wrk_process_line, qstream, chunksize=1000):
pass
if __name__ == '__main__':
main()
......@@ -19,7 +19,7 @@ def open_db(envdir):
'path': envdir,
'readonly': False,
'create': False
})
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=b'queries', create=False, **dbhelper.db_open)
ddb = lenv.open_db(key=b'diffs', create=False, **dbhelper.db_open)
......@@ -50,7 +50,6 @@ def read_repro_lmdb(levn, qdb, reprodb):
yield (qid, qwire, (count, others_agreed, diff_matched))
def main():
lenv, qdb, ddb, reprodb = open_db(sys.argv[1])
repro_stream = read_repro_lmdb(lenv, qdb, reprodb)
......@@ -63,5 +62,6 @@ def main():
continue
print(qmsg.question[0])
if __name__ == '__main__':
main()
......@@ -5,6 +5,7 @@ import socket
import dns.inet
import dns.message
def sock_init(resolvers):
"""
resolvers: [(name, ipaddr, port)]
......@@ -27,6 +28,7 @@ def sock_init(resolvers):
# selector.close() ? # TODO
return selector, sockets
def send_recv_parallel(what, selector, sockets, timeout):
replies = {}
for _, sock, destination in sockets:
......@@ -39,7 +41,7 @@ def send_recv_parallel(what, selector, sockets, timeout):
name = key.data
sock = key.fileobj
(wire, from_address) = sock.recvfrom(65535)
#assert len(wire) > 14
# assert len(wire) > 14
if what[0:2] != wire[0:2]:
continue # wrong msgid, this might be a delayed answer - ignore it
replies[name] = wire
......
......@@ -7,7 +7,7 @@ Created on May 25, 2017
import sys
import handler
if __name__ == '__main__':
if __name__ == '__main__':
my_handler = handler.Handler()
my_handler.start()
sys.exit(0)
\ No newline at end of file
sys.exit(0)
......@@ -7,31 +7,33 @@ import params
import maintain
import logging
log = logging.getLogger()
class Handler(object):
'''
Main program handler.
'''
def __init__(self):
self.params= None
self.params = None
def start(self):
form='%(asctime)-15s %(levelname)s - %(message)s'
form = '%(asctime)-15s %(levelname)s - %(message)s'
logging.basicConfig(level=logging.DEBUG, filename="logfile.log", format=form)
self.params = params.Params()
#load parameters
self.params = params.Params()
# load parameters
try:
self.params.read_params()
except IOError as e:
log.error(e)
raise
#Prepare remote servers
# Prepare remote servers
self.maintainer = maintain.MaintainResolver(self.params)
if self.params.kill_flag:
self.maintainer.kill_resolver(self.params.ip, self.params.port, self.params.ip)
else:
self.maintainer.prepare_resolvers()
self.params.check_server(self.params.ip, self.params.port)
\ No newline at end of file
self.params.check_server(self.params.ip, self.params.port)
......@@ -12,93 +12,132 @@ import fcntl
import struct
from _threading_local import local
log = logging.getLogger()
log = logging.getLogger()
class MaintainResolver(object):
def __init__(self, params):
self.params = params
self.tmp = os.path.join('/tmp/',self.params.user+'/')
self.tmp = os.path.join('/tmp/', self.params.user + '/')
def prepare_resolvers(self):
'''
Uploads script and config files to foreign servers depending
on configuration of IP's of servers.
Folders are stored in /tmp/. And run that server
'''
'''
local_ip = self.get_ip_address(self.params.local_port)
# print "local_ip:"+local_ip
## Send config folder
log.info("Copying %s script + config to %s" %(self.params.resolver, self.params.ip))
self.remote_copy("rsync", os.path.join(self.params.config, self.params.resolver+"-conf"), self.params.ip+":"+self.tmp, "-r")
# Send config folder
log.info("Copying %s script + config to %s" % (self.params.resolver, self.params.ip))
self.remote_copy(
"rsync",
os.path.join(
self.params.config,
self.params.resolver +
"-conf"),
self.params.ip +
":" +
self.tmp,
"-r")
# Send scripts (cleaner and resolver starter)
self.remote_copy(
program="rsync",
from_location=os.path.join(
self.params.script_folder,
"cleaner.sh"),
to_location=self.params.ip +
":" +
self.tmp)
self.remote_copy(
program="rsync",
from_location=os.path.join(
self.params.script_folder,
self.params.resolver +
".sh"),
to_location=self.params.ip +
":" +
self.tmp)
## Send scripts (cleaner and resolver starter)
self.remote_copy(program="rsync", from_location=os.path.join(self.params.script_folder, "cleaner.sh"), to_location=self.params.ip+":"+self.tmp)
self.remote_copy(program="rsync", from_location=os.path.join(self.params.script_folder, self.params.resolver+".sh"), to_location=self.params.ip+":"+self.tmp)
log.info("Generate config %s" % self.params.resolver)
self.generate_resolver_config(self.params.resolver, self.params.ip, self.params.port, local_ip)
self.generate_resolver_config(
self.params.resolver,
self.params.ip,
self.params.port,
local_ip)
log.info("Starting %s" % self.params.resolver)
#First make sure, that there is no resolver running at needed port
self.kill_resolver(self.params.ip, self.params.port, self.params.ip)
# First make sure, that there is no resolver running at needed port
self.kill_resolver(self.params.ip, self.params.port, self.params.ip)
if self.params.resolver == "knot":
#TODO: this just kill knot maintain port - if you uncomment - you need to add another parameter for this port
# TODO: this just kill knot maintain port - if you uncomment -
# you need to add another parameter for this port
# and also uncomment in /script/knot.sh line with maintain port
#self.kill_resolver(self.params.ip, self.params.knot_maintain_port, "0.0.0.0")
#TODO: possible to pick up branch for knot, not just master!
command = "bash %s %s" % (os.path.join(self.tmp, self.params.resolver+".sh"), "master")
# self.kill_resolver(self.params.ip, self.params.knot_maintain_port, "0.0.0.0")
# TODO: possible to pick up branch for knot, not just master!
command = "bash %s %s" % (os.path.join(
self.tmp, self.params.resolver + ".sh"), "master")
rv = self.remote_command("ssh", self.params.ip, command)
self.params.knot_commit = rv.strip()
log.info("Knot Branch, commit: %s, %s" % ("master", self.params.knot_commit))
self.params.knot_commit = rv.strip()
log.info("Knot Branch, commit: %s, %s" % ("master", self.params.knot_commit))
else:
command = "bash %s" % (os.path.join(self.tmp, self.params.resolver+".sh"))
command = "bash %s" % (os.path.join(self.tmp, self.params.resolver + ".sh"))
self.remote_command("ssh", self.params.ip, command)
def kill_resolver(self, server, port, run_ip):
log.info("sudo bash %s %s %s" % (os.path.join(self.tmp, "cleaner.sh"), run_ip, port))
command = "sudo bash %s %s %s" % (os.path.join(self.tmp, "cleaner.sh"), run_ip, port)
self.remote_command("ssh", server, command)
self.remote_command("ssh", server, command)
def generate_resolver_config(self, resolver, server, port, local_ip):
log.info("Generating %s config" % resolver)
command = "bash %s %s %s %s" % (os.path.join(self.tmp,resolver+"-conf",resolver+"-conf-generator.sh"), server, port, local_ip)
self.remote_command("ssh", server, command)
command = "bash %s %s %s %s" % (os.path.join(
self.tmp,
resolver + "-conf",
resolver + "-conf-generator.sh"),
server,
port,
local_ip)
self.remote_command("ssh", server, command)
@staticmethod
def remote_command(program, ip, command):
log.info("%s %s %s" % (program, ip, command))
ssh = subprocess.Popen([program, ip, command],
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
ssh.wait()
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
ssh.wait()
retcode = ssh.returncode
if retcode == 1:
raise IOError("Cannot execute: %s." % command)
try:
retval = ssh.stdout.readlines()[0]
except IndexError:
raise IOError("Connection failed: %s" % command)
log.debug("Command done.")
return retval
@staticmethod
def remote_copy(program, from_location, to_location, param=None):
def remote_copy(program, from_location, to_location, param=None):
if not param:
param = ""
log.info("%s %s %s %s", program, param, from_location, to_location)
with open(os.devnull, 'w') as FNULL:
with open(os.devnull, 'w') as FNULL:
ssh = subprocess.Popen([program, param, from_location, to_location],
shell=False,
stdout=FNULL,
stderr=FNULL)
shell=False,
stdout=FNULL,
stderr=FNULL)
ssh.wait()
ssh.communicate()[0]
if ssh.returncode == 1:
raise IOError("Cannot execute: %s %s %s %s." %(program, param, from_location, to_location))
raise IOError(
"Cannot execute: %s %s %s %s." %
(program, param, from_location, to_location))
log.debug("Copy done.")
@staticmethod
def get_ip_address(ifname):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
......@@ -109,5 +148,5 @@ class MaintainResolver(object):
struct.pack('256s', ifname[:15])
)[20:24])
except IOError as e:
log.error("Something wrong with given interface: %s."%ifname)
raise e
\ No newline at end of file
log.error("Something wrong with given interface: %s." % ifname)
raise e
......@@ -14,14 +14,14 @@ import sys
import socket
import telnetlib
log = logging.getLogger()
log = logging.getLogger()
class Params(object):
'''
Parameters holder
'''
def __init__(self):
'''
Constructor
......@@ -31,118 +31,135 @@ class Params(object):
self.config = ''
self.user = ''
self.kill_flag = False
self.path,tail = os.path.split(os.path.dirname(os.path.abspath(__file__)))
self.script_folder = os.path.join(self.path,"script")
self.resolver=""
self.local_port=""
self.path, tail = os.path.split(os.path.dirname(os.path.abspath(__file__)))
self.script_folder = os.path.join(self.path, "script")
self.resolver = ""
self.local_port = ""
def read_params(self, argv=None): # IGNORE:C0111
def read_params(self, argv=None): # IGNORE:C0111
'''Command line options.'''
possible_resolvers=['knot', 'bind', 'unbound', 'pdns']
possible_resolvers = ['knot', 'bind', 'unbound', 'pdns']
if argv is None:
argv = sys.argv
else:
sys.argv.extend(argv)
program_name = os.path.basename(sys.argv[0])
program_version = "v0.1"
program_build_date = "31.5.2017"
program_version_message = '%%(prog)s %s (%s)' % (program_version, program_build_date)
program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
program_license = '''%s
Created by Jan Holusa on %s.
Copyright 2017 nic.cz. All rights reserved.
''' % (program_shortdesc, "31.5.2017")
try:
# Setup argument parse_result
parse_result = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
parse_result.add_argument("-i", "--ip", dest="ip", help="Set ip to run resolver", required=True )
parse_result.add_argument("-p", "--port", dest="port", help="Set port to run resolver", required=True )
parse_result.add_argument("-u", "--user", dest="user", help="User defined on remote server with granted acess", required=True )
parse_result.add_argument('-k', '--kill', dest='kill', help="Kill running remote server.", action='store_true', default=False)
parse_result.add_argument('-r', '--resolver', dest='resolver', help="Which resolver I should start. Possible values are: %s" % ", ".join(map(str, possible_resolvers)))
parse_result.add_argument('-l', '--local_port', dest='local_port', help="Name of the local port. For example eth0, tun4 and so on...")
parse_result.add_argument('-V', '--version', action='version', version=program_version_message)