scenario.py 39.4 KB
Newer Older
1 2
from __future__ import absolute_import

3
import binascii
4
import calendar
5
import errno
6
import logging
7
import os
8
import posixpath
9
import random
10
import socket
11
import string
12
import struct
Marek Vavruša's avatar
Marek Vavruša committed
13
import time
14
from datetime import datetime
15 16 17

import dns.dnssec
import dns.message
18
import dns.name
19 20 21
import dns.rcode
import dns.rrset
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
22

23 24
import pydnstest.augwrap

25

26 27 28 29 30
def str2bool(v):
    """ Return conversion of JSON-ish string value to boolean. """
    return v.lower() in ('yes', 'true', 'on', '1')


31 32 33 34
# Global statistics
g_rtt = 0.0
g_nqueries = 0

35 36 37 38
#
# Element comparators
#

39

40 41 42 43
def compare_rrs(expected, got):
    """ Compare lists of RR sets, throw exception if different. """
    for rr in expected:
        if rr not in got:
44
            raise ValueError("expected record '%s'" % rr.to_text())
45 46
    for rr in got:
        if rr not in expected:
47
            raise ValueError("unexpected record '%s'" % rr.to_text())
48
    if len(expected) != len(got):
49 50 51
        raise ValueError("expected %s records but got %s records "
                         "(a duplicate RR somewhere?)"
                         % (len(expected), len(got)))
52 53
    return True

Petr Špaček's avatar
Petr Špaček committed
54

55 56 57
def compare_val(expected, got):
    """ Compare values, throw exception if different. """
    if expected != got:
58
        raise ValueError("expected '%s', got '%s'" % (expected, got))
59 60
    return True

61

62 63 64
def compare_sub(got, expected):
    """ Check if got subdomain of expected, throw exception if different. """
    if not expected.is_subdomain(got):
65
        raise ValueError("expected subdomain of '%s', got '%s'" % (expected, got))
66 67
    return True

68

69 70 71 72 73 74 75 76 77 78 79 80
def recvfrom_msg(stream, raw=False):
    """
    Receive DNS message from TCP/UDP socket.

    Returns:
        if raw == False: (DNS message object, peer address)
        if raw == True: (blob, peer address)
    """
    if stream.type & socket.SOCK_DGRAM:
        data, addr = stream.recvfrom(4096)
    elif stream.type & socket.SOCK_STREAM:
        data = stream.recv(2)
81
        if not data:
82 83 84 85 86 87
            return None, None
        msg_len = struct.unpack_from("!H", data)[0]
        data = b""
        received = 0
        while received < msg_len:
            next_chunk = stream.recv(4096)
88
            if not next_chunk:
89 90 91 92 93 94
                return None, None
            data += next_chunk
            received += len(next_chunk)
        addr = stream.getpeername()[0]
    else:
        raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
95 96 97 98 99
    if raw:
        return data, addr
    else:
        msg = dns.message.from_wire(data, one_rr_per_rrset=True)
        return msg, addr
100 101 102 103 104 105 106 107 108 109 110 111 112 113


def sendto_msg(stream, message, addr=None):
    """ Send DNS/UDP/TCP message. """
    try:
        if stream.type & socket.SOCK_DGRAM:
            if addr is None:
                stream.send(message)
            else:
                stream.sendto(message, addr)
        elif stream.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            stream.send(data)
        else:
114 115 116 117
            raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
    except socket.error as ex:
        if ex.errno != errno.ECONNREFUSED:  # TODO Investigate how this can happen
            raise
118 119


120
def replay_rrs(rrs, nqueries, destination, args=[]):
121 122 123
    """ Replay list of queries and report statistics. """
    navail, queries = len(rrs), []
    chunksize = 16
124
    for i in range(nqueries if 'RAND' in args else navail):
125
        rr = rrs[i % navail]
126 127
        name = rr.name
        if 'RAND' in args:
128
            prefix = ''.join([random.choice(string.ascii_letters + string.digits)
129
                              for _ in range(8)])
130 131 132 133
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
134 135 136 137 138 139 140 141 142 143 144 145 146
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
147
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
148
        if to_write:
149 150 151 152 153
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
154
            except socket.error:
155
                pass  # EINVAL
156
        if to_read:
157 158 159 160 161
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
162
            except socket.error:
163
                pass
164
        if not to_write and not to_read:
165
            nwait = 0  # Timeout, started dropping packets
166
            break
167 168
    return nsent, nrcvd

169

Marek Vavruša's avatar
Marek Vavruša committed
170 171
class Entry:
    """
172 173
    Data entry represents scripted message and extra metadata,
    notably match criteria and reply adjustments.
Marek Vavruša's avatar
Marek Vavruša committed
174 175 176 177 178 179 180
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

181
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
182
        """ Initialize data entry. """
183
        self.node = node
Marek Vavruša's avatar
Marek Vavruša committed
184 185
        self.origin = '.'
        self.message = dns.message.Message()
186 187
        self.message.use_edns(edns=0, payload=4096)
        self.fired = 0
Marek Vavruša's avatar
Marek Vavruša committed
188

189
        # RAW
190 191 192 193
        try:
            self.raw_data = binascii.unhexlify(node["/raw"].value)
            self.is_raw_data_entry = True
            return
194
        except KeyError:
195 196 197
            self.raw_data = None
            self.is_raw_data_entry = False

198
        # MATCH
199
        self.match_fields = [m.value for m in node.match("/match")]
200

201 202 203
        if not self.match_fields:
            self.match_fields = ['opcode', 'qtype', 'qname']

204
        # FLAGS
205
        self.process_reply_line(node)
206

207
        # ADJUST
208 209 210 211
        self.adjust_fields = [m.value for m in node.match("/adjust")]
        if not self.adjust_fields:
            self.adjust_fields = ['copy_id']

212
        # MANDATORY
213
        try:
214
            self.mandatory = list(node.match("/mandatory"))[0]
215
        except (KeyError, IndexError):
216
            self.mandatory = None
217

218
        # TSIG
219 220 221 222 223 224 225 226 227
        try:
            tsig = list(node.match("/tsig"))[0]
            tsig_keyname = tsig["/keyname"].value
            tsig_secret = tsig["/secret"].value
            keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
            self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
        except (KeyError, IndexError):
            pass

228
        # SECTIONS & RECORDS
229 230 231 232 233 234 235
        self.sections = []
        for section in node.match("/section/*"):
            section_name = posixpath.basename(section.path)
            self.sections.append(section_name)
            for record in section.match("/record"):
                owner = record['/domain'].value
                if not owner.endswith("."):
236
                    owner += self.origin
237 238 239 240 241 242 243
                try:
                    ttl = dns.ttl.from_text(record['/ttl'].value)
                except KeyError:
                    ttl = self.default_ttl
                try:
                    rdclass = dns.rdataclass.from_text(record['/class'].value)
                except KeyError:
244
                    rdclass = dns.rdataclass.from_text(self.default_cls)
245 246 247 248
                rdtype = dns.rdatatype.from_text(record['/type'].value)
                rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
                if section_name != "question":
                    rd = record['/data'].value.split()
249
                    if rd:
250 251 252 253 254 255 256 257
                        if rdtype == dns.rdatatype.DS:
                            rd[1] = str(dns.dnssec.algorithm_from_text(rd[1]))
                        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
                            rd), origin=dns.name.from_text(self.origin), relativize=False)
                    rr.add(rd)
                if section_name == 'question':
                    if rr.rdtype == dns.rdatatype.AXFR:
                        self.message.xfr = True
258
                    self.message.question.append(rr)
259
                elif section_name == 'answer':
260
                    self.message.answer.append(rr)
261
                elif section_name == 'authority':
262
                    self.message.authority.append(rr)
263
                elif section_name == 'additional':
264
                    self.message.additional.append(rr)
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

    def __str__(self):
        txt = 'ENTRY_BEGIN\n'
        if not self.is_raw_data_entry:
            txt += 'MATCH {0}\n'.format(' '.join(self.match_fields))
        txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields))
        txt += 'REPLY {rcode} {flags}\n'.format(
            rcode=dns.rcode.to_text(self.message.rcode()),
            flags=' '.join([dns.flags.to_text(self.message.flags),
                            dns.flags.edns_to_text(self.message.ednsflags)])
        )
        for sect_name in ['question', 'answer', 'authority', 'additional']:
            sect = getattr(self.message, sect_name)
            if not sect:
                continue
            txt += 'SECTION {n}\n'.format(n=sect_name.upper())
            for rr in sect:
                txt += str(rr)
                txt += '\n'
        if self.is_raw_data_entry:
            txt += 'RAW\n'
            if self.raw_data:
                txt += binascii.hexlify(self.raw_data)
            else:
                txt += 'NULL'
            txt += '\n'
        txt += 'ENTRY_END\n'
        return txt

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
    def process_reply_line(self, node):
        """Extracts flags, rcode and opcode from given node and adjust dns message accordingly"""
        self.fields = [f.value for f in node.match("/reply")]
        if 'DO' in self.fields:
            self.message.want_dnssec(True)
        opcode = self.get_opcode(fields=self.fields)
        rcode = self.get_rcode(fields=self.fields)
        self.message.flags = self.get_flags(fields=self.fields)
        if rcode is not None:
            self.message.set_rcode(rcode)
        if opcode is not None:
            self.message.set_opcode(opcode)

    @classmethod
    def get_flags(cls, fields):
        """From `fields` extracts and returns flags"""
        flags = []
        for code in fields:
            try:
                dns.flags.from_text(code)  # throws KeyError on failure
                flags.append(code)
            except KeyError:
                pass
        return dns.flags.from_text(' '.join(flags))

    @classmethod
    def get_rcode(cls, fields):
        """
        From `fields` extracts and returns rcode.
        Throws `ValueError` if there are more then one rcodes
        """
        rcodes = []
        for code in fields:
            try:
                rcodes.append(dns.rcode.from_text(code))
            except dns.rcode.UnknownRcode:
                pass
        if len(rcodes) > 1:
            raise ValueError("Parse failed, too many rcode values.", rcodes)
        if len(rcodes) == 0:
            return None
        return rcodes[0]

    @classmethod
    def get_opcode(cls, fields):
        """
        From `fields` extracts and returns opcode.
        Throws `ValueError` if there are more then one opcodes
        """
        opcodes = []
        for code in fields:
            try:
                opcodes.append(dns.opcode.from_text(code))
            except dns.opcode.UnknownOpcode:
                pass
        if len(opcodes) > 1:
            raise ValueError("Parse failed, too many opcode values.")
        if len(opcodes) == 0:
            return None
        return opcodes[0]

Marek Vavruša's avatar
Marek Vavruša committed
355 356 357 358 359 360
    def match_part(self, code, msg):
        """ Compare scripted reply to given message using single criteria. """
        if code not in self.match_fields and 'all' not in self.match_fields:
            return True
        expected = self.message
        if code == 'opcode':
361
            return compare_val(expected.opcode(), msg.opcode())
Marek Vavruša's avatar
Marek Vavruša committed
362
        elif code == 'qtype':
363
            if not expected.question:
Marek Vavruša's avatar
Marek Vavruša committed
364
                return True
365
            return compare_val(expected.question[0].rdtype, msg.question[0].rdtype)
Marek Vavruša's avatar
Marek Vavruša committed
366
        elif code == 'qname':
367
            if not expected.question:
Marek Vavruša's avatar
Marek Vavruša committed
368 369
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
370
            return compare_val(expected.question[0].name, qname)
371 372
        elif code == 'qcase':
            return compare_val(msg.question[0].name.labels, expected.question[0].name.labels)
Marek Vavruša's avatar
Marek Vavruša committed
373
        elif code == 'subdomain':
374
            if not expected.question:
Marek Vavruša's avatar
Marek Vavruša committed
375 376
                return True
            qname = dns.name.from_text(msg.question[0].name.to_text().lower())
377
            return compare_sub(expected.question[0].name, qname)
Marek Vavruša's avatar
Marek Vavruša committed
378
        elif code == 'flags':
379
            return compare_val(dns.flags.to_text(expected.flags), dns.flags.to_text(msg.flags))
380
        elif code == 'rcode':
381
            return compare_val(dns.rcode.to_text(expected.rcode()), dns.rcode.to_text(msg.rcode()))
Marek Vavruša's avatar
Marek Vavruša committed
382
        elif code == 'question':
383
            return compare_rrs(expected.question, msg.question)
384
        elif code == 'answer' or code == 'ttl':
385
            return compare_rrs(expected.answer, msg.answer)
Marek Vavruša's avatar
Marek Vavruša committed
386
        elif code == 'authority':
387
            return compare_rrs(expected.authority, msg.authority)
Marek Vavruša's avatar
Marek Vavruša committed
388
        elif code == 'additional':
389
            return compare_rrs(expected.additional, msg.additional)
390 391
        elif code == 'edns':
            if msg.edns != expected.edns:
392
                raise ValueError('expected EDNS %d, got %d' % (expected.edns, msg.edns))
393
            if msg.payload != expected.payload:
394 395
                raise ValueError('expected EDNS bufsize %d, got %d'
                                 % (expected.payload, msg.payload))
396 397 398 399 400 401 402 403 404
        elif code == 'nsid':
            nsid_opt = None
            for opt in expected.options:
                if opt.otype == dns.edns.NSID:
                    nsid_opt = opt
                    break
            # Find matching NSID
            for opt in msg.options:
                if opt.otype == dns.edns.NSID:
405
                    if not nsid_opt:
406
                        raise ValueError('unexpected NSID value "%s"' % opt.data)
407 408 409
                    if opt == nsid_opt:
                        return True
                    else:
410
                        raise ValueError('expected NSID "%s", got "%s"' % (nsid_opt.data, opt.data))
411
            if nsid_opt:
412
                raise ValueError('expected NSID "%s"' % nsid_opt.data)
Marek Vavruša's avatar
Marek Vavruša committed
413
        else:
414
            raise ValueError('unknown match request "%s"' % code)
Marek Vavruša's avatar
Marek Vavruša committed
415 416 417 418 419

    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
        match_fields = self.match_fields
        if 'all' in match_fields:
420 421
            match_fields.remove('all')
            match_fields += ['flags'] + ['rcode'] + self.sections
Marek Vavruša's avatar
Marek Vavruša committed
422 423
        for code in match_fields:
            try:
424
                self.match_part(code, msg)
425 426
            except ValueError as ex:
                errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
427 428
                # TODO: cisla radku
                raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
429 430

    def cmp_raw(self, raw_value):
431
        assert self.is_raw_data_entry
Marek Vavruša's avatar
Marek Vavruša committed
432 433 434 435 436 437 438
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
439
            raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
Marek Vavruša's avatar
Marek Vavruša committed
440 441 442

    def adjust_reply(self, query):
        """ Copy scripted reply and adjust to received query. """
443 444 445
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
Petr Špaček's avatar
Petr Špaček committed
446
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
Marek Vavruša's avatar
Marek Vavruša committed
447 448
        if 'copy_id' in self.adjust_fields:
            answer.id = query.id
449
            # Copy letter-case if the template has QD
450
            if answer.question:
451
                answer.question[0].name = query.question[0].name
Marek Vavruša's avatar
Marek Vavruša committed
452 453
        if 'copy_query' in self.adjust_fields:
            answer.question = query.question
454 455
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
456 457 458 459 460

        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
Marek Vavruša's avatar
Marek Vavruša committed
461 462
        return answer

463 464 465 466
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
467
        if fields and fields[0].isdigit():
468
            version = int(fields.pop(0))
469
        if fields and fields[0].isdigit():
470
            bufsize = int(fields.pop(0))
471 472 473
        if bufsize == 0:
            self.message.use_edns(False)
            return
474 475
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
476
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
477
            if k.lower() == 'nsid':
478
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
479 480 481
            if k.lower() == 'subnet':
                net = v.split('/')
                subnet_addr = net[0]
482 483
                family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
                addr = socket.inet_pton(family, subnet_addr)
484 485 486
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
487 488
                addr = addr[0: (prefix + 7) / 8]
                if prefix % 8 != 0:  # Mask the last byte
489
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
490 491 492
                opts.append(dns.edns.GenericOption(8, struct.pack(
                    "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns=version, payload=bufsize, options=opts)
493

494

Marek Vavruša's avatar
Marek Vavruša committed
495 496 497 498
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """
499
    log = logging.getLogger('pydnstest.scenario.Range')
Marek Vavruša's avatar
Marek Vavruša committed
500

501
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
502
        """ Initialize reply range. """
503
        self.node = node
504 505 506 507 508 509 510
        self.a = int(node['/from'].value)
        self.b = int(node['/to'].value)

        address = node["/address"].value
        self.addresses = {address} if address is not None else set()
        self.addresses |= set([a.value for a in node.match("/address/*")])
        self.stored = [Entry(n) for n in node.match("/entry")]
511 512 513 514 515
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
516 517
        self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
                      self.a, self.b, self.addresses, self.received, self.sent)
Marek Vavruša's avatar
Marek Vavruša committed
518

519 520 521 522 523 524 525 526 527 528
    def __str__(self):
        txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b)
        for addr in self.addresses:
            txt += '        ADDRESS {0}\n'.format(addr)

        for entry in self.stored:
            txt += '\n'
            txt += str(entry)
        txt += 'RANGE_END\n\n'
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
529 530 531 532

    def eligible(self, id, address):
        """ Return true if this range is eligible for fetching reply. """
        if self.a <= id <= self.b:
533
            return (None is address
534 535
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
536 537 538
        return False

    def reply(self, query):
539 540 541 542 543 544
        """
        Get answer for given query (adjusted if needed).

        Returns:
            (DNS message object) or None if there is no candidate in this range
        """
545
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
546 547 548
        for candidate in self.stored:
            try:
                candidate.match(query)
549 550 551 552 553 554
                resp = candidate.adjust_reply(query)
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
                        return None
                self.sent += 1
555
                candidate.fired += 1
556
                return resp
557
            except ValueError:
Marek Vavruša's avatar
Marek Vavruša committed
558 559 560 561
                pass
        return None


562
class StepLogger(logging.LoggerAdapter):  # pylint: disable=too-few-public-methods
563 564 565 566 567 568 569
    """
    Prepent Step identification before each log message.
    """
    def process(self, msg, kwargs):
        return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs


Marek Vavruša's avatar
Marek Vavruša committed
570 571 572 573 574
class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """
575
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
576

577
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
578
        """ Initialize single scenario step. """
579
        self.node = node
580 581
        self.id = int(node.value)
        self.type = node["/type"].value
582
        self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
583
                              {'id': self.id, 'type': self.type})
584 585 586 587 588
        try:
            self.delay = int(node["/timestamp"].value)
        except KeyError:
            pass
        self.data = [Entry(n) for n in node.match("/entry")]
589
        self.queries = []
590
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
591 592
        self.answer = None
        self.raw_answer = None
593 594 595
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
Petr Špaček's avatar
Petr Špaček committed
596

597
        # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
        # if type == 'CHECK_ANSWER':
        #     for arg in extra_args:
        #         param = arg.split('=')
        #         try:
        #             if param[0] == 'REPEAT':
        #                 self.repeat_if_fail = int(param[1])
        #             elif param[0] == 'PAUSE':
        #                 self.pause_if_fail = float(param[1])
        #             elif param[0] == 'NEXT':
        #                 self.next_if_fail = int(param[1])
        #         except Exception as e:
        #             raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))

    def __str__(self):
        txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type)
        if self.repeat_if_fail:
            txt += ' REPEAT {v}'.format(v=self.repeat_if_fail)
        elif self.pause_if_fail:
            txt += ' PAUSE {v}'.format(v=self.pause_if_fail)
        elif self.next_if_fail != -1:
            txt += ' NEXT {v}'.format(v=self.next_if_fail)
        # if self.args:
        #     txt += ' '
        #     txt += ' '.join(self.args)
        txt += '\n'

        for data in self.data:
            # from IPython.core.debugger import Tracer
            # Tracer()()
            txt += str(data)
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
629

630
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
631 632
        """ Play one step from a scenario. """
        if self.type == 'QUERY':
633 634
            self.log.info('')
            self.log.debug(self.data[0].message.to_text())
635 636
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
637
            return self.__query(ctx, tcp=tcp, choice=choice, source=source)
Marek Vavruša's avatar
Marek Vavruša committed
638
        elif self.type == 'CHECK_OUT_QUERY':
639
            self.log.info('')
640
            pass  # Ignore
641
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
642
            self.log.info('')
Marek Vavruša's avatar
Marek Vavruša committed
643
            return self.__check_answer(ctx)
644
        elif self.type == 'TIME_PASSES ELAPSE':
645
            self.log.info('')
646
            return self.__time_passes()
647
        elif self.type == 'REPLY' or self.type == 'MOCK':
648
            self.log.info('')
649 650 651 652 653 654 655 656 657 658
        # Parser currently doesn't support step types LOG, REPLAY and ASSERT.
        # No test uses them.
        # elif self.type == 'LOG':
        #     if not ctx.log:
        #         raise Exception('scenario has no log interface')
        #     return ctx.log.match(self.args)
        # elif self.type == 'REPLAY':
        #     self.__replay(ctx)
        # elif self.type == 'ASSERT':
        #     self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
659
        else:
660
            raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type))
661

Marek Vavruša's avatar
Marek Vavruša committed
662 663
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
664
        if not self.data:
665
            raise ValueError("response definition required")
Marek Vavruša's avatar
Marek Vavruša committed
666 667
        expected = self.data[0]
        if expected.is_raw_data_entry is True:
668
            self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
669 670 671
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
672
                raise ValueError("no answer from preceding query")
673
            self.log.debug("answer: %s", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
674 675
            expected.match(ctx.last_answer)

676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
    # def __replay(self, ctx, chunksize=8):
    #     nqueries = len(self.queries)
    #     if len(self.args) > 0 and self.args[0].isdigit():
    #         nqueries = int(self.args.pop(0))
    #     destination = ctx.client[ctx.client.keys()[0]]
    #     self.log.info('replaying %d queries to %s@%d (%s)',
    #                   nqueries, destination[0], destination[1], ' '.join(self.args))
    #     if 'INTENSIFY' in os.environ:
    #         nqueries *= int(os.environ['INTENSIFY'])
    #     tstart = datetime.now()
    #     nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
    #     # Keep/print the statistics
    #     rtt = (datetime.now() - tstart).total_seconds() * 1000
    #     pps = 1000 * nrcvd / rtt
    #     self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
    #     tag = None
    #     for arg in self.args:
    #         if arg.upper().startswith('PRINT'):
    #             _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
    #     if tag:
    #         self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
    #                       tag.ljust(11), pps, rtt, nsent, nrcvd)
698

699
    def __query(self, ctx, tcp=False, choice=None, source=None):
700 701 702 703 704
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
705
        if not self.data:
706
            raise ValueError("query definition required")
Marek Vavruša's avatar
Marek Vavruša committed
707 708 709 710 711
        if self.data[0].is_raw_data_entry is True:
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
712
        if choice is None or not choice:
713
            choice = list(ctx.client.keys())[0]
714
        if choice not in ctx.client:
715
            raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
716 717 718
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
719
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
720 721
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
722
        if tcp:
723 724 725
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
726
            sock.bind((source, 0))
727
        sock.connect(destination)
728
        # Send query to client and wait for response
729
        tstart = datetime.now()
Marek Vavruša's avatar
Marek Vavruša committed
730 731
        while True:
            try:
732
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
733
                break
734
            except OSError as ex:
Marek Vavruša's avatar
Marek Vavruša committed
735
                # ENOBUFS, throttle sending
736
                if ex.errno == errno.ENOBUFS:
Marek Vavruša's avatar
Marek Vavruša committed
737 738 739 740
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
        if not self.data[0].is_raw_data_entry:
741 742
            while True:
                try:
743
                    answer, _ = recvfrom_msg(sock, True)
744
                    break
745 746
                except OSError as ex:
                    if ex.errno == errno.ENOBUFS:
747
                        time.sleep(0.1)
748 749 750 751 752
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
753 754 755 756
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
757
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
Marek Vavruša's avatar
Marek Vavruša committed
758 759 760 761
        else:
            self.answer = None
        ctx.last_answer = self.answer

762
    def __time_passes(self):
Marek Vavruša's avatar
Marek Vavruša committed
763
        """ Modify system time. """
764 765 766 767
        file_old = os.environ["FAKETIME_TIMESTAMP_FILE"]
        file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next"
        with open(file_old, 'r') as time_file:
            line = time_file.readline().strip()
768
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
769
        t += self.delay
770 771 772 773
        with open(file_next, 'w') as time_file:
            time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
            time_file.flush()
        os.replace(file_next, file_old)
Marek Vavruša's avatar
Marek Vavruša committed
774

775 776 777 778 779 780 781 782 783 784 785 786 787
    # def __assert(self, ctx):
    #     """ Assert that a passed expression evaluates to True. """
    #     result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #     # Evaluate subexpressions for clarity
    #     subexpr = []
    #     for expr in self.args:
    #         try:
    #             ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #             subexpr.append(str(ee))
    #         except:
    #             subexpr.append(expr)
    #     assert result is True, '"%s" assertion fails (%s)' % (
    #                            ' '.join(self.args), ' '.join(subexpr))
788

789

Marek Vavruša's avatar
Marek Vavruša committed
790
class Scenario:
791
    log = logging.getLogger('pydnstest.scenatio.Scenario')
792

793
    def __init__(self, node, filename):
Marek Vavruša's avatar
Marek Vavruša committed
794
        """ Initialize scenario with description. """
795
        self.node = node
796
        self.info = node.value
797
        self.file = filename
798
        self.ranges = [Range(n) for n in node.match("/range")]
799
        self.current_range = None
800
        self.steps = [Step(n) for n in node.match("/step")]
Marek Vavruša's avatar
Marek Vavruša committed
801
        self.current_step = None
802
        self.client = {}
Marek Vavruša's avatar
Marek Vavruša committed
803

804 805 806 807 808 809 810 811 812 813 814 815
    def __str__(self):
        txt = 'SCENARIO_BEGIN'
        if self.info:
            txt += ' {0}'.format(self.info)
        txt += '\n'
        for range in self.ranges:
            txt += str(range)
        for step in self.steps:
            txt += str(step)
        txt += "\nSCENARIO_END"
        return txt

816
    def reply(self, query, address=None):
817 818 819 820 821 822 823
        """
        Generate answer packet for given query.

        The answer can be DNS message object or a binary blob.
        Returns:
            (answer, boolean "is the answer binary blob?")
        """
824
        current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
825 826
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
827 828 829 830
        all_addresses = set()
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
831 832 833
            address = None
        # Find current valid query response range
        for rng in self.ranges:
834
            if rng.eligible(current_step_id, address):
835
                self.current_range = rng
836
                return rng.reply(query), False
Marek Vavruša's avatar
Marek Vavruša committed
837 838
        # Find any prescripted one-shot replies
        for step in self.steps:
839
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
840 841 842 843 844 845 846
                continue
            try:
                candidate = step.data[0]
                if candidate.is_raw_data_entry is False:
                    candidate.match(query)
                    step.data.remove(candidate)
                    answer = candidate.adjust_reply(query)
847
                    return answer, False
Marek Vavruša's avatar
Marek Vavruša committed
848 849
                else:
                    answer = candidate.raw_data
850
                    return answer, True
851
            except (IndexError, ValueError):
Marek Vavruša's avatar
Marek Vavruša committed
852
                pass
853
        return None, True
Marek Vavruša's avatar
Marek Vavruša committed
854

855
    def play(self, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
856
        """ Play given scenario. """
857 858
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
859

860 861 862 863 864 865 866
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
867
            except ValueError as ex:
868
                if step.repeat_if_fail > 0:
869
                    self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
870
                                  step.id, ex, step.next_if_fail, step.repeat_if_fail)
871
                    step.repeat_if_fail -= 1
872
                    if step.pause_if_fail > 0:
873
                        time.sleep(step.pause_if_fail)
874
                    if step.next_if_fail != -1:
875 876
                        next_steps = [j for j in range(len(self.steps)) if self.steps[
                            j].id == step.next_if_fail]
877
                        if not next_steps:
878 879
                            raise ValueError('step %d: wrong NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
880
                        next_step = next_steps[0]
881
                        if next_step < len(self.steps):
882 883
                            i = next_step
                        else:
884 885
                            raise ValueError('step %d: Can''t branch to NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
886 887
                    continue
                else:
888 889
                    raise ValueError('%s step %d %s' % (self.file, step.id, str(ex)))
            i += 1
890

891 892
        for r in self.ranges:
            for e in r.stored:
893
                if e.mandatory and e.fired == 0:
894
                    # TODO: cisla radku
895
                    raise RuntimeError('Mandatory section at %s not fired' % e.mandatory.span)
896

897

898
def get_next(file_in, skip_empty=True):
899 900 901
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
902
        if not line:
903
            return False
904 905 906 907 908 909
        quoted, escaped = False, False
        for i in range(len(line)):
            if line[i] == '\\':
                escaped = not escaped
            if not escaped and line[i] == '"':
                quoted = not quoted
910
            if line[i] in ';' and not quoted:
911 912 913 914
                line = line[0:i]
                break
            if line[i] != '\\':
                escaped = False
915
        tokens = ' '.join(line.strip().split()).split()
916
        if not tokens:
917 918 919 920
            if skip_empty:
                continue
            else:
                return '', []
921 922 923
        op = tokens.pop(0)
        return op, tokens

924

925 926 927
def parse_config(scn_cfg, qmin, installdir):
    """
    Transform scene config (key, value) pairs into dict filled with defaults.
928 929 930
    Returns tuple:
      context dict: {Jinja2 variable: value}
      trust anchor dict: {domain: [TA lines for particular domain]}
931 932 933 934 935 936
    """
    # defaults
    do_not_query_localhost = True
    harden_glue = True
    sockfamily = 0  # auto-select value for socket.getaddrinfo
    trust_anchor_list = []
937
    trust_anchor_files = {}
938
    negative_ta_list = []
939 940 941 942 943 944 945 946 947 948 949
    stub_addr = None
    override_timestamp = None

    features = {}
    feature_list_delimiter = ';'
    feature_pair_delimiter = '='

    for k, v in scn_cfg:
        # Enable selectively for some tests
        if k == 'do-not-query-localhost':
            do_not_query_localhost = str2bool(v)
950
        elif k == 'domain-insecure':
951
            negative_ta_list.append(v)
952
        elif k == 'harden-glue':
953
            harden_glue = str2bool(v)
954
        elif k == 'query-minimization':
955 956
            qmin = str2bool(v)
        elif k == 'trust-anchor':
957 958 959 960 961 962
            trust_anchor = v.strip('"\'')
            trust_anchor_list.append(trust_anchor)
            domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize()
            if domain not in trust_anchor_files:
                trust_anchor_files[domain] = []
            trust_anchor_files[domain].append(trust_anchor)
963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991
        elif k == 'val-override-timestamp':
            override_timestamp_str = v.strip('"\'')
            override_timestamp = int(override_timestamp_str)
        elif k == 'val-override-date':
            override_date_str = v.strip('"\'')
            ovr_yr = override_date_str[0:4]
            ovr_mnt = override_date_str[4:6]
            ovr_day = override_date_str[6:8]
            ovr_hr = override_date_str[8:10]
            ovr_min = override_date_str[10:12]
            ovr_sec = override_date_str[12:]
            override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
                ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
            override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
            override_timestamp = calendar.timegm(override_date)
        elif k == 'stub-addr':
            stub_addr = v.strip('"\'')
        elif k == 'features':
            feature_list = v.split(feature_list_delimiter)
            try:
                for f_item in feature_list:
                    if f_item.find(feature_pair_delimiter) != -1:
                        f_key, f_value = [x.strip()
                                          for x
                                          in f_item.split(feature_pair_delimiter, 1)]
                    else:
                        f_key = f_item.strip()
                        f_value = ""
                    features[f_key] = f_value
992 993
            except KeyError as ex:
                raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex)))
994 995 996 997 998 999 1000
        elif k == 'feature-list':
            try:
                f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
                if f_key not in features:
                    features[f_key] = []
                f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
                features[f_key].append(f_value)
1001 1002 1003
            except KeyError as ex:
                raise KeyError("can't parse feature-list (%s) in config section (%s)"
                               % (v, str(ex)))
1004 1005
        elif k == 'force-ipv6' and v.upper() == 'TRUE':
            sockfamily = socket.AF_INET6
1006 1007
        else:
            raise NotImplementedError('unsupported CONFIG key "%s"' % k)
1008 1009 1010

    ctx = {
        "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
1011
        "NEGATIVE_TRUST_ANCHORS": negative_ta_list,
1012 1013 1014 1015 1016
        "FEATURES": features,
        "HARDEN_GLUE": str(harden_glue).lower(),
        "INSTALL_DIR": installdir,
        "QMIN": str(qmin).lower(),
        "TRUST_ANCHORS": trust_anchor_list,
1017
        "TRUST_ANCHOR_FILES": trust_anchor_files.keys()
1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030
    }
    if stub_addr:
        ctx['ROOT_ADDR'] = stub_addr
        # determine and verify socket family for specified root address
        gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
                                 socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
        assert len(gai) == 1
        sockfamily = gai[0][0]
    if not sockfamily:
        sockfamily = socket.AF_INET  # default to IPv4
    ctx['_SOCKET_FAMILY'] = sockfamily
    if override_timestamp:
        ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
1031
    return (ctx, trust_anchor_files)
1032 1033


1034
def parse_file(path):
1035
    """ Parse scenario from a file. """
1036

1037 1038
    aug = pydnstest.augwrap.AugeasWrapper(
        confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__))
1039
    node = aug.tree
1040 1041 1042
    config = []
    for line in [c.value for c in node.match("/config/*")]:
        if line:
1043 1044 1045 1046 1047
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
1048
                kv = [x.strip() for x in line.split(':', 1)]
1049 1050
                if len(kv) >= 2:
                    config.append(kv)
1051 1052
    scenario = Scenario(node["/scenario"], posixpath.basename(node.path))
    return scenario, config