Commit 0b9988d0 authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Štěpán Balážik

scenario: factor out socket handling and sending DNS queries to mock_client.py

This is useful for code clarity and for other use-cases other than Deckard.
parent 23624b39
Pipeline #44374 passed with stage
in 3 minutes and 35 seconds
"""Module takes care of sending and recieving DNS messages as a mock client"""
import errno
import socket
import struct
import time
from typing import Optional, Tuple, Union
import dns.message
import dns.inet
SOCKET_OPERATION_TIMEOUT = 5
RECEIVE_MESSAGE_SIZE = 2**16-1
THROTTLE_BY = 0.1
def handle_socket_timeout(sock: socket.socket, deadline: float):
# deadline is always time.monotonic
remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError("Server took too long to respond")
sock.settimeout(remaining)
def recv_n_bytes_from_tcp(stream: socket.socket, n: int, deadline: float) -> bytes:
# deadline is always time.monotonic
data = b""
while n != 0:
handle_socket_timeout(stream, deadline)
chunk = stream.recv(n)
# Empty bytes from socket.recv mean that socket is closed
if not chunk:
raise OSError()
n -= len(chunk)
data += chunk
return data
def recvfrom_blob(sock: socket.socket) -> Tuple[bytes, str]:
"""
Receive DNS message from TCP/UDP socket.
"""
# deadline is always time.monotonic
deadline = time.monotonic() + SOCKET_OPERATION_TIMEOUT
while True:
try:
if sock.type & socket.SOCK_DGRAM:
handle_socket_timeout(sock, deadline)
data, addr = sock.recvfrom(RECEIVE_MESSAGE_SIZE)
elif sock.type & socket.SOCK_STREAM:
# First 2 bytes of TCP packet are the size of the message
# See https://tools.ietf.org/html/rfc1035#section-4.2.2
data = recv_n_bytes_from_tcp(sock, 2, deadline)
msg_len = struct.unpack_from("!H", data)[0]
data = recv_n_bytes_from_tcp(sock, msg_len, deadline)
addr = sock.getpeername()[0]
else:
raise NotImplementedError("[recvfrom_blob]: unknown socket type '%i'" % sock.type)
return data, addr
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def recvfrom_msg(sock: socket.socket) -> Tuple[dns.message.Message, str]:
data, addr = recvfrom_blob(sock)
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(sock: socket.socket, message: bytes, addr: Optional[str] = None) -> None:
""" Send DNS/UDP/TCP message. """
try:
if sock.type & socket.SOCK_DGRAM:
if addr is None:
sock.send(message)
else:
sock.sendto(message, addr)
elif sock.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
sock.sendall(data)
else:
raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % sock.type)
except OSError as ex:
# Reference: http://lkml.iu.edu/hypermail/linux/kernel/0002.3/0709.html
if ex.errno != errno.ECONNREFUSED:
raise
def setup_socket(address: str,
port: int,
tcp: bool = False) -> socket.socket:
family = dns.inet.af_for_address(address)
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
sock.settimeout(SOCKET_OPERATION_TIMEOUT)
sock.connect((address, port))
return sock
def send_query(sock: socket.socket, query: Union[dns.message.Message, bytes]) -> None:
message = query if isinstance(query, bytes) else query.to_wire()
while True:
try:
sendto_msg(sock, message)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def get_answer(sock: socket.socket) -> bytes:
""" Compatibility function """
answer, _ = recvfrom_blob(sock)
return answer
def get_dns_message(sock: socket.socket) -> dns.message.Message:
return dns.message.from_wire(get_answer(sock))
# FIXME pylint: disable=too-many-lines
from abc import ABC
import binascii
import calendar
from datetime import datetime
import errno
import logging
import os
import posixpath
......@@ -22,6 +20,7 @@ import dns.tsigkeyring
import pydnstest.augwrap
import pydnstest.matchpart
import pydnstest.mock_client
def str2bool(v):
......@@ -34,57 +33,6 @@ g_rtt = 0.0
g_nqueries = 0
def recvfrom_msg(stream, raw=False):
"""
Receive DNS message from TCP/UDP socket.
Returns:
if raw == False: (DNS message object, peer address)
if raw == True: (blob, peer address)
"""
if stream.type & socket.SOCK_DGRAM:
data, addr = stream.recvfrom(4096)
elif stream.type & socket.SOCK_STREAM:
data = stream.recv(2)
if not data:
return None, None
msg_len = struct.unpack_from("!H", data)[0]
data = b""
received = 0
while received < msg_len:
next_chunk = stream.recv(4096)
if not next_chunk:
return None, None
data += next_chunk
received += len(next_chunk)
addr = stream.getpeername()[0]
else:
raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
if raw:
return data, addr
else:
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(stream, message, addr=None):
""" Send DNS/UDP/TCP message. """
try:
if stream.type & socket.SOCK_DGRAM:
if addr is None:
stream.send(message)
else:
stream.sendto(message, addr)
elif stream.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
stream.send(data)
else:
raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
except socket.error as ex:
if ex.errno != errno.ECONNREFUSED: # TODO Investigate how this can happen
raise
class DNSBlob(ABC):
def to_wire(self) -> bytes:
raise NotImplementedError
......@@ -580,8 +528,8 @@ class Step:
self.log.info('')
self.log.debug(self.data[0].message.to_text())
# Parse QUERY-specific parameters
choice, tcp, source = None, False, None
return self.__query(ctx, tcp=tcp, choice=choice, source=source)
choice, tcp = None, False
return self.__query(ctx, tcp=tcp, choice=choice)
elif self.type == 'CHECK_OUT_QUERY': # ignore
self.log.info('')
return None
......@@ -611,7 +559,7 @@ class Step:
self.log.debug("answer: %s", ctx.last_answer.to_text())
expected.match(ctx.last_answer)
def __query(self, ctx, tcp=False, choice=None, source=None):
def __query(self, ctx, tcp=False, choice=None):
"""
Send query and wait for an answer (if the query is not RAW).
......@@ -628,45 +576,24 @@ class Step:
choice = list(ctx.client.keys())[0]
if choice not in ctx.client:
raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
# Create socket to test subject
sock = None
destination = ctx.client[choice]
family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
sock.settimeout(3)
if source:
sock.bind((source, 0))
sock.connect(destination)
# Send query to client and wait for response
tstart = datetime.now()
while True:
try:
sendto_msg(sock, data_to_wire)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
# Wait for a response for a reasonable time
# Send query and wait for answer
answer = None
sock = pydnstest.mock_client.setup_socket(ctx.client[choice][0],
ctx.client[choice][1],
tcp)
pydnstest.mock_client.send_query(sock, data_to_wire)
if self.data[0].raw_data is None:
while True:
if (datetime.now() - tstart).total_seconds() > 5:
raise RuntimeError("Server took too long to respond")
try:
answer, _ = recvfrom_msg(sock, True)
break
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
answer = pydnstest.mock_client.get_answer(sock)
# Track RTT
rtt = (datetime.now() - tstart).total_seconds() * 1000
global g_rtt, g_nqueries
g_nqueries += 1
g_rtt += rtt
# Remember last answer for checking later
self.raw_answer = answer
ctx.last_raw_answer = answer
......
......@@ -12,7 +12,7 @@ import time
import dns.message
import dns.rdatatype
from pydnstest import scenario
from pydnstest import scenario, mock_client
class TestServer:
......@@ -87,7 +87,7 @@ class TestServer:
"""
log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client)
query, client_addr = mock_client.recvfrom_msg(client)
if query is None:
return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
......@@ -105,7 +105,7 @@ class TestServer:
else:
log.debug('response: %s', message)
scenario.sendto_msg(client, message.to_wire(), client_addr)
mock_client.sendto_msg(client, message.to_wire(), client_addr)
return True
def query_io(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment