Commit 92622ed2 authored by Tomas Krizek's avatar Tomas Krizek

dbhlper: create MetaDatabase helper

parent dce39943
from abc import ABC
from contextlib import contextmanager
import logging
import os
import struct
import sys
from typing import Any, Dict, Iterator, Optional, Tuple, Sequence # noqa
import time
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Sequence # noqa
import lmdb
......@@ -15,6 +18,9 @@ QKey = bytes
WireFormat = bytes
VERSION = '2018-05-21'
def qid2key(qid: QID) -> QKey:
return struct.pack('<I', qid)
......@@ -27,6 +33,7 @@ class LMDB:
ANSWERS = b'answers'
DIFFS = b'diffs'
QUERIES = b'queries'
META = b'meta'
ENV_DEFAULTS = {
'map_size': 10 * 1024**3, # 10 G
......@@ -135,13 +142,13 @@ class DNSReply:
SIZEOF_INT = 4
SIZEOF_SHORT = 2
def __init__(self, wire: Optional[WireFormat], time: float = 0) -> None:
def __init__(self, wire: Optional[WireFormat], time_: float = 0) -> None:
if wire is None:
self.wire = b''
self.time = float('+inf')
else:
self.wire = wire
self.time = time
self.time = time_
@property
def timeout(self) -> bool:
......@@ -184,10 +191,10 @@ class DNSReply:
raise ValueError('Missing data in binary format')
if time_int == cls.TIMEOUT_INT:
time = float('+inf')
time_ = float('+inf')
else:
time = time_int / (10 ** 6)
reply = DNSReply(wire, time)
time_ = time_int / (10 ** 6)
reply = DNSReply(wire, time_)
return reply, buff[offset:]
......@@ -208,5 +215,98 @@ class DNSRepliesFactory:
return replies
class Database(ABC):
DB_NAME = b''
def __init__(self, lmdb_):
self.lmdb = lmdb_
self.db = None
@contextmanager
def transaction(self, write: bool = False):
# ensure database is open
if self.db is None:
if not self.DB_NAME:
raise RuntimeError('No database to initialize!')
self.lmdb.open_db(self.DB_NAME, create=True)
with self.lmdb.env.begin(self.db, write=write) as txn:
yield txn
def read_key(self, key: bytes) -> bytes:
with self.transaction() as txn:
data = txn.get(key)
if data is None:
raise ValueError("Missing '{}' key in '{}' database!".format(
key.decode('ascii'), self.DB_NAME.decode('ascii')))
return data
def write_key(self, key: bytes, value: bytes) -> None:
with self.transaction(write=True) as txn:
txn.put(key, value)
class MetaDatabase(Database):
DB_NAME = LMDB.META
KEY_VERSION = b'version'
KEY_START_TIME = b'start_time'
KEY_END_TIME = b'end_time'
KEY_SERVERS = b'servers'
KEY_NAME = b'name'
def read_servers(self) -> List[ResolverID]:
servers = []
ndata = self.read_key(self.KEY_SERVERS)
n, = struct.unpack('<I', ndata)
for i in range(n):
key = self.KEY_NAME + str(i).encode('ascii')
server = self.read_key(key)
servers.append(server.decode('ascii'))
return servers
def write_servers(self, servers: Sequence[ResolverID]) -> None:
n = struct.pack('<I', len(servers))
self.write_key(self.KEY_SERVERS, n)
for i, server in enumerate(servers):
key = self.KEY_NAME + str(i).encode('ascii')
self.write_key(key, server.encode('ascii'))
def write_version(self) -> None:
self.write_key(self.KEY_VERSION, VERSION.encode('ascii'))
def check_version(self) -> None:
version = self.read_key(self.KEY_VERSION).decode('ascii')
if version != VERSION:
raise ValueError(
'LMDB version mismatch! (expected "{}", got "{}")'.format(
VERSION, version))
def write_start_time(self, timestamp: Optional[int] = None) -> None:
self._write_timestamp(self.KEY_START_TIME, timestamp)
def write_end_time(self, timestamp: Optional[int] = None) -> None:
self._write_timestamp(self.KEY_END_TIME, timestamp)
def read_start_time(self) -> Optional[int]:
return self._read_timestamp(self.KEY_START_TIME)
def read_end_time(self) -> Optional[int]:
return self._read_timestamp(self.KEY_END_TIME)
def _read_timestamp(self, key: bytes) -> Optional[int]:
try:
data = self.read_key(key)
except ValueError:
return None
else:
return struct.unpack('<I', data)[0]
def _write_timestamp(self, key: bytes, timestamp: Optional[int]) -> None:
if timestamp is None:
timestamp = round(time.time())
data = struct.pack('<I', timestamp)
self.write_key(key, data)
# upon import, check we're on a little endian platform
assert sys.byteorder == 'little', 'Big endian platforms are not supported'
......@@ -6,6 +6,7 @@ import logging
import multiprocessing.pool as pool
import pickle
from typing import Any, Dict, Iterator, Mapping, Optional, Sequence, Tuple # noqa
import sys
import dns.exception
import dns.message
......@@ -15,7 +16,7 @@ import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter,
FieldLabel, MismatchValue, QID)
from dbhelper import DNSReply, LMDB, key2qid, ResolverID
from dbhelper import DNSReply, key2qid, LMDB, MetaDatabase, ResolverID
lmdb = None
......@@ -267,8 +268,18 @@ def main():
criteria = args.cfg['diff']['criteria']
target = args.cfg['diff']['target']
# fast=True would later cause lmdb.BadRslotError in conjuction with multiprocessing
with LMDB(args.envdir) as lmdb_:
meta = MetaDatabase(lmdb_)
try:
meta.check_version()
except ValueError as exc:
logging.critical(exc)
sys.exit(1)
with LMDB(args.envdir, fast=True) as lmdb_:
lmdb = lmdb_
lmdb.open_db(LMDB.ANSWERS)
lmdb.open_db(LMDB.DIFFS, create=True, drop=True)
qid_stream = lmdb.key_stream(LMDB.ANSWERS)
......
......@@ -9,7 +9,7 @@ import sys
import cli
from dataformat import DiffReport
from dbhelper import LMDB
from dbhelper import LMDB, MetaDatabase
import sendrecv
......@@ -52,6 +52,11 @@ def main():
start_time = int(time.time())
with LMDB(args.envdir) as lmdb:
meta = MetaDatabase(lmdb)
meta.write_version()
meta.write_start_time()
meta.write_servers(args.cfg['servers']['names'])
lmdb.open_db(LMDB.QUERIES)
adb = lmdb.open_db(LMDB.ANSWERS, create=True, check_notexists=True)
......@@ -75,6 +80,7 @@ def main():
finally:
# attempt to preserve data if something went wrong (or not)
txn.commit()
meta.write_end_time()
# get query/answer statistics
export_statistics(lmdb, datafile, start_time)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment