pytests: refactor utils to generate msg ids

parent 5621cbfa
......@@ -6,7 +6,6 @@ import ssl
import subprocess
import time
import dns
import jinja2
import pytest
......@@ -43,9 +42,8 @@ def make_ssl_context():
def ping_alive(sock):
msgid = utils.random_msgid()
buf = utils.get_msgbuf('localhost.', dns.rdatatype.A, msgid)
sock.sendall(buf)
buff, msgid = utils.get_msgbuff()
sock.sendall(buff)
answer = utils.receive_parse_answer(sock)
return answer.id == msgid
......
"""TCP Connection Management tests"""
import dns
import dns.message
import utils
......@@ -13,16 +10,12 @@ def test_ignore_garbage(kresd_sock):
Expected: garbage must be ignored and the second query must be answered
"""
MSG_ID = 1
msg = utils.get_msgbuf('localhost.', dns.rdatatype.A, MSG_ID)
garbage = utils.get_prefixed_garbage(1024)
buf = garbage + msg
msg_buff, msgid = utils.get_msgbuff()
garbage_buff = utils.get_prefixed_garbage(1024)
kresd_sock.sendall(garbage_buff + msg_buff)
kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock)
assert msg_answer.id == MSG_ID
assert msg_answer.id == msgid
def test_pipelining(kresd_sock):
......@@ -31,13 +24,10 @@ def test_pipelining(kresd_sock):
Expected: answer to the second query must come first.
"""
MSG_ID_FIRST = 1
MSG_ID_SECOND = 2
buff1, msgid1 = utils.get_msgbuff('1000.delay.getdnsapi.net.', msgid=1)
buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
buff = buff1 + buff2
kresd_sock.sendall(buff)
buf = utils.get_msgbuf('1000.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_FIRST) \
+ utils.get_msgbuf('1.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_SECOND)
kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock)
assert msg_answer.id == MSG_ID_SECOND
assert msg_answer.id == msgid2
......@@ -47,7 +47,7 @@ def send_incorrect_repeatedly(sock, buff, delay=1):
def test_less_than_header(kresd_sock):
"""Prefix is less than the length of the DNS message header."""
wire = utils.prepare_wire()
wire, _ = utils.prepare_wire()
datalen = 11 # DNS header size minus 1
buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff)
......@@ -55,7 +55,7 @@ def test_less_than_header(kresd_sock):
def test_greater_than_message(kresd_sock):
"""Prefix is greater than the length of the entire DNS message."""
wire = utils.prepare_wire()
wire, _ = utils.prepare_wire()
datalen = len(wire) + 16
buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff)
......@@ -64,7 +64,7 @@ def test_greater_than_message(kresd_sock):
def test_cuts_message(kresd_sock):
"""Prefix is greater than the length of the DNS message header, but shorter than
the entire DNS message."""
wire = utils.prepare_wire()
wire, _ = utils.prepare_wire()
datalen = 14 # DNS Header size plus 2
assert datalen < len(wire)
buff = utils.prepare_buffer(wire, datalen)
......@@ -75,11 +75,10 @@ def test_cuts_message_after_ok(kresd_sock):
"""First, normal DNS message is sent. Afterwards, message with incorrect prefix
(greater than header, less than entire message) is sent. First message must be
answered, then the connection should be closed after timeout."""
normal_msg_id = 1
normal_wire = utils.prepare_wire(normal_msg_id)
normal_wire, normal_msgid = utils.prepare_wire(msgid=1)
normal_buff = utils.prepare_buffer(normal_wire)
cut_wire = utils.prepare_wire()
cut_wire, _ = utils.prepare_wire(msgid=2)
cut_datalen = 14
assert cut_datalen < len(cut_wire)
cut_buff = utils.prepare_buffer(cut_wire, cut_datalen)
......@@ -97,8 +96,8 @@ def test_trailing_garbage(kresd_sock):
"""Prefix is correct, but the message has trailing garbage. The connection must
stay open until all message have been sent and answered."""
for _ in range(10):
msgid = utils.random_msgid()
wire = utils.prepare_wire(msgid) + utils.get_garbage(8)
wire, msgid = utils.prepare_wire()
wire += utils.get_garbage(8)
buff = utils.prepare_buffer(wire)
kresd_sock.sendall(buff)
......
......@@ -5,10 +5,6 @@ import dns
import dns.message
def random_msgid():
return random.randint(1, 65535)
def receive_answer(sock):
answer_total_len = 0
data = sock.recv(2)
......@@ -39,15 +35,15 @@ def receive_parse_answer(sock):
def prepare_wire(
msgid=None,
qname='localhost.',
qtype=dns.rdatatype.A,
qclass=dns.rdataclass.IN):
qclass=dns.rdataclass.IN,
msgid=None):
"""Utility function to generate DNS wire format message"""
msg = dns.message.make_query(qname, qtype, qclass)
if msgid is not None:
msg.id = msgid
return msg.to_wire()
return msg.to_wire(), msg.id
def prepare_buffer(wire, datalen=None):
......@@ -58,22 +54,16 @@ def prepare_buffer(wire, datalen=None):
return struct.pack("!H", datalen) + wire
def get_msgbuf(qname, qtype, msgid):
# TODO remove/refactor in favor of prepare_wire, prepare_buffer
msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN)
msg.id = msgid
data = msg.to_wire()
datalen = len(data)
buf = struct.pack("!H", datalen) + data
return buf
def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
buff = prepare_buffer(wire)
return buff, msgid
def get_garbage(length):
return bytearray(random.getrandbits(8) for _ in range(length))
return bytes(random.getrandbits(8) for _ in range(length))
def get_prefixed_garbage(length):
data = get_garbage(length)
datalen = len(data)
buf = struct.pack("!H", datalen) + data
return buf
return prepare_buffer(data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment