pytests: refactor utils to generate msg ids

parent 5621cbfa
...@@ -6,7 +6,6 @@ import ssl ...@@ -6,7 +6,6 @@ import ssl
import subprocess import subprocess
import time import time
import dns
import jinja2 import jinja2
import pytest import pytest
...@@ -43,9 +42,8 @@ def make_ssl_context(): ...@@ -43,9 +42,8 @@ def make_ssl_context():
def ping_alive(sock): def ping_alive(sock):
msgid = utils.random_msgid() buff, msgid = utils.get_msgbuff()
buf = utils.get_msgbuf('localhost.', dns.rdatatype.A, msgid) sock.sendall(buff)
sock.sendall(buf)
answer = utils.receive_parse_answer(sock) answer = utils.receive_parse_answer(sock)
return answer.id == msgid return answer.id == msgid
......
"""TCP Connection Management tests""" """TCP Connection Management tests"""
import dns
import dns.message
import utils import utils
...@@ -13,16 +10,12 @@ def test_ignore_garbage(kresd_sock): ...@@ -13,16 +10,12 @@ def test_ignore_garbage(kresd_sock):
Expected: garbage must be ignored and the second query must be answered Expected: garbage must be ignored and the second query must be answered
""" """
MSG_ID = 1 msg_buff, msgid = utils.get_msgbuff()
garbage_buff = utils.get_prefixed_garbage(1024)
msg = utils.get_msgbuf('localhost.', dns.rdatatype.A, MSG_ID) kresd_sock.sendall(garbage_buff + msg_buff)
garbage = utils.get_prefixed_garbage(1024)
buf = garbage + msg
kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock) msg_answer = utils.receive_parse_answer(kresd_sock)
assert msg_answer.id == msgid
assert msg_answer.id == MSG_ID
def test_pipelining(kresd_sock): def test_pipelining(kresd_sock):
...@@ -31,13 +24,10 @@ def test_pipelining(kresd_sock): ...@@ -31,13 +24,10 @@ def test_pipelining(kresd_sock):
Expected: answer to the second query must come first. Expected: answer to the second query must come first.
""" """
MSG_ID_FIRST = 1 buff1, msgid1 = utils.get_msgbuff('1000.delay.getdnsapi.net.', msgid=1)
MSG_ID_SECOND = 2 buff2, msgid2 = utils.get_msgbuff('1.delay.getdnsapi.net.', msgid=2)
buff = buff1 + buff2
kresd_sock.sendall(buff)
buf = utils.get_msgbuf('1000.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_FIRST) \
+ utils.get_msgbuf('1.delay.getdnsapi.net.', dns.rdatatype.A, MSG_ID_SECOND)
kresd_sock.sendall(buf)
msg_answer = utils.receive_parse_answer(kresd_sock) msg_answer = utils.receive_parse_answer(kresd_sock)
assert msg_answer.id == msgid2
assert msg_answer.id == MSG_ID_SECOND
...@@ -47,7 +47,7 @@ def send_incorrect_repeatedly(sock, buff, delay=1): ...@@ -47,7 +47,7 @@ def send_incorrect_repeatedly(sock, buff, delay=1):
def test_less_than_header(kresd_sock): def test_less_than_header(kresd_sock):
"""Prefix is less than the length of the DNS message header.""" """Prefix is less than the length of the DNS message header."""
wire = utils.prepare_wire() wire, _ = utils.prepare_wire()
datalen = 11 # DNS header size minus 1 datalen = 11 # DNS header size minus 1
buff = utils.prepare_buffer(wire, datalen) buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff) send_incorrect_repeatedly(kresd_sock, buff)
...@@ -55,7 +55,7 @@ def test_less_than_header(kresd_sock): ...@@ -55,7 +55,7 @@ def test_less_than_header(kresd_sock):
def test_greater_than_message(kresd_sock): def test_greater_than_message(kresd_sock):
"""Prefix is greater than the length of the entire DNS message.""" """Prefix is greater than the length of the entire DNS message."""
wire = utils.prepare_wire() wire, _ = utils.prepare_wire()
datalen = len(wire) + 16 datalen = len(wire) + 16
buff = utils.prepare_buffer(wire, datalen) buff = utils.prepare_buffer(wire, datalen)
send_incorrect_repeatedly(kresd_sock, buff) send_incorrect_repeatedly(kresd_sock, buff)
...@@ -64,7 +64,7 @@ def test_greater_than_message(kresd_sock): ...@@ -64,7 +64,7 @@ def test_greater_than_message(kresd_sock):
def test_cuts_message(kresd_sock): def test_cuts_message(kresd_sock):
"""Prefix is greater than the length of the DNS message header, but shorter than """Prefix is greater than the length of the DNS message header, but shorter than
the entire DNS message.""" the entire DNS message."""
wire = utils.prepare_wire() wire, _ = utils.prepare_wire()
datalen = 14 # DNS Header size plus 2 datalen = 14 # DNS Header size plus 2
assert datalen < len(wire) assert datalen < len(wire)
buff = utils.prepare_buffer(wire, datalen) buff = utils.prepare_buffer(wire, datalen)
...@@ -75,11 +75,10 @@ def test_cuts_message_after_ok(kresd_sock): ...@@ -75,11 +75,10 @@ def test_cuts_message_after_ok(kresd_sock):
"""First, normal DNS message is sent. Afterwards, message with incorrect prefix """First, normal DNS message is sent. Afterwards, message with incorrect prefix
(greater than header, less than entire message) is sent. First message must be (greater than header, less than entire message) is sent. First message must be
answered, then the connection should be closed after timeout.""" answered, then the connection should be closed after timeout."""
normal_msg_id = 1 normal_wire, normal_msgid = utils.prepare_wire(msgid=1)
normal_wire = utils.prepare_wire(normal_msg_id)
normal_buff = utils.prepare_buffer(normal_wire) normal_buff = utils.prepare_buffer(normal_wire)
cut_wire = utils.prepare_wire() cut_wire, _ = utils.prepare_wire(msgid=2)
cut_datalen = 14 cut_datalen = 14
assert cut_datalen < len(cut_wire) assert cut_datalen < len(cut_wire)
cut_buff = utils.prepare_buffer(cut_wire, cut_datalen) cut_buff = utils.prepare_buffer(cut_wire, cut_datalen)
...@@ -97,8 +96,8 @@ def test_trailing_garbage(kresd_sock): ...@@ -97,8 +96,8 @@ def test_trailing_garbage(kresd_sock):
"""Prefix is correct, but the message has trailing garbage. The connection must """Prefix is correct, but the message has trailing garbage. The connection must
stay open until all message have been sent and answered.""" stay open until all message have been sent and answered."""
for _ in range(10): for _ in range(10):
msgid = utils.random_msgid() wire, msgid = utils.prepare_wire()
wire = utils.prepare_wire(msgid) + utils.get_garbage(8) wire += utils.get_garbage(8)
buff = utils.prepare_buffer(wire) buff = utils.prepare_buffer(wire)
kresd_sock.sendall(buff) kresd_sock.sendall(buff)
......
...@@ -5,10 +5,6 @@ import dns ...@@ -5,10 +5,6 @@ import dns
import dns.message import dns.message
def random_msgid():
return random.randint(1, 65535)
def receive_answer(sock): def receive_answer(sock):
answer_total_len = 0 answer_total_len = 0
data = sock.recv(2) data = sock.recv(2)
...@@ -39,15 +35,15 @@ def receive_parse_answer(sock): ...@@ -39,15 +35,15 @@ def receive_parse_answer(sock):
def prepare_wire( def prepare_wire(
msgid=None,
qname='localhost.', qname='localhost.',
qtype=dns.rdatatype.A, qtype=dns.rdatatype.A,
qclass=dns.rdataclass.IN): qclass=dns.rdataclass.IN,
msgid=None):
"""Utility function to generate DNS wire format message""" """Utility function to generate DNS wire format message"""
msg = dns.message.make_query(qname, qtype, qclass) msg = dns.message.make_query(qname, qtype, qclass)
if msgid is not None: if msgid is not None:
msg.id = msgid msg.id = msgid
return msg.to_wire() return msg.to_wire(), msg.id
def prepare_buffer(wire, datalen=None): def prepare_buffer(wire, datalen=None):
...@@ -58,22 +54,16 @@ def prepare_buffer(wire, datalen=None): ...@@ -58,22 +54,16 @@ def prepare_buffer(wire, datalen=None):
return struct.pack("!H", datalen) + wire return struct.pack("!H", datalen) + wire
def get_msgbuf(qname, qtype, msgid): def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
# TODO remove/refactor in favor of prepare_wire, prepare_buffer wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
msg = dns.message.make_query(qname, qtype, dns.rdataclass.IN) buff = prepare_buffer(wire)
msg.id = msgid return buff, msgid
data = msg.to_wire()
datalen = len(data)
buf = struct.pack("!H", datalen) + data
return buf
def get_garbage(length): def get_garbage(length):
return bytearray(random.getrandbits(8) for _ in range(length)) return bytes(random.getrandbits(8) for _ in range(length))
def get_prefixed_garbage(length): def get_prefixed_garbage(length):
data = get_garbage(length) data = get_garbage(length)
datalen = len(data) return prepare_buffer(data)
buf = struct.pack("!H", datalen) + data
return buf
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment