Commit b7c726b9 authored by Tomas Krizek's avatar Tomas Krizek

refactoring: move query parallelization logic to sendrecv

parent eab5a928
......@@ -9,7 +9,6 @@ from typing import ( # noqa
MismatchValue = Union[str, Sequence[str]]
ResolverID = str
QID = int
WireFormat = bytes
FieldLabel = str
......
......@@ -15,8 +15,9 @@ import dns.exception
import cli
from dataformat import (
DataMismatch, DiffReport, Disagreements, DisagreementsCounter, FieldLabel, MismatchValue,
Reply, ResolverID, QID)
Reply, QID)
from dbhelper import LMDB, key2qid
from sendrecv import ResolverID
lmdb = None # type: Optional[Any]
......
......@@ -4,89 +4,15 @@ import argparse
import logging
import multiprocessing.pool as pool
import os
import pickle
import random
import threading
import time
from typing import List, Tuple, Dict, Any, Mapping, Sequence # noqa: type hints
import sys
import cli
from dataformat import DiffReport, ResolverID
from dataformat import DiffReport
from dbhelper import LMDB
import sendrecv
worker_state = threading.local()
resolvers = [] # type: List[Tuple[str, str, str, int]]
ignore_timeout = False
max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver
timeout = None
time_delay_min = 0
time_delay_max = 0
def worker_init():
"""
make sure it works with distincts processes and threads as well
"""
worker_state.timeouts = {}
worker_reinit()
def worker_reinit():
selector, sockets = sendrecv.sock_init(resolvers)
worker_state.selector = selector
worker_state.sockets = sockets
def worker_deinit(selector, sockets):
"""
Close all sockets and selector.
"""
selector.close()
for _, sck, _ in sockets:
sck.close()
def worker_query_lmdb_wrapper(args):
qid, qwire = args
selector = worker_state.selector
sockets = worker_state.sockets
# optional artificial delay for testing
if time_delay_max > 0:
time.sleep(random.uniform(time_delay_min, time_delay_max))
replies, reinit = sendrecv.send_recv_parallel(qwire, selector, sockets, timeout)
if not ignore_timeout:
check_timeout(replies)
if reinit: # a connection is broken or something
# TODO: log this?
worker_deinit(selector, sockets)
worker_reinit()
blob = pickle.dumps(replies)
return (qid, blob)
def check_timeout(replies):
for resolver, reply in replies.items():
timeouts = worker_state.timeouts
if reply.wire is not None:
timeouts[resolver] = 0
else:
timeouts[resolver] = timeouts.get(resolver, 0) + 1
if timeouts[resolver] >= max_timeouts:
raise RuntimeError(
"Resolver '{}' timed-out {:d} times in a row. "
"Use '--ignore-timeout' to supress this error.".format(
resolver, max_timeouts))
def export_statistics(lmdb, datafile, start_time):
qdb = lmdb.get_db(LMDB.QUERIES)
adb = lmdb.get_db(LMDB.ANSWERS)
......@@ -109,22 +35,7 @@ def export_statistics(lmdb, datafile, start_time):
report.export_json(datafile)
def get_resolvers(config: Mapping[str, Any]) -> Sequence[Tuple[ResolverID, str, str, int]]:
resolvers_ = []
for resname in config['servers']['names']:
rescfg = config[resname]
resolvers_.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
return resolvers_
def main():
global ignore_timeout
global max_timeouts
global resolvers
global timeout
global time_delay_min
global time_delay_max
cli.setup_logging()
parser = argparse.ArgumentParser(
description='read queries from LMDB, send them in parallel to servers '
......@@ -136,16 +47,8 @@ def main():
help='continue despite consecutive timeouts from resolvers')
args = parser.parse_args()
sendrecv.module_init(args)
datafile = cli.get_datafile(args)
resolvers = get_resolvers(args.cfg)
ignore_timeout = args.ignore_timeout
timeout = args.cfg['sendrecv']['timeout']
time_delay_min = args.cfg['sendrecv']['time_delay_min']
time_delay_max = args.cfg['sendrecv']['time_delay_max']
try:
max_timeouts = args.cfg['sendrecv']['max_timeouts']
except KeyError:
pass
start_time = int(time.time())
with LMDB(args.envdir) as lmdb:
......@@ -158,9 +61,9 @@ def main():
# process queries in parallel
with pool.Pool(
processes=args.cfg['sendrecv']['jobs'],
initializer=worker_init) as p:
initializer=sendrecv.worker_init) as p:
i = 0
for qid, blob in p.imap(worker_query_lmdb_wrapper, qstream,
for qid, blob in p.imap(sendrecv.worker_perform_query, qstream,
chunksize=100):
i += 1
if i % 10000 == 0:
......
"""
sendrecv module
===============
This module is used by orchestrator and diffrepro to perform DNS queries in parallel.
The entire module keeps a global state, which enables its easy use with both
threads or processes. Make sure not to break this compatibility.
"""
from argparse import Namespace
import pickle
import random
import selectors
import socket
import ssl
import struct
import time
from typing import Dict, Mapping, Tuple # noqa
import threading
from typing import Any, Dict, List, Mapping, Sequence, Tuple # noqa: type hints
import dns.inet
import dns.message
from dataformat import Reply, ResolverID
from dataformat import Reply, QID, WireFormat
ResolverID = str
IP = str
Protocol = str
Port = int
RepliesBlob = bytes
IsStreamFlag = bool # Is message preceeded by RFC 1035 section 4.2.2 length?
ReinitFlag = bool
Selector = selectors.BaseSelector
Socket = socket.socket
ResolverSockets = Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]
# module-wide state variables
__resolvers = [] # type: Sequence[Tuple[ResolverID, IP, Protocol, Port]]
__worker_state = threading.local()
__max_timeouts = 10 # crash when N consecutive timeouts are received from a single resolver
__ignore_timeout = False
__timeout = None
__time_delay_min = 0
__time_delay_max = 0
__timeout_replies = {} # type: Dict[float, Reply]
def module_init(args: Namespace) -> None:
global __resolvers
global __max_timeouts
global __ignore_timeout
global __timeout
global __time_delay_min
global __time_delay_max
__resolvers = get_resolvers(args.cfg)
__timeout = args.cfg['sendrecv']['timeout']
__time_delay_min = args.cfg['sendrecv']['time_delay_min']
__time_delay_max = args.cfg['sendrecv']['time_delay_max']
try:
__max_timeouts = args.cfg['sendrecv']['max_timeouts']
except KeyError:
pass
try:
__ignore_timeout = args.ignore_timeout
except AttributeError:
pass
def worker_init() -> None:
__worker_state.timeouts = {}
worker_reinit()
def worker_reinit() -> None:
selector, sockets = sock_init() # type: Tuple[Selector, ResolverSockets]
__worker_state.selector = selector
__worker_state.sockets = sockets
TIMEOUT_REPLIES = {} # type: Dict[float, Reply]
def worker_deinit() -> None:
selector = __worker_state.selector
sockets = __worker_state.sockets
selector.close()
for _, sck, _ in sockets: # type: ignore # python/mypy#465
sck.close()
def worker_perform_query(args: Tuple[QID, WireFormat]) -> Tuple[QID, RepliesBlob]:
"""DNS query performed by orchestrator"""
qid, qwire = args
selector = __worker_state.selector
sockets = __worker_state.sockets
# optional artificial delay for testing
if __time_delay_max > 0:
time.sleep(random.uniform(__time_delay_min, __time_delay_max))
replies, reinit = send_recv_parallel(qwire, selector, sockets, __timeout)
if not __ignore_timeout:
_check_timeout(replies)
if reinit: # a connection is broken or something
worker_deinit()
worker_reinit()
blob = pickle.dumps(replies)
return qid, blob
def get_resolvers(
config: Mapping[str, Any]
) -> Sequence[Tuple[ResolverID, IP, Protocol, Port]]:
resolvers = []
for resname in config['servers']['names']:
rescfg = config[resname]
resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
return resolvers
def _check_timeout(replies: Mapping[ResolverID, Reply]) -> None:
for resolver, reply in replies.items():
timeouts = __worker_state.timeouts
if reply.wire is not None:
timeouts[resolver] = 0
else:
timeouts[resolver] = timeouts.get(resolver, 0) + 1
if timeouts[resolver] >= __max_timeouts:
raise RuntimeError(
"Resolver '{}' timed-out {:d} times in a row. "
"Use '--ignore-timeout' to supress this error.".format(
resolver, __max_timeouts))
def sock_init(resolvers):
"""
resolvers: [(name, ipaddr, transport, port)]
returns (selector, [(name, socket, isstream)])
"""
def sock_init() -> Tuple[Selector, Sequence[Tuple[ResolverID, Socket, IsStreamFlag]]]:
sockets = []
selector = selectors.DefaultSelector()
for name, ipaddr, transport, port in resolvers:
for name, ipaddr, transport, port in __resolvers:
af = dns.inet.af_for_address(ipaddr)
if af == dns.inet.AF_INET:
destination = (ipaddr, port)
destination = (ipaddr, port) # type: Any
elif af == dns.inet.AF_INET6:
destination = (ipaddr, port, 0, 0)
else:
......@@ -46,16 +166,11 @@ def sock_init(resolvers):
sockets.append((name, sock, isstream))
selector.register(sock, selectors.EVENT_READ, (name, isstream))
# selector.close() ? # TODO
return selector, sockets
def _recv_msg(sock, isstream):
"""
receive DNS message from socket
issteam: Is message preceeded by RFC 1035 section 4.2.2 length?
returns: wire format without preambule or ConnectionError
"""
def _recv_msg(sock: Socket, isstream: IsStreamFlag) -> WireFormat:
"""Receive DNS message from socket and remove preambule (if present)."""
if isstream: # parse preambule
blength = sock.recv(2) # TODO: does not work with TLS: , socket.MSG_WAITALL)
if not blength: # stream closed
......@@ -67,18 +182,15 @@ def _recv_msg(sock, isstream):
def send_recv_parallel(
dgram,
selector,
sockets,
dgram: WireFormat, # DNS message suitable for UDP transport
selector: Selector,
sockets: ResolverSockets,
timeout: float
) -> Tuple[Mapping[ResolverID, Reply], bool]:
"""
dgram: DNS message in binary format suitable for UDP transport
"""
) -> Tuple[Mapping[ResolverID, Reply], ReinitFlag]:
replies = {} # type: Dict[ResolverID, Reply]
streammsg = None
# optimization: create only one timeout_reply object per timeout value
timeout_reply = TIMEOUT_REPLIES.setdefault(timeout, Reply(None, timeout))
timeout_reply = __timeout_replies.setdefault(timeout, Reply(None, timeout))
start_time = time.perf_counter()
end_time = start_time + timeout
for _, sock, isstream in sockets:
......@@ -96,9 +208,10 @@ def send_recv_parallel(
remaining_time = end_time - time.perf_counter()
if remaining_time <= 0:
break # timeout
events = selector.select(timeout=remaining_time)
for key, _ in events:
events = selector.select(timeout=remaining_time) # type: ignore # python/mypy#2070
for key, _ in events: # type: ignore # python/mypy#465
name, isstream = key.data
assert isinstance(key.fileobj, socket.socket) # fileobj can't be int
sock = key.fileobj
try:
wire = _recv_msg(sock, isstream)
......@@ -112,7 +225,7 @@ def send_recv_parallel(
replies[name] = Reply(wire, time.perf_counter() - start_time)
# set missing replies as timeout
for resolver, *_ in sockets:
for resolver, *_ in sockets: # type: ignore # python/mypy#465
if resolver not in replies:
replies[resolver] = timeout_reply
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment