Commit 0b9988d0 authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Štěpán Balážik

scenario: factor out socket handling and sending DNS queries to mock_client.py

This is useful for code clarity and for other use-cases other than Deckard.
parent 23624b39
Pipeline #44374 passed with stage
in 3 minutes and 35 seconds
"""Module takes care of sending and recieving DNS messages as a mock client"""
import errno
import socket
import struct
import time
from typing import Optional, Tuple, Union
import dns.message
import dns.inet
SOCKET_OPERATION_TIMEOUT = 5
RECEIVE_MESSAGE_SIZE = 2**16-1
THROTTLE_BY = 0.1
def handle_socket_timeout(sock: socket.socket, deadline: float):
# deadline is always time.monotonic
remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError("Server took too long to respond")
sock.settimeout(remaining)
def recv_n_bytes_from_tcp(stream: socket.socket, n: int, deadline: float) -> bytes:
# deadline is always time.monotonic
data = b""
while n != 0:
handle_socket_timeout(stream, deadline)
chunk = stream.recv(n)
# Empty bytes from socket.recv mean that socket is closed
if not chunk:
raise OSError()
n -= len(chunk)
data += chunk
return data
def recvfrom_blob(sock: socket.socket) -> Tuple[bytes, str]:
"""
Receive DNS message from TCP/UDP socket.
"""
# deadline is always time.monotonic
deadline = time.monotonic() + SOCKET_OPERATION_TIMEOUT
while True:
try:
if sock.type & socket.SOCK_DGRAM:
handle_socket_timeout(sock, deadline)
data, addr = sock.recvfrom(RECEIVE_MESSAGE_SIZE)
elif sock.type & socket.SOCK_STREAM:
# First 2 bytes of TCP packet are the size of the message
# See https://tools.ietf.org/html/rfc1035#section-4.2.2
data = recv_n_bytes_from_tcp(sock, 2, deadline)
msg_len = struct.unpack_from("!H", data)[0]
data = recv_n_bytes_from_tcp(sock, msg_len, deadline)
addr = sock.getpeername()[0]
else:
raise NotImplementedError("[recvfrom_blob]: unknown socket type '%i'" % sock.type)
return data, addr
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def recvfrom_msg(sock: socket.socket) -> Tuple[dns.message.Message, str]:
data, addr = recvfrom_blob(sock)
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(sock: socket.socket, message: bytes, addr: Optional[str] = None) -> None:
""" Send DNS/UDP/TCP message. """
try:
if sock.type & socket.SOCK_DGRAM:
if addr is None:
sock.send(message)
else:
sock.sendto(message, addr)
elif sock.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
sock.sendall(data)
else:
raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % sock.type)
except OSError as ex:
# Reference: http://lkml.iu.edu/hypermail/linux/kernel/0002.3/0709.html
if ex.errno != errno.ECONNREFUSED:
raise
def setup_socket(address: str,
port: int,
tcp: bool = False) -> socket.socket:
family = dns.inet.af_for_address(address)
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
sock.settimeout(SOCKET_OPERATION_TIMEOUT)
sock.connect((address, port))
return sock
def send_query(sock: socket.socket, query: Union[dns.message.Message, bytes]) -> None:
message = query if isinstance(query, bytes) else query.to_wire()
while True:
try:
sendto_msg(sock, message)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def get_answer(sock: socket.socket) -> bytes:
""" Compatibility function """
answer, _ = recvfrom_blob(sock)
return answer
def get_dns_message(sock: socket.socket) -> dns.message.Message:
return dns.message.from_wire(get_answer(sock))
# FIXME pylint: disable=too-many-lines
from abc import ABC from abc import ABC
import binascii import binascii
import calendar import calendar
from datetime import datetime from datetime import datetime
import errno
import logging import logging
import os import os
import posixpath import posixpath
...@@ -22,6 +20,7 @@ import dns.tsigkeyring ...@@ -22,6 +20,7 @@ import dns.tsigkeyring
import pydnstest.augwrap import pydnstest.augwrap
import pydnstest.matchpart import pydnstest.matchpart
import pydnstest.mock_client
def str2bool(v): def str2bool(v):
...@@ -34,57 +33,6 @@ g_rtt = 0.0 ...@@ -34,57 +33,6 @@ g_rtt = 0.0
g_nqueries = 0 g_nqueries = 0
def recvfrom_msg(stream, raw=False):
"""
Receive DNS message from TCP/UDP socket.
Returns:
if raw == False: (DNS message object, peer address)
if raw == True: (blob, peer address)
"""
if stream.type & socket.SOCK_DGRAM:
data, addr = stream.recvfrom(4096)
elif stream.type & socket.SOCK_STREAM:
data = stream.recv(2)
if not data:
return None, None
msg_len = struct.unpack_from("!H", data)[0]
data = b""
received = 0
while received < msg_len:
next_chunk = stream.recv(4096)
if not next_chunk:
return None, None
data += next_chunk
received += len(next_chunk)
addr = stream.getpeername()[0]
else:
raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
if raw:
return data, addr
else:
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(stream, message, addr=None):
""" Send DNS/UDP/TCP message. """
try:
if stream.type & socket.SOCK_DGRAM:
if addr is None:
stream.send(message)
else:
stream.sendto(message, addr)
elif stream.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
stream.send(data)
else:
raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
except socket.error as ex:
if ex.errno != errno.ECONNREFUSED: # TODO Investigate how this can happen
raise
class DNSBlob(ABC): class DNSBlob(ABC):
def to_wire(self) -> bytes: def to_wire(self) -> bytes:
raise NotImplementedError raise NotImplementedError
...@@ -580,8 +528,8 @@ class Step: ...@@ -580,8 +528,8 @@ class Step:
self.log.info('') self.log.info('')
self.log.debug(self.data[0].message.to_text()) self.log.debug(self.data[0].message.to_text())
# Parse QUERY-specific parameters # Parse QUERY-specific parameters
choice, tcp, source = None, False, None choice, tcp = None, False
return self.__query(ctx, tcp=tcp, choice=choice, source=source) return self.__query(ctx, tcp=tcp, choice=choice)
elif self.type == 'CHECK_OUT_QUERY': # ignore elif self.type == 'CHECK_OUT_QUERY': # ignore
self.log.info('') self.log.info('')
return None return None
...@@ -611,7 +559,7 @@ class Step: ...@@ -611,7 +559,7 @@ class Step:
self.log.debug("answer: %s", ctx.last_answer.to_text()) self.log.debug("answer: %s", ctx.last_answer.to_text())
expected.match(ctx.last_answer) expected.match(ctx.last_answer)
def __query(self, ctx, tcp=False, choice=None, source=None): def __query(self, ctx, tcp=False, choice=None):
""" """
Send query and wait for an answer (if the query is not RAW). Send query and wait for an answer (if the query is not RAW).
...@@ -628,45 +576,24 @@ class Step: ...@@ -628,45 +576,24 @@ class Step:
choice = list(ctx.client.keys())[0] choice = list(ctx.client.keys())[0]
if choice not in ctx.client: if choice not in ctx.client:
raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice)) raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
# Create socket to test subject
sock = None
destination = ctx.client[choice]
family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
sock.settimeout(3)
if source:
sock.bind((source, 0))
sock.connect(destination)
# Send query to client and wait for response
tstart = datetime.now() tstart = datetime.now()
while True:
try: # Send query and wait for answer
sendto_msg(sock, data_to_wire)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
# Wait for a response for a reasonable time
answer = None answer = None
sock = pydnstest.mock_client.setup_socket(ctx.client[choice][0],
ctx.client[choice][1],
tcp)
pydnstest.mock_client.send_query(sock, data_to_wire)
if self.data[0].raw_data is None: if self.data[0].raw_data is None:
while True: answer = pydnstest.mock_client.get_answer(sock)
if (datetime.now() - tstart).total_seconds() > 5:
raise RuntimeError("Server took too long to respond")
try:
answer, _ = recvfrom_msg(sock, True)
break
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
# Track RTT # Track RTT
rtt = (datetime.now() - tstart).total_seconds() * 1000 rtt = (datetime.now() - tstart).total_seconds() * 1000
global g_rtt, g_nqueries global g_rtt, g_nqueries
g_nqueries += 1 g_nqueries += 1
g_rtt += rtt g_rtt += rtt
# Remember last answer for checking later # Remember last answer for checking later
self.raw_answer = answer self.raw_answer = answer
ctx.last_raw_answer = answer ctx.last_raw_answer = answer
......
...@@ -12,7 +12,7 @@ import time ...@@ -12,7 +12,7 @@ import time
import dns.message import dns.message
import dns.rdatatype import dns.rdatatype
from pydnstest import scenario from pydnstest import scenario, mock_client
class TestServer: class TestServer:
...@@ -87,7 +87,7 @@ class TestServer: ...@@ -87,7 +87,7 @@ class TestServer:
""" """
log = logging.getLogger('pydnstest.testserver.handle_query') log = logging.getLogger('pydnstest.testserver.handle_query')
server_addr = client.getsockname()[0] server_addr = client.getsockname()[0]
query, client_addr = scenario.recvfrom_msg(client) query, client_addr = mock_client.recvfrom_msg(client)
if query is None: if query is None:
return False return False
log.debug('server %s received query from %s: %s', server_addr, client_addr, query) log.debug('server %s received query from %s: %s', server_addr, client_addr, query)
...@@ -105,7 +105,7 @@ class TestServer: ...@@ -105,7 +105,7 @@ class TestServer:
else: else:
log.debug('response: %s', message) log.debug('response: %s', message)
scenario.sendto_msg(client, message.to_wire(), client_addr) mock_client.sendto_msg(client, message.to_wire(), client_addr)
return True return True
def query_io(self): def query_io(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment