Commit db1298dd authored by Tomas Krizek's avatar Tomas Krizek Committed by Petr Špaček

respdiff: lmdb refactoring

Handle LMDB environment and databses in a single class to reduce
duplicate code.
parent 5dcda045
from typing import Dict, Any # NOQA: needed for type hint in comment
from typing import Dict, Any, Tuple, Generator # NOQA: needed for type hint in comment
import os
import lmdb
ANSWERS_DB_NAME = b'answers'
DIFFS_DB_NAME = b'diffs'
QUERIES_DB_NAME = b'queries'
STATS_DB_NAME = b'stats'
def qid2key(qid):
"""Encode query ID to database key"""
return str(qid).encode('ascii')
env_open = {
'map_size': 1024**4,
'max_readers': 64,
'max_dbs': 5,
'max_spare_txns': 64,
} # type: Dict[str, Any]
class LMDB:
ANSWERS = b'answers'
DIFFS = b'diffs'
QUERIES = b'queries'
REPROSTATS = b'reprostats'
STATS = b'stats'
db_open = {
'reverse_key': True
}
ENV_DEFAULTS = {
'map_size': 1024**4,
'max_readers': 64,
'max_dbs': 5,
'max_spare_txns': 64,
} # type: Dict[str, Any]
DB_OPEN_DEFAULTS = {
'reverse_key': True
} # type: Dict[str, Any]
def key_stream(lenv, db):
"""
yield all keys from given db
"""
with lenv.begin(db) as txn:
with txn.cursor(db) as cur:
cont = cur.first()
while cont:
yield cur.key()
cont = cur.next()
def __init__(self, path: str, create: bool = False,
readonly: bool = False, fast: bool = False) -> None:
self.path = path
self.dbs = {} # type: Dict[bytes, Any]
self.config = LMDB.ENV_DEFAULTS.copy()
self.config.update({
'path': path,
'create': create,
'readonly': readonly
})
if fast: # unsafe on crashes, but faster
self.config.update({
'writemap': True,
'sync': False,
'map_async': True,
})
if not os.path.exists(self.path):
os.makedirs(self.path)
self.env = lmdb.Environment(**self.config)
def key_value_stream(lenv, db):
"""
yield all (key, value) pairs from given db
"""
with lenv.begin(db) as txn:
cur = txn.cursor(db)
for key, blob in cur:
yield (key, blob)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.env.close()
def qid2key(qid):
"""Encode query ID to database key"""
return str(qid).encode('ascii')
def open_db(self, dbname: bytes, create: bool = False, check_exists: bool = False,
check_notexists: bool = False, drop: bool = False):
assert self.env is not None, "LMDB wasn't initialized!"
if check_exists and not self.exists_db(dbname):
msg = 'LMDB environment "{}" does not contain DB {}! '.format(
self.path, dbname.decode('utf-8'))
raise RuntimeError(msg)
if check_notexists and self.exists_db(dbname):
msg = ('LMDB environment "{}" already contains DB {}! '
'Overwritting it would invalidate data in the environment, '
'terminating.').format(self.path, dbname.decode('utf-8'))
raise RuntimeError(msg)
if drop:
try:
db = self.env.open_db(key=dbname, create=False, **LMDB.DB_OPEN_DEFAULTS)
with self.env.begin(write=True) as txn:
txn.drop(db)
except lmdb.NotFoundError:
pass
db = self.env.open_db(key=dbname, create=create, **LMDB.DB_OPEN_DEFAULTS)
self.dbs[dbname] = db
return db
def exists_db(self, dbname: bytes) -> bool:
config = LMDB.ENV_DEFAULTS.copy()
config.update({
'path': self.path,
'readonly': True,
'create': False
})
try:
with lmdb.Environment(**config) as env:
env.open_db(key=dbname, **LMDB.DB_OPEN_DEFAULTS, create=False)
return True
except (lmdb.NotFoundError, lmdb.Error):
return False
def get_db(self, dbname: bytes):
try:
return self.dbs[dbname]
except KeyError:
raise RuntimeError("Database {} isn't open!".format(dbname.decode('utf-8')))
def key_stream(self, dbname: bytes):
"""yield all keys from given db"""
db = self.get_db(dbname)
with self.env.begin(db) as txn:
cur = txn.cursor(db)
for key in cur.iternext(keys=True, values=False):
yield key
def db_exists(envdir, dbname):
"""
Determine if named DB exists in environment specified by path.
"""
config = env_open.copy()
config['path'] = envdir
config['readonly'] = True
try:
with lmdb.Environment(**config) as env:
env.open_db(key=dbname, **db_open, create=False)
return True
except (lmdb.NotFoundError, lmdb.Error):
return False
def key_value_stream(self, dbname: bytes):
"""yield all (key, value) pairs from given db"""
db = self.get_db(dbname)
with self.env.begin(db) as txn:
cur = txn.cursor(db)
for key, blob in cur:
yield key, blob
......@@ -2,9 +2,7 @@ import pickle
import subprocess
import sys
import lmdb
import dbhelper
from dbhelper import LMDB
import diffsum
from msgdiff import DataMismatch # noqa: needed for unpickling
import msgdiff
......@@ -12,24 +10,11 @@ import orchestrator
import sendrecv
def open_db(envdir):
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'readonly': False,
'create': False
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=b'queries', create=False, **dbhelper.db_open)
ddb = lenv.open_db(key=b'diffs', create=False, **dbhelper.db_open)
reprodb = lenv.open_db(key=b'reprostats', create=True, **dbhelper.db_open)
return lenv, qdb, ddb, reprodb
def load_stats(lenv, reprodb, qid):
def load_stats(lmdb, qid):
"""(count, others_agreed, diff_matched)"""
with lenv.begin() as txn:
stats_bin = txn.get(qid, db=reprodb)
reprodb = lmdb.get_db(LMDB.REPROSTATS)
with lmdb.env.begin(reprodb) as txn:
stats_bin = txn.get(qid)
if stats_bin:
stats = pickle.loads(stats_bin)
else:
......@@ -40,13 +25,14 @@ def load_stats(lenv, reprodb, qid):
return stats[0], stats[1], stats[2]
def save_stats(lenv, reprodb, qid, stats):
def save_stats(lmdb, qid, stats):
assert len(stats) == 3
assert stats[0] >= stats[1] >= stats[2]
stats_bin = pickle.dumps(stats)
with lenv.begin(write=True) as txn:
txn.put(qid, stats_bin, db=reprodb)
reprodb = lmdb.get_db(LMDB.REPROSTATS)
with lmdb.env.begin(reprodb, write=True) as txn:
txn.put(qid, stats_bin)
def main():
......@@ -54,38 +40,43 @@ def main():
'opcode', 'rcode', 'flags', 'question', 'qname', 'qtype', 'answertypes', 'answerrrsigs'
] # FIXME
selector, sockets = sendrecv.sock_init(getattr(orchestrator, 'resolvers'))
lenv, qdb, ddb, reprodb = open_db(sys.argv[1])
diff_stream = diffsum.read_diffs_lmdb(lenv, qdb, ddb)
processed = 0
verified = 0
for qid, qwire, orig_others_agree, orig_diffs in diff_stream:
if not orig_others_agree:
continue # others do not agree, nothing to verify
# others agree, verify if answers are stable and the diff is reproducible
retries, upstream_stable, diff_matches = load_stats(lenv, reprodb, qid)
if retries > 0:
if retries != upstream_stable or upstream_stable != diff_matches:
continue # either unstable upstream or diff is not 100 % reproducible, skip it
processed += 1
# it might be reproducible, restart everything
if len(sys.argv) == 3:
subprocess.check_call([sys.argv[2]])
wire_blobs = sendrecv.send_recv_parallel(qwire, selector, sockets, orchestrator.timeout)
answers = msgdiff.decode_wire_dict(wire_blobs)
new_others_agree, new_diffs = msgdiff.compare(answers, criteria, 'kresd') # FIXME
retries += 1
if orig_others_agree == new_others_agree:
upstream_stable += 1
if orig_diffs == new_diffs:
diff_matches += 1
print(qid, (retries, upstream_stable, diff_matches))
save_stats(lenv, reprodb, qid, (retries, upstream_stable, diff_matches))
if retries == upstream_stable == diff_matches:
verified += 1
with LMDB(sys.argv[1]) as lmdb:
lmdb.open_db(LMDB.QUERIES)
lmdb.open_db(LMDB.DIFFS)
lmdb.open_db(LMDB.REPROSTATS, create=True)
diff_stream = diffsum.read_diffs_lmdb(lmdb)
processed = 0
verified = 0
for qid, qwire, orig_others_agree, orig_diffs in diff_stream:
if not orig_others_agree:
continue # others do not agree, nothing to verify
# others agree, verify if answers are stable and the diff is reproducible
retries, upstream_stable, diff_matches = load_stats(lmdb, qid)
if retries > 0:
if retries != upstream_stable or upstream_stable != diff_matches:
continue # either unstable upstream or diff is not 100 % reproducible, skip it
processed += 1
# it might be reproducible, restart everything
if len(sys.argv) == 3:
subprocess.check_call([sys.argv[2]])
wire_blobs = sendrecv.send_recv_parallel(qwire, selector, sockets, orchestrator.timeout)
answers = msgdiff.decode_wire_dict(wire_blobs)
new_others_agree, new_diffs = msgdiff.compare(answers, criteria, 'kresd') # FIXME
retries += 1
if orig_others_agree == new_others_agree:
upstream_stable += 1
if orig_diffs == new_diffs:
diff_matches += 1
print(qid, (retries, upstream_stable, diff_matches))
save_stats(lmdb, qid, (retries, upstream_stable, diff_matches))
if retries == upstream_stable == diff_matches:
verified += 1
print('processed :', processed)
print('verified :', verified)
......
......@@ -6,10 +6,9 @@ import logging
import pickle
import dns.rdatatype
import lmdb
import cfg
import dbhelper
from dbhelper import LMDB
from msgdiff import DataMismatch # NOQA: needed for unpickling
......@@ -158,33 +157,15 @@ def print_field_queries(counter, n):
print("%s %s\t\t%s mismatches" % (qname, qtype, count))
def open_db(envdir):
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'readonly': False,
'create': False
})
lenv = lmdb.Environment(**config)
try:
qdb = lenv.open_db(key=dbhelper.QUERIES_DB_NAME, create=False, **dbhelper.db_open)
adb = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=False, **dbhelper.db_open)
ddb = lenv.open_db(key=dbhelper.DIFFS_DB_NAME, create=False, **dbhelper.db_open)
sdb = lenv.open_db(key=dbhelper.STATS_DB_NAME, create=False, **dbhelper.db_open)
except lmdb.NotFoundError:
logging.critical(
'Unable to generate statistics. LMDB does not contain queries, answers, or diffs!')
raise
return lenv, qdb, adb, ddb, sdb
def read_diffs_lmdb(levn, qdb, ddb):
with levn.begin() as txn:
def read_diffs_lmdb(lmdb):
qdb = lmdb.get_db(LMDB.QUERIES)
ddb = lmdb.get_db(LMDB.DIFFS)
with lmdb.env.begin() as txn:
with txn.cursor(ddb) as diffcur:
for qid, diffblob in diffcur:
others_agree, diff = pickle.loads(diffblob)
qwire = txn.get(qid, db=qdb)
yield (qid, qwire, others_agree, diff)
yield qid, qwire, others_agree, diff
def main():
......@@ -200,14 +181,18 @@ def main():
config = cfg.read_cfg(args.cfgpath)
field_weights = config['report']['field_weights']
lenv, qdb, adb, ddb, sdb = open_db(args.envdir)
diff_stream = read_diffs_lmdb(lenv, qdb, ddb)
global_stats, field_stats = process_results(field_weights, diff_stream)
with lenv.begin() as txn:
global_stats['queries'] = txn.stat(qdb)['entries']
global_stats['answers'] = txn.stat(adb)['entries']
with lenv.begin(sdb) as txn:
stats = pickle.loads(txn.get(b'global_stats'))
with LMDB(args.envdir, readonly=True) as lmdb:
qdb = lmdb.open_db(LMDB.QUERIES)
adb = lmdb.open_db(LMDB.ANSWERS)
lmdb.open_db(LMDB.DIFFS)
sdb = lmdb.open_db(LMDB.STATS)
diff_stream = read_diffs_lmdb(lmdb)
global_stats, field_stats = process_results(field_weights, diff_stream)
with lmdb.env.begin() as txn:
global_stats['queries'] = txn.stat(qdb)['entries']
global_stats['answers'] = txn.stat(adb)['entries']
with lmdb.env.begin(sdb) as txn:
stats = pickle.loads(txn.get(b'global_stats'))
global_stats['duration'] = round(stats['end_time'] - stats['start_time'])
print_results(global_stats, field_weights, field_stats)
......
......@@ -2,13 +2,11 @@ import pickle
import os
import sys
import lmdb
from dbhelper import LMDB
import dbhelper
def read_blobs_lmdb(lenv, db, qid):
with lenv.begin(db) as txn:
def read_blobs_lmdb(lmdb, db, qid):
with lmdb.env.begin(db) as txn:
blob = txn.get(qid)
assert blob
answers = pickle.loads(blob)
......@@ -23,18 +21,11 @@ def write_blobs(blob_dict, workdir):
def main():
config = dbhelper.env_open.copy()
config.update({
'path': sys.argv[1],
'readonly': True
})
lenv = lmdb.Environment(**config)
db = lenv.open_db(key=b'answers', **dbhelper.db_open, create=False)
qid = str(int(sys.argv[2])).encode('ascii')
blobs = read_blobs_lmdb(lenv, db, qid)
write_blobs(blobs, sys.argv[3])
lenv.close()
with LMDB(sys.argv[1], readonly=True) as lmdb:
db = lmdb.open_db(LMDB.ANSWERS)
qid = str(int(sys.argv[2])).encode('ascii')
blobs = read_blobs_lmdb(lmdb, db, qid)
write_blobs(blobs, sys.argv[3])
if __name__ == '__main__':
......
......@@ -5,21 +5,17 @@ from functools import partial
import logging
import multiprocessing.pool as pool
import pickle
import sys
from typing import Dict
import dns.message
import dns.exception
import lmdb
import cfg
import dataformat
import dbhelper
from dbhelper import LMDB
lenv = None
answers_db = None
diffs_db = None
lmdb = None
class DataMismatch(Exception):
......@@ -165,8 +161,9 @@ def decode_wire_dict(wire_dict: Dict[str, dataformat.Reply]) \
return answers
def read_answers_lmdb(lenv, db, qid): # pylint: disable=redefined-outer-name
with lenv.begin(db) as txn:
def read_answers_lmdb(qid):
adb = lmdb.get_db(LMDB.ANSWERS)
with lmdb.env.begin(adb) as txn:
blob = txn.get(qid)
assert blob
wire_dict = pickle.loads(blob)
......@@ -214,35 +211,20 @@ def compare(answers, criteria, target):
return (others_agree, target_diffs)
def lmdb_init(envdir):
global lenv
global answers_db
global diffs_db
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'readonly': False,
'create': False,
'writemap': True,
'sync': False
})
lenv = lmdb.Environment(**config)
answers_db = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=False, **dbhelper.db_open)
diffs_db = lenv.open_db(key=dbhelper.DIFFS_DB_NAME, create=True, **dbhelper.db_open)
def compare_lmdb_wrapper(criteria, target, qid):
answers = read_answers_lmdb(lenv, answers_db, qid)
answers = read_answers_lmdb(qid)
others_agree, target_diffs = compare(answers, criteria, target)
if others_agree and not target_diffs:
return # all agreed, nothing to write
blob = pickle.dumps((others_agree, target_diffs))
with lenv.begin(diffs_db, write=True) as txn:
ddb = lmdb.get_db(LMDB.DIFFS)
with lmdb.env.begin(ddb, write=True) as txn:
txn.put(qid, blob)
def main():
global lmdb
logging.basicConfig(format='%(levelname)s %(message)s', level=logging.DEBUG)
parser = argparse.ArgumentParser(
description='compute diff from answers stored in LMDB and write diffs to LMDB')
......@@ -253,37 +235,18 @@ def main():
args = parser.parse_args()
config = cfg.read_cfg(args.cfgpath)
envconfig = dbhelper.env_open.copy()
envconfig.update({
'path': args.envdir,
'readonly': False,
'create': False
})
lenv_tmp = lmdb.Environment(**envconfig)
try:
lenv_tmp.open_db(key=dbhelper.ANSWERS_DB_NAME, create=False, **dbhelper.db_open)
except lmdb.NotFoundError:
logging.critical('LMDB does not contain DNS answers in DB %s, terminating.',
dbhelper.ANSWERS_DB_NAME)
sys.exit(1)
try: # drop diffs DB if it exists, it can be re-generated at will
ddb = lenv_tmp.open_db(key=dbhelper.DIFFS_DB_NAME, create=False, **dbhelper.db_open)
with lenv_tmp.begin(write=True) as txn:
txn.drop(ddb)
except lmdb.NotFoundError:
pass
lmdb_init(args.envdir)
criteria = config['diff']['criteria']
target = config['diff']['target']
qid_stream = dbhelper.key_stream(lenv, answers_db)
func = partial(compare_lmdb_wrapper, criteria, target)
with pool.Pool() as p:
for _ in p.imap_unordered(func, qid_stream, chunksize=10):
pass
with LMDB(args.envdir, fast=True) as lmdb_:
lmdb = lmdb_
lmdb.open_db(LMDB.ANSWERS, check_exists=True)
lmdb.open_db(LMDB.DIFFS, create=True, drop=True)
qid_stream = lmdb.key_stream(LMDB.ANSWERS)
func = partial(compare_lmdb_wrapper, criteria, target)
with pool.Pool() as p:
for _ in p.imap_unordered(func, qid_stream, chunksize=10):
pass
if __name__ == '__main__':
......
......@@ -3,20 +3,17 @@
import argparse
import multiprocessing.pool as pool
import pickle
import sys
import threading
import time
import logging
import lmdb
from typing import List, Tuple # noqa: type hints
import cfg
import dbhelper
from dbhelper import LMDB
import sendrecv
worker_state = {} # shared by all workers
resolvers = []
resolvers = [] # type: List[Tuple[str, str, str, int]]
timeout = None
......@@ -54,25 +51,6 @@ def worker_query_lmdb_wrapper(args):
return (qid, blob)
def lmdb_init(envdir):
"""Open LMDB environment and database for writting."""
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'writemap': True,
'sync': False,
'map_async': True,
'readonly': False
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=dbhelper.QUERIES_DB_NAME,
create=False,
**dbhelper.db_open)
adb = lenv.open_db(key=dbhelper.ANSWERS_DB_NAME, create=True, **dbhelper.db_open)
sdb = lenv.open_db(key=dbhelper.STATS_DB_NAME, create=True, **dbhelper.db_open)
return (lenv, qdb, adb, sdb)
def main():
global timeout
......@@ -90,39 +68,27 @@ def main():
resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
timeout = args.cfg['sendrecv']['timeout']
if not dbhelper.db_exists(args.envdir, dbhelper.QUERIES_DB_NAME):
logging.critical(
'LMDB environment "%s does not contain DB %s! '
'Use qprep to prepare queries.',
args.envdir, dbhelper.ANSWERS_DB_NAME)
sys.exit(1)
if dbhelper.db_exists(args.envdir, dbhelper.ANSWERS_DB_NAME):
logging.critical(
'LMDB environment "%s" already contains DB %s! '
'Overwritting it would invalidate data in the environment, '
'terminating.',
args.envdir, dbhelper.ANSWERS_DB_NAME)
sys.exit(1)
lenv, qdb, adb, sdb = lmdb_init(args.envdir)
qstream = dbhelper.key_value_stream(lenv, qdb)
stats = {
'start_time': time.time(),
'end_time': None,
}
with lenv.begin(adb, write=True) as txn:
with pool.Pool(
processes=args.cfg['sendrecv']['jobs'],
initializer=worker_init) as p:
for qid, blob in p.imap(worker_query_lmdb_wrapper, qstream, chunksize=100):
txn.put(qid, blob)
stats['end_time'] = time.time()
with lenv.begin(sdb, write=True) as txn:
txn.put(b'global_stats', pickle.dumps(stats))
with LMDB(args.envdir, fast=True) as lmdb:
lmdb.open_db(LMDB.QUERIES, check_exists=True)
adb = lmdb.open_db(LMDB.ANSWERS, create=True, check_notexists=True)
sdb = lmdb.open_db(LMDB.STATS, create=True)
qstream = lmdb.key_value_stream(LMDB.QUERIES)
with lmdb.env.begin(adb, write=True) as txn:
with pool.Pool(
processes=args.cfg['sendrecv']['jobs'],
initializer=worker_init) as p:
for qid, blob in p.imap(worker_query_lmdb_wrapper, qstream, chunksize=100):
txn.put(qid, blob)
stats['end_time'] = time.time()
with lmdb.env.begin(sdb, write=True) as txn:
txn.put(b'global_stats', pickle.dumps(stats))
if __name__ == "__main__":
......
......@@ -8,13 +8,12 @@ import sys
from typing import Tuple
import dpkt
import blacklist
import dbhelper
import dns.exception
import dns.message
import dns.rdatatype
import lmdb
import blacklist
from dbhelper import LMDB, qid2key
REPORT_CHUNKS = 10000
......@@ -78,7 +77,7 @@ def wrk_process_wire_packet(qid: int, wire_packet: bytes, log_repr: str) -> Tupl
:arg log_repr representation of packet for logs
"""
if not blacklist.is_blacklisted(wire_packet):
key = dbhelper.qid2key(qid)
key = qid2key(qid)
return key, wire_packet
logging.debug('Query "%s" blacklisted (skipping query ID %d)',
......@@ -109,19 +108,6 @@ def wire_from_text(text):
return msg.to_wire()
def lmdb_init(envdir):
"""Open LMDB environment and database"""
config = dbhelper.env_open.copy()
config.update({
'path': envdir,
'sync': False, # unsafe but fast
'writemap': True # we do not care, this is a new database
})
lenv = lmdb.Environment(**config)
qdb = lenv.open_db(key=dbhelper.QUERIES_DB_NAME, **dbhelper.db_open)
return (lenv, qdb)
def main():
logging.basicConfig(format='%(levelname)s %(message)s', level=logging.DEBUG)
parser = argparse.ArgumentParser(
......@@ -129,7 +115,7 @@ def main():
description='Convert queries data from standard input and store '
'wire format into LMDB "queries" DB.')
parser.add_argument('envpath', type=str, help='path where to create LMDB environment')
parser.add_argument('envdir', type=str, help='path where to create LMDB environment')
parser.add_argument('-f', '--in-format', type=str, choices=['text', 'pcap'], default='text',
help='define format for input data, default value is text\n'
'Expected input for "text" is: "<qname> <RR type>", '
......@@ -144,30 +130,23 @@ def main():
if args.in_format == 'pcap' and not args.pcap_file:
logging.critical("Missing path to pcap file, use argument --pcap-file")
sys.exit(1)
if dbhelper.db_exists(args.envpath, dbhelper.QUERIES_DB_NAME):
logging.critical(
'LMDB environment "%s" already contains DB %s! '
'Overwritting it would invalidate data in the environment, '
'terminating.',
args.envpath, dbhelper.QUERIES_DB_NAME)
sys.exit(1)
lenv, qdb = lmdb_init(args.envpath)
with lenv.begin(qdb, write=True) as txn:
with pool.Pool() as workers:
if args.in_format == 'text':
data_stream = read_lines(sys.stdin)
method = wrk_process_line
elif args.in_format == 'pcap':
data_stream = parse_pcap(args.pcap_file)
method = wrk_process_packet
else:
logging.error('unknown in-format, use "text" or "pcap"')
sys