Commit 37b9cf21 authored by Petr Špaček's avatar Petr Špaček

orchestrator: support TCP transport

Potential problems I'm aware of:

1. Errors while connecting/sending are not handled. I thing this is a
   good thing to catch problems early on but I might be wrong.

2. If connection to any resolver server fails, all connections in given
   worker are closed and reinitialized.

3. Logging and postprocessing is missing. I.e. there is not way to
   distinguish connection reset from timeout for given query.
parent 080334a1
......@@ -25,6 +25,12 @@ def comma_list(lstr):
return [name.strip() for name in lstr.split(',')]
def transport_opt(ostr):
if ostr not in {'udp', 'tcp'}:
raise ValueError('unsupported transport')
return ostr
# declarative config format description for always-present sections
# dict structure: dict[section name][key name] = type
_CFGFMT = {
......@@ -48,7 +54,8 @@ _CFGFMT = {
# dict structure: dict[key name] = type
_CFGFMT_SERVER = {
'ip': ipaddr_check,
'port': int
'port': int,
'transport': transport_opt
}
......
......@@ -18,28 +18,45 @@ global worker_state
worker_state = {} # shared by all workers
def worker_init(envdir, resolvers, init_timeout):
def worker_init(init_resolvers, init_timeout):
"""
make sure it works with distincts processes and threads as well
"""
global worker_state # initialized to empty dict
global resolvers
global timeout
resolvers = init_resolvers
timeout = init_timeout
tid = threading.current_thread().ident
selector, sockets = sendrecv.sock_init(resolvers)
worker_state[tid] = (selector, sockets)
def worker_deinit(selector, sockets):
"""
Close all sockets and selector.
"""
selector.close()
for _, sck, _ in sockets:
sck.close()
def worker_query_lmdb_wrapper(args):
global worker_state # initialized in worker_init
global timeout
qid, qwire = args
tid = threading.current_thread().ident
selector, sockets = worker_state[tid]
replies = sendrecv.send_recv_parallel(qwire, selector, sockets, timeout)
blob = pickle.dumps(replies)
replies, reinit = sendrecv.send_recv_parallel(qwire, selector, sockets, timeout)
if reinit: # a connection is broken or something
# TODO: log this?
worker_deinit(selector, sockets)
worker_init(resolvers, timeout)
blob = pickle.dumps(replies)
return (qid, blob)
......@@ -74,7 +91,7 @@ def main():
resolvers = []
for resname in args.cfg['servers']['names']:
rescfg = args.cfg[resname]
resolvers.append((resname, rescfg['ip'], rescfg['port']))
resolvers.append((resname, rescfg['ip'], rescfg['transport'], rescfg['port']))
if not dbhelper.db_exists(args.envdir, dbhelper.QUERIES_DB_NAME):
logging.critical(
......@@ -98,7 +115,7 @@ def main():
with pool.Pool(
processes=args.cfg['sendrecv']['jobs'],
initializer=worker_init,
initargs=[args.envdir, resolvers, args.cfg['sendrecv']['timeout']]) as p:
initargs=[resolvers, args.cfg['sendrecv']['timeout']]) as p:
for qid, blob in p.imap(worker_query_lmdb_wrapper, qstream, chunksize=100):
txn.put(qid, blob)
......
......@@ -13,15 +13,18 @@ names = kresd, bind, unbound
# containing IP address and port of particular server
[kresd]
ip = ::1
port = 5353
port = 53021
transport = tcp
[bind]
ip = 127.0.0.1
port = 53533
port = 53011
transport = udp
[unbound]
ip = 127.0.0.1
port = 53535
port = 53001
transport = udp
[diff]
# symbolic name of server under test
......
import os
import selectors
import socket
import struct
import dns.inet
import dns.message
......@@ -8,12 +9,12 @@ import dns.message
def sock_init(resolvers):
"""
resolvers: [(name, ipaddr, port)]
returns (selector, [(name, socket, sendtoarg)])
resolvers: [(name, ipaddr, transport, port)]
returns (selector, [(name, socket, isstream)])
"""
sockets = []
selector = selectors.DefaultSelector()
for name, ipaddr, port in resolvers:
for name, ipaddr, transport, port in resolvers:
af = dns.inet.af_for_address(ipaddr)
if af == dns.inet.AF_INET:
destination = (ipaddr, port)
......@@ -21,31 +22,74 @@ def sock_init(resolvers):
destination = (ipaddr, port, 0, 0)
else:
raise NotImplementedError('AF')
sock = socket.socket(af, socket.SOCK_DGRAM, 0)
if transport == 'tcp':
socktype = socket.SOCK_STREAM
isstream = True
elif transport == 'udp':
socktype = socket.SOCK_DGRAM
isstream = False
else:
raise NotImplementedError('socktype: {}'.format(socktype))
sock = socket.socket(af, socktype, 0)
sock.connect(destination)
sock.setblocking(False)
sockets.append((name, sock, destination))
selector.register(sock, selectors.EVENT_READ, name)
sockets.append((name, sock, isstream))
selector.register(sock, selectors.EVENT_READ, (name, isstream))
# selector.close() ? # TODO
return selector, sockets
def send_recv_parallel(what, selector, sockets, timeout):
def _recv_msg(sock, isstream):
"""
receive DNS message from socket
issteam: Is message preceeded by RFC 1034 section 4.2.2 length?
returns: wire format without preambule or ConnectionError
"""
if isstream: # parse preambule
blength = sock.recv(2, socket.MSG_WAITALL)
if len(blength) == 0: # stream closed
raise ConnectionError('TCP recv length == 0')
(length, ) = struct.unpack('!H', blength)
else:
length = 65535 # max. UDP message size, no IPv6 jumbograms
return sock.recv(length)
def send_recv_parallel(dgram, selector, sockets, timeout):
"""
dgram: DNS message in binary format suitable for UDP transport
"""
replies = {}
for _, sock, destination in sockets:
sock.sendto(what, destination)
streammsg = None
for _, sock, isstream in sockets:
if isstream: # prepend length, RFC 1034 section 4.2.2
if not streammsg:
length = len(dgram)
streammsg = struct.pack('!H', length) + dgram
sock.sendall(streammsg)
else:
sock.sendall(dgram)
# receive replies
reinit = False
while len(replies) != len(sockets):
events = selector.select(timeout=timeout) # BLEH! timeout shortening
for key, _ in events:
name = key.data
name, isstream = key.data
sock = key.fileobj
(wire, from_address) = sock.recvfrom(65535)
try:
wire = _recv_msg(sock, isstream)
except ConnectionError:
reinit = True
selector.unregister(sock)
continue # receive answers from other parties
# assert len(wire) > 14
if what[0:2] != wire[0:2]:
if dgram[0:2] != wire[0:2]:
continue # wrong msgid, this might be a delayed answer - ignore it
replies[name] = wire
if not events:
break # TIMEOUT
return replies
return replies, reinit
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment