Commit 6d1e42e2 authored by Marek Vavrusa's avatar Marek Vavrusa

scenario: step "QUERY" supports forced TCP

first parameter to STEP QUERY may be "TCP"
in this case, the test subject is queried over TCP
parent f6482333
......@@ -10,6 +10,7 @@ import itertools
import time
from datetime import datetime
from dprint import dprint
from testserver import recvfrom_msg, sendto_msg
class Entry:
"""
......@@ -131,6 +132,16 @@ class Entry:
self.message.want_dnssec('DO' in eflags)
self.message.set_rcode(rcode)
def set_edns(self, fields):
""" Set EDNS version and bufsize. """
version = 0
bufsize = 4096
if len(fields) > 0:
version = int(fields[0])
if len(fields) > 1:
bufsize = int(fields[1])
self.message.use_edns(edns = version, payload = bufsize)
def begin_raw(self):
""" Set raw data pending flag. """
self.raw_data_pending = True
......@@ -311,7 +322,7 @@ class Step:
dtag = '[ STEP %03d ] %s' % (self.id, self.type)
if self.type == 'QUERY':
dprint(dtag, self.data[0].message.to_text())
return self.__query(ctx)
return self.__query(ctx, tcp = 'TCP' in self.args)
elif self.type == 'CHECK_OUT_QUERY':
dprint(dtag, '')
pass # Ignore
......@@ -342,7 +353,7 @@ class Step:
dprint("[ __check_answer ]", ctx.last_answer.to_text())
expected.match(ctx.last_answer)
def __query(self, ctx):
def __query(self, ctx, tcp = False):
""" Resolve a query. """
if len(self.data) == 0:
raise Exception("query definition required")
......@@ -352,9 +363,10 @@ class Step:
# Don't use a message copy as the EDNS data portion is not copied.
data_to_wire = self.data[0].message.to_wire()
# Send query to client and wait for response
sock = ctx.child_tcp if tcp else ctx.child_udp
while True:
try:
ctx.child_sock.send(data_to_wire)
sendto_msg(sock, data_to_wire)
break
except OSError, e:
# ENOBUFS, throttle sending
......@@ -365,7 +377,7 @@ class Step:
if not self.data[0].is_raw_data_entry:
while True:
try:
answer, addr = ctx.child_sock.recvfrom(4096)
answer, _ = recvfrom_msg(sock, True)
break
except OSError, e:
if e.errno == errno.ENOBUFS:
......@@ -398,7 +410,8 @@ class Scenario:
self.ranges = []
self.steps = []
self.current_step = None
self.child_sock = None
self.child_udp = None
self.child_tcp = None
self.force_ipv6 = False
def reply(self, query, address = None):
......@@ -434,9 +447,13 @@ class Scenario:
def play(self, family, paddr):
""" Play given scenario. """
self.child_sock = socket.socket(family, socket.SOCK_DGRAM)
self.child_sock.settimeout(3)
self.child_sock.connect(paddr)
# Connect to tested subject
self.child_udp = socket.socket(family, socket.SOCK_DGRAM)
self.child_udp.settimeout(3)
self.child_udp.connect(paddr)
self.child_tcp = socket.socket(family, socket.SOCK_STREAM)
self.child_tcp.settimeout(3)
self.child_tcp.connect(paddr)
if len(self.steps) == 0:
raise ('no steps in this scenario')
......@@ -469,8 +486,8 @@ class Scenario:
raise Exception('step #%d %s' % (step.id, str(e)))
i = i + 1
finally:
self.child_sock.close()
self.child_sock = None
self.child_udp.close()
self.child_tcp.close()
def get_next(file_in):
""" Return next token from the input stream. """
......@@ -493,7 +510,9 @@ def parse_entry(op, args, file_in):
for op, args in iter(lambda: get_next(file_in), False):
if op == 'ENTRY_END':
break
elif op == 'REPLY':
elif op == 'EDNS':
out.set_edns(args)
elif op == 'REPLY' or op == 'FLAGS':
out.set_reply(args)
elif op == 'MATCH':
out.set_match(args)
......@@ -515,13 +534,11 @@ def parse_step(op, args, file_in):
global auto_step
if len(args) == 0:
raise Exception('expected at least STEP <type>')
if len(args) < 2:
args = [str(auto_step), args[0]]
auto_step = int(args[0]) + 1 # Add 1 when step ID isn't specified
extra_args = []
if len(args) > 2:
extra_args = args[2:]
out = Step(args[0], args[1], extra_args)
# Auto-increment when step ID isn't specified
if len(args) < 2 or not args[0].isdigit():
args = [str(auto_step)] + args
auto_step = int(args[0]) + 1
out = Step(args[0], args[1], args[2:])
if out.has_data:
op, args = get_next(file_in)
if op == 'ENTRY_BEGIN':
......
......@@ -10,7 +10,7 @@ import struct
import binascii
from dprint import dprint
def recvfrom_msg(stream):
def recvfrom_msg(stream, raw = False):
""" Receive DNS/UDP/TCP message. """
if stream.type == socket.SOCK_DGRAM:
data, addr = stream.recvfrom(4096)
......@@ -30,13 +30,18 @@ def recvfrom_msg(stream):
addr = stream.getpeername()[0]
else:
raise Exception ("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
return dns.message.from_wire(data), addr
if not raw:
data = dns.message.from_wire(data)
return data, addr
def sendto_msg(stream, message, addr):
def sendto_msg(stream, message, addr = None):
""" Send DNS/UDP/TCP message. """
try:
if stream.type == socket.SOCK_DGRAM:
stream.sendto(message, addr)
if addr is None:
stream.send(message)
else:
stream.sendto(message, addr)
elif stream.type == socket.SOCK_STREAM:
data = struct.pack("!H",len(message)) + message
stream.send(data)
......
......@@ -229,7 +229,7 @@ ENTRY_END
RANGE_END
; get cname in cache. use MX query
STEP 1 QUERY
STEP 1 QUERY TCP
ENTRY_BEGIN
REPLY RD
SECTION QUESTION
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment