scenario.py 38 KB
Newer Older
1 2
from __future__ import absolute_import

3
import calendar
4
import logging
5
import dns.message
Marek Vavruša's avatar
Marek Vavruša committed
6 7 8
import dns.rrset
import dns.rcode
import dns.dnssec
9
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
10
import binascii
11 12 13 14 15 16 17 18
import socket
import struct
import os
import sys
import errno
import itertools
import random
import string
Marek Vavruša's avatar
Marek Vavruša committed
19 20 21
import time
from datetime import datetime

22

23 24 25 26 27
def str2bool(v):
    """ Return conversion of JSON-ish string value to boolean. """
    return v.lower() in ('yes', 'true', 'on', '1')


28 29 30 31
# Global statistics
g_rtt = 0.0
g_nqueries = 0

32 33 34 35
#
# Element comparators
#

36 37

def create_rr(owner, args, ttl=3600, rdclass='IN', origin='.'):
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
    """ Parse RR from tokenized string. """
    if not owner.endswith('.'):
        owner += origin
    try:
        ttl = dns.ttl.from_text(args[0])
        args.pop(0)
    except:
        pass  # optional
    try:
        rdclass = dns.rdataclass.from_text(args[0])
        args.pop(0)
    except:
        pass  # optional
    rdtype = args.pop(0)
    rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
    if len(args) > 0:
54
        if rr.rdtype == dns.rdatatype.DS:
55 56
            # convert textual algorithm identifier to number
            args[1] = str(dns.dnssec.algorithm_from_text(args[1]))
57 58
        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
            args), origin=dns.name.from_text(origin), relativize=False)
59 60 61
        rr.add(rd)
    return rr

62

63 64 65 66 67 68 69 70
def compare_rrs(expected, got):
    """ Compare lists of RR sets, throw exception if different. """
    for rr in expected:
        if rr not in got:
            raise Exception("expected record '%s'" % rr.to_text())
    for rr in got:
        if rr not in expected:
            raise Exception("unexpected record '%s'" % rr.to_text())
71 72 73 74
    if len(expected) != len(got):
        raise Exception("expected %s records but got %s records "
                        "(a duplicate RR somewhere?)"
                        % (len(expected), len(got)))
75 76
    return True

Petr Špaček's avatar
Petr Špaček committed
77

78 79 80 81 82 83
def compare_val(expected, got):
    """ Compare values, throw exception if different. """
    if expected != got:
        raise Exception("expected '%s', got '%s'" % (expected, got))
    return True

84

85 86 87 88 89 90
def compare_sub(got, expected):
    """ Check if got subdomain of expected, throw exception if different. """
    if not expected.is_subdomain(got):
        raise Exception("expected subdomain of '%s', got '%s'" % (expected, got))
    return True

91

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
def recvfrom_msg(stream, raw=False):
    """
    Receive DNS message from TCP/UDP socket.

    Returns:
        if raw == False: (DNS message object, peer address)
        if raw == True: (blob, peer address)
    """
    if stream.type & socket.SOCK_DGRAM:
        data, addr = stream.recvfrom(4096)
    elif stream.type & socket.SOCK_STREAM:
        data = stream.recv(2)
        if len(data) == 0:
            return None, None
        msg_len = struct.unpack_from("!H", data)[0]
        data = b""
        received = 0
        while received < msg_len:
            next_chunk = stream.recv(4096)
            if len(next_chunk) == 0:
                return None, None
            data += next_chunk
            received += len(next_chunk)
        addr = stream.getpeername()[0]
    else:
        raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
    if not raw:
        data = dns.message.from_wire(data, one_rr_per_rrset=True)
    return data, addr


def sendto_msg(stream, message, addr=None):
    """ Send DNS/UDP/TCP message. """
    try:
        if stream.type & socket.SOCK_DGRAM:
            if addr is None:
                stream.send(message)
            else:
                stream.sendto(message, addr)
        elif stream.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            stream.send(data)
        else:
            assert False, "[sendto_msg]: unknown socket type '%i'" % stream.type
    except:  # Failure to respond is OK, resolver should recover
        pass


140
def replay_rrs(rrs, nqueries, destination, args=[]):
141 142 143
    """ Replay list of queries and report statistics. """
    navail, queries = len(rrs), []
    chunksize = 16
144
    for i in range(nqueries if 'RAND' in args else navail):
145
        rr = rrs[i % navail]
146 147
        name = rr.name
        if 'RAND' in args:
148 149
            prefix = ''.join([random.choice(string.ascii_letters + string.digits)
                              for n in range(8)])
150 151 152 153
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
154 155 156 157 158 159 160 161 162 163 164 165 166
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
167
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
168 169 170 171 172 173 174
        if len(to_write) > 0:
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
            except:
175
                pass  # EINVAL
176 177 178 179 180 181 182 183
        if len(to_read) > 0:
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
            except:
                pass
184
        if len(to_write) == 0 and len(to_read) == 0:
185
            nwait = 0  # Timeout, started dropping packets
186
            break
187 188
    return nsent, nrcvd

189

Marek Vavruša's avatar
Marek Vavruša committed
190 191
class Entry:
    """
192 193
    Data entry represents scripted message and extra metadata,
    notably match criteria and reply adjustments.
Marek Vavruša's avatar
Marek Vavruša committed
194 195 196 197 198 199 200
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

201
    def __init__(self, lineno=0):
Marek Vavruša's avatar
Marek Vavruša committed
202 203 204 205 206
        """ Initialize data entry. """
        self.match_fields = ['opcode', 'qtype', 'qname']
        self.adjust_fields = ['copy_id']
        self.origin = '.'
        self.message = dns.message.Message()
207
        self.message.use_edns(edns=0, payload=4096)
Marek Vavruša's avatar
Marek Vavruša committed
208 209 210 211
        self.sections = []
        self.is_raw_data_entry = False
        self.raw_data_pending = False
        self.raw_data = None
212
        self.lineno = lineno
213
        self.mandatory = False
214
        self.fired = 0
Marek Vavruša's avatar
Marek Vavruša committed
215 216 217 218 219 220 221

    def match_part(self, code, msg):
        """ Compare scripted reply to given message using single criteria. """
        if code not in self.match_fields and 'all' not in self.match_fields:
            return True
        expected = self.message
        if code == 'opcode':
222
            return compare_val(expected.opcode(), msg.opcode())
Marek Vavruša's avatar
Marek Vavruša committed
223 224 225
        elif code == 'qtype':
            if len(expected.question) == 0:
                return True
226
            return compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
Marek Vavruša's avatar
Marek Vavruša committed
227 228 229 230
        elif code == 'qname':
            if len(expected.question) == 0:
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
231
            return compare_val(expected.question[0].name, qname)
232 233
        elif code == 'qcase':
            return compare_val(msg.question[0].name.labels, expected.question[0].name.labels)
Marek Vavruša's avatar
Marek Vavruša committed
234 235 236 237
        elif code == 'subdomain':
            if len(expected.question) == 0:
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
238
            return compare_sub(expected.question[0].name, qname)
Marek Vavruša's avatar
Marek Vavruša committed
239
        elif code == 'flags':
240
            return compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
241
        elif code == 'rcode':
242
            return compare_val(dns.rcode.to_text(expected.rcode()), dns.rcode.to_text(msg.rcode()))
Marek Vavruša's avatar
Marek Vavruša committed
243
        elif code == 'question':
244
            return compare_rrs(expected.question, msg.question)
245
        elif code == 'answer' or code == 'ttl':
246
            return compare_rrs(expected.answer, msg.answer)
Marek Vavruša's avatar
Marek Vavruša committed
247
        elif code == 'authority':
248
            return compare_rrs(expected.authority, msg.authority)
Marek Vavruša's avatar
Marek Vavruša committed
249
        elif code == 'additional':
250
            return compare_rrs(expected.additional, msg.additional)
251 252 253 254
        elif code == 'edns':
            if msg.edns != expected.edns:
                raise Exception('expected EDNS %d, got %d' % (expected.edns, msg.edns))
            if msg.payload != expected.payload:
255 256
                raise Exception('expected EDNS bufsize %d, got %d'
                                % (expected.payload, msg.payload))
257 258 259 260 261 262 263 264 265
        elif code == 'nsid':
            nsid_opt = None
            for opt in expected.options:
                if opt.otype == dns.edns.NSID:
                    nsid_opt = opt
                    break
            # Find matching NSID
            for opt in msg.options:
                if opt.otype == dns.edns.NSID:
266
                    if not nsid_opt:
267
                        raise Exception('unexpected NSID value "%s"' % opt.data)
268 269 270 271
                    if opt == nsid_opt:
                        return True
                    else:
                        raise Exception('expected NSID "%s", got "%s"' % (nsid_opt.data, opt.data))
272 273
            if nsid_opt:
                raise Exception('expected NSID "%s"' % nsid_opt.data)
Marek Vavruša's avatar
Marek Vavruša committed
274 275 276 277 278 279 280
        else:
            raise Exception('unknown match request "%s"' % code)

    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
        match_fields = self.match_fields
        if 'all' in match_fields:
281 282
            match_fields.remove('all')
            match_fields += ['flags'] + ['rcode'] + self.sections
Marek Vavruša's avatar
Marek Vavruša committed
283 284
        for code in match_fields:
            try:
285
                self.match_part(code, msg)
Marek Vavruša's avatar
Marek Vavruša committed
286
            except Exception as e:
287 288
                errstr = '%s in the response:\n%s' % (str(e), msg.to_text())
                raise Exception("line %d, \"%s\": %s" % (self.lineno, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
289 290

    def cmp_raw(self, raw_value):
291
        assert self.is_raw_data_entry
Marek Vavruša's avatar
Marek Vavruša committed
292 293 294 295 296 297 298
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
299
            raise Exception("raw message comparsion failed: expected %s got %s" % (expected, got))
Marek Vavruša's avatar
Marek Vavruša committed
300 301

    def set_match(self, fields):
302 303 304 305 306
        """
        Set list of conditions for message comparison

        [all, flags, question, answer, authority, additional, edns]
        """
Marek Vavruša's avatar
Marek Vavruša committed
307 308 309 310
        self.match_fields = fields

    def adjust_reply(self, query):
        """ Copy scripted reply and adjust to received query. """
311 312 313
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
Petr Špaček's avatar
Petr Špaček committed
314
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
Marek Vavruša's avatar
Marek Vavruša committed
315 316
        if 'copy_id' in self.adjust_fields:
            answer.id = query.id
317 318 319
            # Copy letter-case if the template has QD
            if len(answer.question) > 0:
                answer.question[0].name = query.question[0].name
Marek Vavruša's avatar
Marek Vavruša committed
320 321
        if 'copy_query' in self.adjust_fields:
            answer.question = query.question
322 323
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
324 325 326 327 328

        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
Marek Vavruša's avatar
Marek Vavruša committed
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
        return answer

    def set_adjust(self, fields):
        """ Set reply adjustment fields [copy_id, copy_query] """
        self.adjust_fields = fields

    def set_reply(self, fields):
        """ Set reply flags and rcode. """
        eflags = []
        flags = []
        rcode = dns.rcode.from_text(self.default_rc)
        for code in fields:
            if code == 'DO':
                eflags.append(code)
                continue
            try:
                rcode = dns.rcode.from_text(code)
            except:
                flags.append(code)
        self.message.flags = dns.flags.from_text(' '.join(flags))
        self.message.want_dnssec('DO' in eflags)
        self.message.set_rcode(rcode)

352 353 354 355
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
356
        if len(fields) > 0 and fields[0].isdigit():
357
            version = int(fields.pop(0))
358
        if len(fields) > 0 and fields[0].isdigit():
359
            bufsize = int(fields.pop(0))
360 361 362
        if bufsize == 0:
            self.message.use_edns(False)
            return
363 364
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
365
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
366
            if k.lower() == 'nsid':
367
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
368 369 370
            if k.lower() == 'subnet':
                net = v.split('/')
                subnet_addr = net[0]
371 372
                family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
                addr = socket.inet_pton(family, subnet_addr)
373 374 375
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
376 377
                addr = addr[0: (prefix + 7) / 8]
                if prefix % 8 != 0:  # Mask the last byte
378
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
379 380 381
                opts.append(dns.edns.GenericOption(8, struct.pack(
                    "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns=version, payload=bufsize, options=opts)
382

Marek Vavruša's avatar
Marek Vavruša committed
383 384 385 386 387 388 389 390 391 392 393 394
    def begin_raw(self):
        """ Set raw data pending flag. """
        self.raw_data_pending = True

    def begin_section(self, section):
        """ Begin packet section. """
        self.section = section
        self.sections.append(section.lower())

    def add_record(self, owner, args):
        """ Add record to current packet section. """
        if self.raw_data_pending is True:
395
            if self.raw_data is None:
Marek Vavruša's avatar
Marek Vavruša committed
396 397 398 399 400 401 402 403 404
                if owner == 'NULL':
                    self.raw_data = None
                else:
                    self.raw_data = binascii.unhexlify(owner)
            else:
                raise Exception('raw data already set in this entry')
            self.raw_data_pending = False
            self.is_raw_data_entry = True
        else:
405 406
            rr = create_rr(owner, args, ttl=self.default_ttl,
                           rdclass=self.default_cls, origin=self.origin)
Marek Vavruša's avatar
Marek Vavruša committed
407
            if self.section == 'QUESTION':
408 409
                if rr.rdtype == dns.rdatatype.AXFR:
                    self.message.xfr = True
Marek Vavruša's avatar
Marek Vavruša committed
410 411 412 413 414 415 416 417 418 419
                self.__rr_add(self.message.question, rr)
            elif self.section == 'ANSWER':
                self.__rr_add(self.message.answer, rr)
            elif self.section == 'AUTHORITY':
                self.__rr_add(self.message.authority, rr)
            elif self.section == 'ADDITIONAL':
                self.__rr_add(self.message.additional, rr)
            else:
                raise Exception('bad section %s' % self.section)

Petr Špaček's avatar
Petr Špaček committed
420
    def use_tsig(self, fields):
421
        tsig_keyname = fields[0]
Petr Špaček's avatar
Petr Špaček committed
422 423 424
        tsig_secret = fields[1]
        keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
        self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
425

Marek Vavruša's avatar
Marek Vavruša committed
426
    def __rr_add(self, section, rr):
427
        """Append to given section."""
428
        section.append(rr)
Marek Vavruša's avatar
Marek Vavruša committed
429

430 431 432
    def set_mandatory(self):
        self.mandatory = True

433

Marek Vavruša's avatar
Marek Vavruša committed
434 435 436 437
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """
438
    log = logging.getLogger('pydnstest.scenario.Range')
Marek Vavruša's avatar
Marek Vavruša committed
439 440 441 442 443

    def __init__(self, a, b):
        """ Initialize reply range. """
        self.a = a
        self.b = b
444
        self.addresses = set()
Marek Vavruša's avatar
Marek Vavruša committed
445
        self.stored = []
446 447 448 449 450
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
451 452
        self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
                      self.a, self.b, self.addresses, self.received, self.sent)
Marek Vavruša's avatar
Marek Vavruša committed
453 454 455 456 457 458 459 460

    def add(self, entry):
        """ Append a scripted response to the range"""
        self.stored.append(entry)

    def eligible(self, id, address):
        """ Return true if this range is eligible for fetching reply. """
        if self.a <= id <= self.b:
461
            return (None is address
462 463
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
464 465 466
        return False

    def reply(self, query):
467 468 469 470 471 472
        """
        Get answer for given query (adjusted if needed).

        Returns:
            (DNS message object) or None if there is no candidate in this range
        """
473
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
474 475 476
        for candidate in self.stored:
            try:
                candidate.match(query)
477 478 479 480 481 482
                resp = candidate.adjust_reply(query)
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
                        return None
                self.sent += 1
483
                candidate.fired += 1
484
                return resp
485
            except Exception:
Marek Vavruša's avatar
Marek Vavruša committed
486 487 488 489
                pass
        return None


490 491 492 493 494 495 496 497
class StepLogger(logging.LoggerAdapter):
    """
    Prepent Step identification before each log message.
    """
    def process(self, msg, kwargs):
        return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs


Marek Vavruša's avatar
Marek Vavruša committed
498 499 500 501 502
class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """
503
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
504 505 506 507 508

    def __init__(self, id, type, extra_args):
        """ Initialize single scenario step. """
        self.id = int(id)
        self.type = type
509 510
        self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
                              {'id': id, 'type': type})
Marek Vavruša's avatar
Marek Vavruša committed
511 512
        self.args = extra_args
        self.data = []
513
        self.queries = []
514
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
515 516
        self.answer = None
        self.raw_answer = None
517 518 519
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
Petr Špaček's avatar
Petr Špaček committed
520

521 522 523 524 525 526 527 528 529 530 531
        if type == 'CHECK_ANSWER':
            for arg in extra_args:
                param = arg.split('=')
                try:
                    if param[0] == 'REPEAT':
                        self.repeat_if_fail = int(param[1])
                    elif param[0] == 'PAUSE':
                        self.pause_if_fail = float(param[1])
                    elif param[0] == 'NEXT':
                        self.next_if_fail = int(param[1])
                except Exception as e:
532
                    raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))
533

Marek Vavruša's avatar
Marek Vavruša committed
534 535 536 537
    def add(self, entry):
        """ Append a data entry to this step. """
        self.data.append(entry)

538
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
539 540
        """ Play one step from a scenario. """
        if self.type == 'QUERY':
541 542
            self.log.info('')
            self.log.debug(self.data[0].message.to_text())
543 544 545
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
            for v in self.args:
546
                if '=' in v:  # Key=Value
547 548 549 550 551 552 553
                    v = v.split('=')
                    if v[0].lower() == 'source':
                        source = v[1]
                elif v.lower() == 'tcp':
                    tcp = True
                else:
                    choice = v
554
            return self.__query(ctx, tcp=tcp, choice=choice, source=source)
Marek Vavruša's avatar
Marek Vavruša committed
555
        elif self.type == 'CHECK_OUT_QUERY':
556
            self.log.info('')
557
            pass  # Ignore
558
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
559
            self.log.info('')
Marek Vavruša's avatar
Marek Vavruša committed
560 561
            return self.__check_answer(ctx)
        elif self.type == 'TIME_PASSES':
562
            self.log.info('')
563
            return self.__time_passes()
564
        elif self.type == 'REPLY' or self.type == 'MOCK':
565
            self.log.info('')
566 567 568 569
        elif self.type == 'LOG':
            if not ctx.log:
                raise Exception('scenario has no log interface')
            return ctx.log.match(self.args)
570 571
        elif self.type == 'REPLAY':
            self.__replay(ctx)
572 573
        elif self.type == 'ASSERT':
            self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
574
        else:
575
            raise Exception('step %03d type %s unsupported' % (self.id, self.type))
576

Marek Vavruša's avatar
Marek Vavruša committed
577 578 579 580 581 582
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
        if len(self.data) == 0:
            raise Exception("response definition required")
        expected = self.data[0]
        if expected.is_raw_data_entry is True:
583
            self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
584 585 586 587
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
                raise Exception("no answer from preceding query")
588
            self.log.debug("answer: %s", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
589 590
            expected.match(ctx.last_answer)

591
    def __replay(self, ctx):
592
        nqueries = len(self.queries)
593 594
        if len(self.args) > 0 and self.args[0].isdigit():
            nqueries = int(self.args.pop(0))
595
        destination = ctx.client[ctx.client.keys()[0]]
596 597
        self.log.info('replaying %d queries to %s@%d (%s)',
                      nqueries, destination[0], destination[1], ' '.join(self.args))
598 599
        if 'INTENSIFY' in os.environ:
            nqueries *= int(os.environ['INTENSIFY'])
600
        tstart = datetime.now()
601 602 603 604
        nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
        # Keep/print the statistics
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        pps = 1000 * nrcvd / rtt
605
        self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
606 607 608 609 610
        tag = None
        for arg in self.args:
            if arg.upper().startswith('PRINT'):
                _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
        if tag:
611 612
            self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
                          tag.ljust(11), pps, rtt, nsent, nrcvd)
613

614
    def __query(self, ctx, tcp=False, choice=None, source=None):
615 616 617 618 619
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
Marek Vavruša's avatar
Marek Vavruša committed
620 621 622 623 624 625 626
        if len(self.data) == 0:
            raise Exception("query definition required")
        if self.data[0].is_raw_data_entry is True:
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
627
        if choice is None or len(choice) == 0:
628
            choice = list(ctx.client.keys())[0]
629
        if choice not in ctx.client:
630
            raise Exception('step %03d invalid QUERY target: %s' % (self.id, choice))
631 632 633
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
634
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
635 636
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
637
        if tcp:
638 639 640
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
641
            sock.bind((source, 0))
642
        sock.connect(destination)
643
        # Send query to client and wait for response
644
        tstart = datetime.now()
Marek Vavruša's avatar
Marek Vavruša committed
645 646
        while True:
            try:
647
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
648
                break
649
            except OSError as e:
Marek Vavruša's avatar
Marek Vavruša committed
650 651 652 653 654 655
                # ENOBUFS, throttle sending
                if e.errno == errno.ENOBUFS:
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
        if not self.data[0].is_raw_data_entry:
656 657
            while True:
                try:
658
                    answer, _ = recvfrom_msg(sock, True)
659
                    break
660
                except OSError as e:
661 662
                    if e.errno == errno.ENOBUFS:
                        time.sleep(0.1)
663 664 665 666 667
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
668 669 670 671
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
672
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
Marek Vavruša's avatar
Marek Vavruša committed
673 674 675 676
        else:
            self.answer = None
        ctx.last_answer = self.answer

677
    def __time_passes(self):
Marek Vavruša's avatar
Marek Vavruša committed
678 679 680 681
        """ Modify system time. """
        time_file = open(os.environ["FAKETIME_TIMESTAMP_FILE"], 'r')
        line = time_file.readline().strip()
        time_file.close()
682
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
Marek Vavruša's avatar
Marek Vavruša committed
683 684
        t += int(self.args[1])
        time_file = open(os.environ["FAKETIME_TIMESTAMP_FILE"], 'w')
685 686
        time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
        time_file.flush()
Marek Vavruša's avatar
Marek Vavruša committed
687 688
        time_file.close()

689 690 691 692 693 694 695 696 697 698 699
    def __assert(self, ctx):
        """ Assert that a passed expression evaluates to True. """
        result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
        # Evaluate subexpressions for clarity
        subexpr = []
        for expr in self.args:
            try:
                ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
                subexpr.append(str(ee))
            except:
                subexpr.append(expr)
700 701 702
        assert result is True, '"%s" assertion fails (%s)' % (
                               ' '.join(self.args), ' '.join(subexpr))

703

Marek Vavruša's avatar
Marek Vavruša committed
704
class Scenario:
705
    log = logging.getLogger('pydnstest.scenatio.Scenario')
706 707

    def __init__(self, info, filename=''):
Marek Vavruša's avatar
Marek Vavruša committed
708 709
        """ Initialize scenario with description. """
        self.info = info
710
        self.file = filename
Marek Vavruša's avatar
Marek Vavruša committed
711
        self.ranges = []
712
        self.current_range = None
Marek Vavruša's avatar
Marek Vavruša committed
713 714
        self.steps = []
        self.current_step = None
715
        self.client = {}
Marek Vavruša's avatar
Marek Vavruša committed
716

717
    def reply(self, query, address=None):
718 719 720 721 722 723 724
        """
        Generate answer packet for given query.

        The answer can be DNS message object or a binary blob.
        Returns:
            (answer, boolean "is the answer binary blob?")
        """
725
        current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
726 727
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
728 729 730 731
        all_addresses = set()
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
732 733 734
            address = None
        # Find current valid query response range
        for rng in self.ranges:
735
            if rng.eligible(current_step_id, address):
736
                self.current_range = rng
Marek Vavruša's avatar
Marek Vavruša committed
737 738 739
                return (rng.reply(query), False)
        # Find any prescripted one-shot replies
        for step in self.steps:
740
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
                continue
            try:
                candidate = step.data[0]
                if candidate.is_raw_data_entry is False:
                    candidate.match(query)
                    step.data.remove(candidate)
                    answer = candidate.adjust_reply(query)
                    return (answer, False)
                else:
                    answer = candidate.raw_data
                    return (answer, True)
            except:
                pass
        return (None, True)

756
    def play(self, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
757
        """ Play given scenario. """
758 759
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
760

761 762 763 764 765 766 767 768
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
            except Exception as e:
769
                if step.repeat_if_fail > 0:
770 771
                    self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
                                  step.id, e, step.next_if_fail, step.repeat_if_fail)
772
                    step.repeat_if_fail -= 1
773
                    if step.pause_if_fail > 0:
774
                        time.sleep(step.pause_if_fail)
775
                    if step.next_if_fail != -1:
776 777
                        next_steps = [j for j in range(len(self.steps)) if self.steps[
                            j].id == step.next_if_fail]
778
                        if len(next_steps) == 0:
779 780
                            raise Exception('step %d: wrong NEXT value "%d"' %
                                            (step.id, step.next_if_fail))
781
                        next_step = next_steps[0]
782
                        if next_step < len(self.steps):
783 784
                            i = next_step
                        else:
785 786
                            raise Exception('step %d: Can''t branch to NEXT value "%d"' %
                                            (step.id, step.next_if_fail))
787 788 789 790
                    continue
                else:
                    raise Exception('%s step %d %s' % (self.file, step.id, str(e)))
            i = i + 1
791

792 793 794 795 796
        for r in self.ranges:
            for e in r.stored:
                if e.mandatory is True and e.fired == 0:
                    raise Exception('Mandatory section at line %d is not fired' % e.lineno)

797

798
def get_next(file_in, skip_empty=True):
799 800 801 802 803
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
        if len(line) == 0:
            return False
804 805 806 807 808 809
        quoted, escaped = False, False
        for i in range(len(line)):
            if line[i] == '\\':
                escaped = not escaped
            if not escaped and line[i] == '"':
                quoted = not quoted
810
            if line[i] in (';') and not quoted:
811 812 813 814
                line = line[0:i]
                break
            if line[i] != '\\':
                escaped = False
815 816
        tokens = ' '.join(line.strip().split()).split()
        if len(tokens) == 0:
817 818 819 820
            if skip_empty:
                continue
            else:
                return '', []
821 822 823
        op = tokens.pop(0)
        return op, tokens

824 825

def parse_entry(op, args, file_in, in_entry=False):
826
    """ Parse entry definition. """
827
    out = Entry(file_in.lineno())
828 829 830
    for op, args in iter(lambda: get_next(file_in, in_entry), False):
        if op == 'ENTRY_END' or op == '':
            in_entry = False
831
            break
832
        elif op == 'ENTRY_BEGIN':  # Optional, compatibility with Unbound tests
833 834 835 836
            if in_entry:
                raise Exception('nested ENTRY_BEGIN not supported')
            in_entry = True
            pass
837 838 839
        elif op == 'EDNS':
            out.set_edns(args)
        elif op == 'REPLY' or op == 'FLAGS':
840 841 842 843 844 845 846 847 848 849 850
            out.set_reply(args)
        elif op == 'MATCH':
            out.set_match(args)
        elif op == 'ADJUST':
            out.set_adjust(args)
        elif op == 'SECTION':
            out.begin_section(args[0])
        elif op == 'RAW':
            out.begin_raw()
        elif op == 'TSIG':
            out.use_tsig(args)
851 852
        elif op == 'MANDATORY':
            out.set_mandatory()
853 854 855 856
        else:
            out.add_record(op, args)
    return out

857

858 859 860 861 862 863 864 865 866
def parse_queries(out, file_in):
    """ Parse list of queries terminated by blank line. """
    out.queries = []
    for op, args in iter(lambda: get_next(file_in, False), False):
        if op == '':
            break
        out.queries.append(create_rr(op, args))
    return out

867
auto_step = 0
868 869


870 871 872 873 874
def parse_step(op, args, file_in):
    """ Parse range definition. """
    global auto_step
    if len(args) == 0:
        raise Exception('expected at least STEP <type>')
875 876 877 878 879
    # Auto-increment when step ID isn't specified
    if len(args) < 2 or not args[0].isdigit():
        args = [str(auto_step)] + args
    auto_step = int(args[0]) + 1
    out = Step(args[0], args[1], args[2:])
880
    if out.has_data:
881
        out.add(parse_entry(op, args, file_in))
882 883 884
    # Special steps
    if args[1] == 'REPLAY':
        parse_queries(out, file_in)
885 886 887 888 889 890 891 892 893 894
    return out


def parse_range(op, args, file_in):
    """ Parse range definition. """
    if len(args) < 2:
        raise Exception('expected RANGE_BEGIN <from> <to> [address]')
    out = Range(int(args[0]), int(args[1]))
    # Shortcut for address
    if len(args) > 2:
895
        out.addresses.add(args[2])
896 897 898 899 900 901
    # Parameters
    if len(args) > 3:
        out.args = {}
        for v in args[3:]:
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
            out.args[k] = v
902 903
    for op, args in iter(lambda: get_next(file_in), False):
        if op == 'ADDRESS':
904
            out.addresses.add(args[0])
905
        elif op == 'ENTRY_BEGIN':
906
            out.add(parse_entry(op, args, file_in, in_entry=True))
907 908 909 910 911 912 913
        elif op == 'RANGE_END':
            break
    return out


def parse_scenario(op, args, file_in):
    """ Parse scenario definition. """
914
    out = Scenario(args[0], file_in.filename())
915 916 917 918 919 920 921 922 923 924
    for op, args in iter(lambda: get_next(file_in), False):
        if op == 'SCENARIO_END':
            break
        if op == 'RANGE_BEGIN':
            out.ranges.append(parse_range(op, args, file_in))
        if op == 'STEP':
            out.steps.append(parse_step(op, args, file_in))
    return out


925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
def parse_config(scn_cfg, qmin, installdir):
    """
    Transform scene config (key, value) pairs into dict filled with defaults.
    """
    # defaults
    do_not_query_localhost = True
    harden_glue = True
    sockfamily = 0  # auto-select value for socket.getaddrinfo
    trust_anchor_list = []
    stub_addr = None
    override_timestamp = None

    features = {}
    feature_list_delimiter = ';'
    feature_pair_delimiter = '='

    for k, v in scn_cfg:
        # Enable selectively for some tests
        if k == 'do-not-query-localhost':
            do_not_query_localhost = str2bool(v)
        if k == 'harden-glue':
            harden_glue = str2bool(v)
        if k == 'query-minimization':
            qmin = str2bool(v)
        elif k == 'trust-anchor':
            trust_anchor_list.append(v.strip('"\''))
        elif k == 'val-override-timestamp':
            override_timestamp_str = v.strip('"\'')
            override_timestamp = int(override_timestamp_str)
        elif k == 'val-override-date':
            override_date_str = v.strip('"\'')
            ovr_yr = override_date_str[0:4]
            ovr_mnt = override_date_str[4:6]
            ovr_day = override_date_str[6:8]
            ovr_hr = override_date_str[8:10]
            ovr_min = override_date_str[10:12]
            ovr_sec = override_date_str[12:]
            override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
                ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
            override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
            override_timestamp = calendar.timegm(override_date)
        elif k == 'stub-addr':
            stub_addr = v.strip('"\'')
        elif k == 'features':
            feature_list = v.split(feature_list_delimiter)
            try:
                for f_item in feature_list:
                    if f_item.find(feature_pair_delimiter) != -1:
                        f_key, f_value = [x.strip()
                                          for x
                                          in f_item.split(feature_pair_delimiter, 1)]
                    else:
                        f_key = f_item.strip()
                        f_value = ""
                    features[f_key] = f_value
            except Exception as e:
                raise Exception("can't parse features (%s) in config section (%s)" % (v, str(e)))
        elif k == 'feature-list':
            try:
                f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
                if f_key not in features:
                    features[f_key] = []
                f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
                features[f_key].append(f_value)
            except Exception as e:
                raise Exception("can't parse feature-list (%s) in config section (%s)"
                                % (v, str(e)))
        elif k == 'force-ipv6' and v.upper() == 'TRUE':
            sockfamily = socket.AF_INET6

    ctx = {
        "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
        "FEATURES": features,
        "HARDEN_GLUE": str(harden_glue).lower(),
        "INSTALL_DIR": installdir,
        "QMIN": str(qmin).lower(),
        "TRUST_ANCHORS": trust_anchor_list,
    }
    if stub_addr:
        ctx['ROOT_ADDR'] = stub_addr
        # determine and verify socket family for specified root address
        gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
                                 socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
        assert len(gai) == 1
        sockfamily = gai[0][0]
    if not sockfamily:
        sockfamily = socket.AF_INET  # default to IPv4
    ctx['_SOCKET_FAMILY'] = sockfamily
    if override_timestamp:
        ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
    return ctx


1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033
def parse_file(file_in):
    """ Parse scenario from a file. """
    try:
        config = []
        line = file_in.readline()
        while len(line):
            # Zero-configuration
            if line.startswith('SCENARIO_BEGIN'):
                return parse_scenario(line, line.split(' ')[1:], file_in), config
            if line.startswith('CONFIG_END'):
                break
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
1034
                kv = [x.strip() for x in line.split(':', 1)]
1035 1036 1037 1038 1039 1040 1041 1042 1043
                if len(kv) >= 2:
                    config.append(kv)
            line = file_in.readline()

        for op, args in iter(lambda: get_next(file_in), False):
            if op == 'SCENARIO_BEGIN':
                return parse_scenario(op, args, file_in), config
        raise Exception("IGNORE (missing scenario)")
    except Exception as e:
1044
        raise Exception('%s#%d: %s' % (file_in.filename(), file_in.lineno(), str(e)))