Commit 65b48167 authored by Tomas Krizek's avatar Tomas Krizek

use the new LMDB binary format in respdiff toolchain

This also Fixes knot/resolver-benchmarking#34 as all the data from orchestrator is stored in LMDB.
parent dfef5726
......@@ -12,6 +12,8 @@ import lmdb
from dataformat import QID
# upon import, check we're on a little endian platform
assert sys.byteorder == 'little', 'Big endian platforms are not supported'
ResolverID = str
RepliesBlob = bytes
......@@ -233,9 +235,10 @@ class DNSRepliesFactory:
class Database(ABC):
DB_NAME = b''
def __init__(self, lmdb_):
def __init__(self, lmdb_, create: bool = False) -> None:
self.lmdb = lmdb_
self.db = None
self.create = create
@contextmanager
def transaction(self, write: bool = False):
......@@ -246,7 +249,10 @@ class Database(ABC):
try:
self.db = self.lmdb.get_db(self.DB_NAME)
except ValueError:
self.db = self.lmdb.open_db(self.DB_NAME, create=True)
try:
self.db = self.lmdb.open_db(self.DB_NAME, create=self.create)
except lmdb.Error as exc:
raise RuntimeError('Failed to open LMDB database: {}'.format(exc))
with self.lmdb.env.begin(self.db, write=write) as txn:
yield txn
......@@ -255,7 +261,7 @@ class Database(ABC):
with self.transaction() as txn:
data = txn.get(key)
if data is None:
raise ValueError("Missing '{}' key in '{}' database!".format(
raise KeyError("Missing '{}' key in '{}' database!".format(
key.decode('ascii'), self.DB_NAME.decode('ascii')))
return data
......@@ -272,6 +278,19 @@ class MetaDatabase(Database):
KEY_SERVERS = b'servers'
KEY_NAME = b'name'
def __init__(
self,
lmdb_,
servers: Sequence[ResolverID],
create: bool = False
) -> None:
super(MetaDatabase, self).__init__(lmdb_, create)
if create:
self.write_servers(servers)
else:
self.check_version()
self.check_servers(servers)
def read_servers(self) -> List[ResolverID]:
servers = []
ndata = self.read_key(self.KEY_SERVERS)
......@@ -283,19 +302,28 @@ class MetaDatabase(Database):
return servers
def write_servers(self, servers: Sequence[ResolverID]) -> None:
if not servers:
raise ValueError("Empty list of servers!")
n = struct.pack('<I', len(servers))
self.write_key(self.KEY_SERVERS, n)
for i, server in enumerate(servers):
key = self.KEY_NAME + str(i).encode('ascii')
self.write_key(key, server.encode('ascii'))
def check_servers(self, servers: Sequence[ResolverID]) -> None:
db_servers = self.read_servers()
if not servers == db_servers:
raise NotImplementedError(
'Servers defined in config differ from the ones in "meta" database! '
'(config: "{}", meta db: "{}")'.format(servers, db_servers))
def write_version(self) -> None:
self.write_key(self.KEY_VERSION, VERSION.encode('ascii'))
def check_version(self) -> None:
version = self.read_key(self.KEY_VERSION).decode('ascii')
if version != VERSION:
raise ValueError(
raise NotImplementedError(
'LMDB version mismatch! (expected "{}", got "{}")'.format(
VERSION, version))
......@@ -314,7 +342,7 @@ class MetaDatabase(Database):
def _read_timestamp(self, key: bytes) -> Optional[int]:
try:
data = self.read_key(key)
except ValueError:
except KeyError:
return None
else:
return struct.unpack('<I', data)[0]
......@@ -324,7 +352,3 @@ class MetaDatabase(Database):
timestamp = round(time.time())
data = struct.pack('<I', timestamp)
self.write_key(key, data)
# upon import, check we're on a little endian platform
assert sys.byteorder == 'little', 'Big endian platforms are not supported'
......@@ -4,15 +4,17 @@ import argparse
from itertools import zip_longest
import logging
from multiprocessing import pool
import pickle
import random
import subprocess
import sys
from typing import ( # noqa
Any, AbstractSet, Iterable, Iterator, Mapping, Sequence, Tuple, TypeVar,
Union)
import cli
from dbhelper import LMDB, key2qid, ResolverID, RepliesBlob, qid2key, QKey, WireFormat
from dbhelper import (
DNSReply, DNSRepliesFactory, key2qid, LMDB, MetaDatabase, ResolverID, qid2key,
QKey, WireFormat)
import diffsum
from dataformat import Diff, DiffReport, FieldLabel, ReproData, QID # noqa
import msgdiff
......@@ -82,7 +84,7 @@ def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]:
def process_answers(
qkey: QKey,
replies_blob: RepliesBlob,
replies: Mapping[ResolverID, DNSReply],
report: DiffReport,
criteria: Sequence[FieldLabel],
target: ResolverID
......@@ -91,7 +93,6 @@ def process_answers(
raise RuntimeError("Report doesn't contain necessary data!")
qid = key2qid(qkey)
reprocounter = report.reprodata[qid]
replies = pickle.loads(replies_blob)
answers = msgdiff.decode_replies(replies)
others_agree, mismatches = msgdiff.compare(answers, criteria, target)
......@@ -118,6 +119,8 @@ def main():
datafile = cli.get_datafile(args)
report = DiffReport.from_json(datafile)
restart_scripts = get_restart_scripts(args.cfg)
servers = args.cfg['servers']['names']
dnsreplies_factory = DNSRepliesFactory(servers)
if args.sequential:
nproc = 1
......@@ -130,6 +133,12 @@ def main():
with LMDB(args.envdir, readonly=True) as lmdb:
lmdb.open_db(LMDB.QUERIES)
try:
MetaDatabase(lmdb, servers, create=False) # check version and servers
except NotImplementedError as exc:
logging.critical(exc)
sys.exit(1)
dstream = disagreement_query_stream(lmdb, report)
try:
with pool.Pool(processes=nproc) as p:
......@@ -140,10 +149,11 @@ def main():
restart_resolver(script)
process_args = [args for args in process_args if args is not None]
for qkey, replies, in p.imap_unordered(
for qkey, replies_data, in p.imap_unordered(
sendrecv.worker_perform_single_query,
process_args,
chunksize=1):
replies = dnsreplies_factory.parse(replies_data)
process_answers(qkey, replies, report,
args.cfg['diff']['criteria'],
args.cfg['diff']['target'])
......
......@@ -4,14 +4,15 @@
import argparse
import logging
import math
import pickle
from typing import Dict, List
import sys
import lmdb
import numpy as np
import cfg
from dbhelper import LMDB
import cli
from dbhelper import DNSRepliesFactory, LMDB, MetaDatabase, ResolverID
# Force matplotlib to use a different backend to handle machines without a display
import matplotlib
......@@ -20,13 +21,16 @@ matplotlib.use('Agg')
import matplotlib.pyplot as plt # noqa
def load_data(txn: lmdb.Transaction) -> Dict[str, List[float]]:
data = {} # type: Dict[str, List[float]]
def load_data(
txn: lmdb.Transaction,
dnsreplies_factory: DNSRepliesFactory
) -> Dict[ResolverID, List[float]]:
data = {} # type: Dict[ResolverID, List[float]]
cursor = txn.cursor()
for value in cursor.iternext(keys=False, values=True):
replies = pickle.loads(value)
replies = dnsreplies_factory.parse(value)
for resolver, reply in replies.items():
data.setdefault(resolver, []).append(reply.duration)
data.setdefault(resolver, []).append(reply.time)
return data
......@@ -69,8 +73,7 @@ def plot_log_percentile_histogram(data: Dict[str, List[float]], config=None):
def main():
logging.basicConfig(
format='%(levelname)s %(message)s', level=logging.DEBUG)
cli.setup_logging()
parser = argparse.ArgumentParser(
description='Plot query response time histogram from answers stored '
'in LMDB')
......@@ -83,11 +86,20 @@ def main():
help='LMDB environment to read answers from')
args = parser.parse_args()
config = cfg.read_cfg(args.cfgpath)
servers = config['servers']['names']
dnsreplies_factory = DNSRepliesFactory(servers)
with LMDB(args.envdir, readonly=True) as lmdb_:
adb = lmdb_.open_db(LMDB.ANSWERS)
try:
MetaDatabase(lmdb_, servers, create=False) # check version and servers
except NotImplementedError as exc:
logging.critical(exc)
sys.exit(1)
with lmdb_.env.begin(adb) as txn:
data = load_data(txn)
data = load_data(txn, dnsreplies_factory)
plot_log_percentile_histogram(data, config)
# save to file
......
......@@ -17,7 +17,7 @@ import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter,
FieldLabel, MismatchValue, QID)
from dbhelper import DNSReply, key2qid, LMDB, MetaDatabase, ResolverID
from dbhelper import DNSReply, DNSRepliesFactory, key2qid, LMDB, MetaDatabase, ResolverID
lmdb = None
......@@ -160,14 +160,16 @@ def decode_replies(
return answers
def read_answers_lmdb(qid: QID) -> Mapping[ResolverID, dns.message.Message]:
if lmdb is None:
raise RuntimeError("LMDB wasn't initialized!")
def read_answers_lmdb(
dnsreplies_factory: DNSRepliesFactory,
qid: QID
) -> Mapping[ResolverID, dns.message.Message]:
assert lmdb is not None, "LMDB wasn't initialized!"
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
replies_blob = txn.get(qid)
assert replies_blob
replies = pickle.loads(replies_blob)
replies = dnsreplies_factory.parse(replies_blob)
return decode_replies(replies)
......@@ -222,8 +224,14 @@ def compare(
return (others_agree, target_diffs)
def compare_lmdb_wrapper(criteria, target, qid):
answers = read_answers_lmdb(qid)
def compare_lmdb_wrapper(
criteria: Sequence[FieldLabel],
target: ResolverID,
dnsreplies_factory: DNSRepliesFactory,
qid: QID
) -> None:
assert lmdb is not None, "LMDB wasn't initialized!"
answers = read_answers_lmdb(dnsreplies_factory, qid)
others_agree, target_diffs = compare(answers, criteria, target)
if others_agree and not target_diffs:
return # all agreed, nothing to write
......@@ -234,9 +242,7 @@ def compare_lmdb_wrapper(criteria, target, qid):
def export_json(filename: str, report: DiffReport):
if lmdb is None:
raise RuntimeError("LMDB wasn't initialized!")
assert lmdb is not None, "LMDB wasn't initialized!"
report.other_disagreements = DisagreementsCounter()
report.target_disagreements = Disagreements()
......@@ -263,14 +269,14 @@ def export_json(filename: str, report: DiffReport):
report.export_json(filename)
def prepare_report(lmdb_):
def prepare_report(lmdb_, servers: Sequence[ResolverID]) -> DiffReport:
qdb = lmdb_.open_db(LMDB.QUERIES)
adb = lmdb_.open_db(LMDB.ANSWERS)
with lmdb_.env.begin() as txn:
total_queries = txn.stat(qdb)['entries']
total_answers = txn.stat(adb)['entries']
meta = MetaDatabase(lmdb_)
meta = MetaDatabase(lmdb_, servers)
start_time = meta.read_start_time()
end_time = meta.read_end_time()
......@@ -295,26 +301,27 @@ def main():
datafile = cli.get_datafile(args, check_exists=False)
criteria = args.cfg['diff']['criteria']
target = args.cfg['diff']['target']
servers = args.cfg['servers']['names']
with LMDB(args.envdir) as lmdb_:
# NOTE: To avoid an lmdb.BadRslotError, probably caused by weird
# interaction when using multiple transaction / processes, open a separate
# environment. Also, any dbs have to be opened before using MetaDatabase().
report = prepare_report(lmdb_)
meta = MetaDatabase(lmdb_)
report = prepare_report(lmdb_, servers)
try:
meta.check_version()
except ValueError as exc:
MetaDatabase(lmdb_, servers, create=False) # check version and servers
except NotImplementedError as exc:
logging.critical(exc)
sys.exit(1)
with LMDB(args.envdir, fast=True) as lmdb_:
lmdb = lmdb_
lmdb.open_db(LMDB.ANSWERS)
lmdb.open_db(LMDB.DIFFS, create=True, drop=True)
qid_stream = lmdb.key_stream(LMDB.ANSWERS)
func = partial(compare_lmdb_wrapper, criteria, target)
dnsreplies_factory = DNSRepliesFactory(servers)
func = partial(compare_lmdb_wrapper, criteria, target, dnsreplies_factory)
with pool.Pool() as p:
for _ in p.imap_unordered(func, qid_stream, chunksize=10):
pass
......
......@@ -24,10 +24,9 @@ def main():
sendrecv.module_init(args)
with LMDB(args.envdir) as lmdb:
meta = MetaDatabase(lmdb)
meta = MetaDatabase(lmdb, args.cfg['servers']['names'], create=True)
meta.write_version()
meta.write_start_time()
meta.write_servers(args.cfg['servers']['names'])
lmdb.open_db(LMDB.QUERIES)
adb = lmdb.open_db(LMDB.ANSWERS, create=True, check_notexists=True)
......
......@@ -10,7 +10,6 @@ threads or processes. Make sure not to break this compatibility.
from argparse import Namespace
import pickle
import random
import selectors
import socket
......@@ -23,7 +22,7 @@ from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints
import dns.inet
import dns.message
from dbhelper import DNSReply, RepliesBlob, ResolverID, QKey, WireFormat
from dbhelper import DNSReply, DNSRepliesFactory, RepliesBlob, ResolverID, QKey, WireFormat
IP = str
......@@ -45,6 +44,7 @@ __timeout = 10
__time_delay_min = 0
__time_delay_max = 0
__timeout_reply = DNSReply(None) # optimization: create only one timeout_reply object
__dnsreplies_factory = None
def module_init(args: Namespace) -> None:
......@@ -54,6 +54,7 @@ def module_init(args: Namespace) -> None:
global __timeout
global __time_delay_min
global __time_delay_max
global __dnsreplies_factory
__resolvers = get_resolvers(args.cfg)
__timeout = args.cfg['sendrecv']['timeout']
......@@ -68,6 +69,9 @@ def module_init(args: Namespace) -> None:
except AttributeError:
pass
servers = [resolver[0] for resolver in __resolvers]
__dnsreplies_factory = DNSRepliesFactory(servers)
def worker_init() -> None:
__worker_state.timeouts = {}
......@@ -108,7 +112,8 @@ def worker_perform_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, RepliesBl
worker_deinit()
worker_reinit()
blob = pickle.dumps(replies)
assert __dnsreplies_factory is not None, "Module wasn't initilized!"
blob = __dnsreplies_factory.serialize(replies)
return qkey, blob
......@@ -123,7 +128,8 @@ def worker_perform_single_query(args: Tuple[QKey, WireFormat]) -> Tuple[QKey, Re
worker_deinit()
blob = pickle.dumps(replies)
assert __dnsreplies_factory is not None, "Module wasn't initilized!"
blob = __dnsreplies_factory.serialize(replies)
return qkey, blob
......
......@@ -116,7 +116,7 @@ def test_lmdb_answers_single_server():
envdir = os.path.join(LMDB_DIR, 'answers_single_server')
with LMDB(envdir) as lmdb:
adb = lmdb.open_db(LMDB.ANSWERS)
meta = MetaDatabase(lmdb)
meta = MetaDatabase(lmdb, ['kresd'])
assert meta.read_start_time() == INT_3M
assert meta.read_end_time() == INT_3M
......@@ -136,7 +136,7 @@ def test_lmdb_answers_multiple_servers():
envdir = os.path.join(LMDB_DIR, 'answers_multiple_servers')
with LMDB(envdir) as lmdb:
adb = lmdb.open_db(LMDB.ANSWERS)
meta = MetaDatabase(lmdb)
meta = MetaDatabase(lmdb, ['kresd', 'bind', 'unbound'])
assert meta.read_start_time() is None
assert meta.read_end_time() is None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment