scenario.py 38.6 KB
Newer Older
1
# FIXME pylint: disable=too-many-lines
2
from abc import ABC
3
import binascii
4
import calendar
5
from datetime import datetime
6
import errno
7
import logging
8
import os
9
import posixpath
10
import random
11
import socket
12
import string
13
import struct
Marek Vavruša's avatar
Marek Vavruša committed
14
import time
15
from typing import Optional
16 17 18

import dns.dnssec
import dns.message
19
import dns.name
20 21 22
import dns.rcode
import dns.rrset
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
23

24
import pydnstest.augwrap
25
import pydnstest.matchpart
26

27

28 29 30 31 32
def str2bool(v):
    """ Return conversion of JSON-ish string value to boolean. """
    return v.lower() in ('yes', 'true', 'on', '1')


33 34 35 36
# Global statistics
g_rtt = 0.0
g_nqueries = 0

37

38 39 40 41 42 43 44 45 46 47 48 49
def recvfrom_msg(stream, raw=False):
    """
    Receive DNS message from TCP/UDP socket.

    Returns:
        if raw == False: (DNS message object, peer address)
        if raw == True: (blob, peer address)
    """
    if stream.type & socket.SOCK_DGRAM:
        data, addr = stream.recvfrom(4096)
    elif stream.type & socket.SOCK_STREAM:
        data = stream.recv(2)
50
        if not data:
51 52 53 54 55 56
            return None, None
        msg_len = struct.unpack_from("!H", data)[0]
        data = b""
        received = 0
        while received < msg_len:
            next_chunk = stream.recv(4096)
57
            if not next_chunk:
58 59 60 61 62 63
                return None, None
            data += next_chunk
            received += len(next_chunk)
        addr = stream.getpeername()[0]
    else:
        raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
64 65 66 67 68
    if raw:
        return data, addr
    else:
        msg = dns.message.from_wire(data, one_rr_per_rrset=True)
        return msg, addr
69 70 71 72 73 74 75 76 77 78 79 80 81 82


def sendto_msg(stream, message, addr=None):
    """ Send DNS/UDP/TCP message. """
    try:
        if stream.type & socket.SOCK_DGRAM:
            if addr is None:
                stream.send(message)
            else:
                stream.sendto(message, addr)
        elif stream.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            stream.send(data)
        else:
83 84 85 86
            raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
    except socket.error as ex:
        if ex.errno != errno.ECONNREFUSED:  # TODO Investigate how this can happen
            raise
87 88


89
def replay_rrs(rrs, nqueries, destination, args=None):
90
    """ Replay list of queries and report statistics. """
91 92
    if args is None:
        args = []
93 94
    navail, queries = len(rrs), []
    chunksize = 16
95
    for i in range(nqueries if 'RAND' in args else navail):
96
        rr = rrs[i % navail]
97 98
        name = rr.name
        if 'RAND' in args:
99
            prefix = ''.join([random.choice(string.ascii_letters + string.digits)
100
                              for _ in range(8)])
101 102 103 104
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
105 106 107 108 109 110 111 112 113 114 115 116 117
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
118
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
119
        if to_write:
120 121 122 123 124
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
125
            except socket.error:
126
                pass  # EINVAL
127
        if to_read:
128 129 130 131 132
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
133
            except socket.error:
134
                pass
135
        if not to_write and not to_read:
136
            nwait = 0  # Timeout, started dropping packets
137
            break
138 139
    return nsent, nrcvd

140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
class DNSBlob(ABC):
    def to_wire(self) -> bytes:
        raise NotImplementedError

    def __str__(self) -> str:
        return '<DNSBlob>'


class DNSMessage(DNSBlob):
    def __init__(self, message: dns.message.Message) -> None:
        assert message is not None
        self.message = message

    def to_wire(self) -> bytes:
        return self.message.to_wire(max_size=65535)

    def __str__(self) -> str:
        return str(self.message)


class DNSReply(DNSMessage):
162 163
    def __init__(
                self,
164 165 166 167
                message: dns.message.Message,
                query: Optional[dns.message.Message] = None,
                copy_id: bool = False,
                copy_query: bool = False
168
            ) -> None:
169 170 171 172 173
        super().__init__(message)
        if copy_id or copy_query:
            if query is None:
                raise ValueError("query must be provided to adjust copy_id/copy_query")
            self.adjust_reply(query, copy_id, copy_query)
174

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    def adjust_reply(
                self,
                query: dns.message.Message,
                copy_id: bool = True,
                copy_query: bool = True
            ) -> None:
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
        if copy_id:
            answer.id = query.id
            # Copy letter-case if the template has QD
            if answer.question:
                answer.question[0].name = query.question[0].name
        if copy_query:
            answer.question = query.question
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
194

195 196 197 198 199
        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
        self.message = answer
200

201 202 203 204 205 206 207 208 209 210 211

class DNSReplyRaw(DNSBlob):
    def __init__(
                self,
                wire: bytes,
                query: Optional[dns.message.Message] = None,
                copy_id: bool = False
            ) -> None:
        assert wire is not None
        self.wire = wire
        if copy_id:
212 213
            if query is None:
                raise ValueError("query must be provided to adjust copy_id")
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
            self.adjust_reply(query, copy_id)

    def adjust_reply(
                self,
                query: dns.message.Message,
                copy_id: bool = True
            ) -> None:
        if copy_id:
            if len(self.wire) < 2:
                raise ValueError(
                    'wire data must contain at least 2 bytes to adjust query id')
            raw_answer = bytearray(self.wire)
            struct.pack_into('!H', raw_answer, 0, query.id)
            self.wire = bytes(raw_answer)

    def to_wire(self) -> bytes:
        return self.wire

    def __str__(self) -> str:
        return '<DNSReplyRaw>'


class DNSReplyServfail(DNSMessage):
    def __init__(self, query: dns.message.Message) -> None:
        message = dns.message.make_response(query)
        message.set_rcode(dns.rcode.SERVFAIL)
        super().__init__(message)
241 242


Marek Vavruša's avatar
Marek Vavruša committed
243 244
class Entry:
    """
245 246
    Data entry represents scripted message and extra metadata,
    notably match criteria and reply adjustments.
Marek Vavruša's avatar
Marek Vavruša committed
247 248 249 250 251 252 253
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

254
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
255
        """ Initialize data entry. """
256
        self.node = node
Marek Vavruša's avatar
Marek Vavruša committed
257 258
        self.origin = '.'
        self.message = dns.message.Message()
259 260
        self.message.use_edns(edns=0, payload=4096)
        self.fired = 0
Marek Vavruša's avatar
Marek Vavruša committed
261

262
        # RAW
263
        self.raw_data = None  # type: Optional[bytes]
264
        self.is_raw_data_entry = self.process_raw()
265

266
        # MATCH
267
        self.match_fields = self.process_match()
268

269
        # FLAGS
270
        self.process_reply_line()
271

272
        # ADJUST
273
        self.adjust_fields = {m.value for m in node.match("/adjust")}
274

275
        # MANDATORY
276
        try:
277
            self.mandatory = list(node.match("/mandatory"))[0]
278
        except (KeyError, IndexError):
279
            self.mandatory = None
280

281
        # TSIG
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        self.process_tsig()

        # SECTIONS & RECORDS
        self.sections = self.process_sections()

    def process_raw(self):
        try:
            self.raw_data = binascii.unhexlify(self.node["/raw"].value)
            return True
        except KeyError:
            return False

    def process_match(self):
        try:
            self.node["/match_present"]
        except KeyError:
            return None

        fields = set(m.value for m in self.node.match("/match"))

        if 'all' in fields:
            fields.remove("all")
            fields |= set(["opcode", "qtype", "qname", "flags",
                           "rcode", "answer", "authority", "additional"])

        if 'question' in fields:
            fields.remove("question")
            fields |= set(["qtype", "qname"])

        return fields

    def process_reply_line(self):
        """Extracts flags, rcode and opcode from given node and adjust dns message accordingly"""
        self.fields = [f.value for f in self.node.match("/reply")]
        if 'DO' in self.fields:
            self.message.want_dnssec(True)
        opcode = self.get_opcode(fields=self.fields)
        rcode = self.get_rcode(fields=self.fields)
        self.message.flags = self.get_flags(fields=self.fields)
        if rcode is not None:
            self.message.set_rcode(rcode)
        if opcode is not None:
            self.message.set_opcode(opcode)

    def process_tsig(self):
327
        try:
328
            tsig = list(self.node.match("/tsig"))[0]
329 330 331 332 333 334 335
            tsig_keyname = tsig["/keyname"].value
            tsig_secret = tsig["/secret"].value
            keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
            self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
        except (KeyError, IndexError):
            pass

336 337 338
    def process_sections(self):
        sections = set()
        for section in self.node.match("/section/*"):
339
            section_name = posixpath.basename(section.path)
340
            sections.add(section_name)
341 342 343
            for record in section.match("/record"):
                owner = record['/domain'].value
                if not owner.endswith("."):
344
                    owner += self.origin
345 346 347 348 349 350 351
                try:
                    ttl = dns.ttl.from_text(record['/ttl'].value)
                except KeyError:
                    ttl = self.default_ttl
                try:
                    rdclass = dns.rdataclass.from_text(record['/class'].value)
                except KeyError:
352
                    rdclass = dns.rdataclass.from_text(self.default_cls)
353 354 355 356
                rdtype = dns.rdatatype.from_text(record['/type'].value)
                rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
                if section_name != "question":
                    rd = record['/data'].value.split()
357
                    if rd:
358 359 360 361 362 363 364 365
                        if rdtype == dns.rdatatype.DS:
                            rd[1] = str(dns.dnssec.algorithm_from_text(rd[1]))
                        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
                            rd), origin=dns.name.from_text(self.origin), relativize=False)
                    rr.add(rd)
                if section_name == 'question':
                    if rr.rdtype == dns.rdatatype.AXFR:
                        self.message.xfr = True
366
                    self.message.question.append(rr)
367
                elif section_name == 'answer':
368
                    self.message.answer.append(rr)
369
                elif section_name == 'authority':
370
                    self.message.authority.append(rr)
371
                elif section_name == 'additional':
372
                    self.message.additional.append(rr)
373
        return sections
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402

    def __str__(self):
        txt = 'ENTRY_BEGIN\n'
        if not self.is_raw_data_entry:
            txt += 'MATCH {0}\n'.format(' '.join(self.match_fields))
        txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields))
        txt += 'REPLY {rcode} {flags}\n'.format(
            rcode=dns.rcode.to_text(self.message.rcode()),
            flags=' '.join([dns.flags.to_text(self.message.flags),
                            dns.flags.edns_to_text(self.message.ednsflags)])
        )
        for sect_name in ['question', 'answer', 'authority', 'additional']:
            sect = getattr(self.message, sect_name)
            if not sect:
                continue
            txt += 'SECTION {n}\n'.format(n=sect_name.upper())
            for rr in sect:
                txt += str(rr)
                txt += '\n'
        if self.is_raw_data_entry:
            txt += 'RAW\n'
            if self.raw_data:
                txt += binascii.hexlify(self.raw_data)
            else:
                txt += 'NULL'
            txt += '\n'
        txt += 'ENTRY_END\n'
        return txt

403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428
    @classmethod
    def get_flags(cls, fields):
        """From `fields` extracts and returns flags"""
        flags = []
        for code in fields:
            try:
                dns.flags.from_text(code)  # throws KeyError on failure
                flags.append(code)
            except KeyError:
                pass
        return dns.flags.from_text(' '.join(flags))

    @classmethod
    def get_rcode(cls, fields):
        """
        From `fields` extracts and returns rcode.
        Throws `ValueError` if there are more then one rcodes
        """
        rcodes = []
        for code in fields:
            try:
                rcodes.append(dns.rcode.from_text(code))
            except dns.rcode.UnknownRcode:
                pass
        if len(rcodes) > 1:
            raise ValueError("Parse failed, too many rcode values.", rcodes)
429
        if not rcodes:
430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
            return None
        return rcodes[0]

    @classmethod
    def get_opcode(cls, fields):
        """
        From `fields` extracts and returns opcode.
        Throws `ValueError` if there are more then one opcodes
        """
        opcodes = []
        for code in fields:
            try:
                opcodes.append(dns.opcode.from_text(code))
            except dns.opcode.UnknownOpcode:
                pass
        if len(opcodes) > 1:
            raise ValueError("Parse failed, too many opcode values.")
447
        if not opcodes:
448 449 450
            return None
        return opcodes[0]

Marek Vavruša's avatar
Marek Vavruša committed
451 452
    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
453
        for code in self.match_fields:
Marek Vavruša's avatar
Marek Vavruša committed
454
            try:
455 456
                pydnstest.matchpart.match_part(self.message, msg, code)
            except pydnstest.matchpart.DataMismatch as ex:
457
                errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
458 459
                # TODO: cisla radku
                raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
460 461

    def cmp_raw(self, raw_value):
462
        assert self.is_raw_data_entry
Marek Vavruša's avatar
Marek Vavruša committed
463 464 465 466 467 468 469
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
470
            raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
Marek Vavruša's avatar
Marek Vavruša committed
471

472
    def reply(self, query) -> Optional[DNSBlob]:
473 474
        if 'do_not_answer' in self.adjust_fields:
            return None
475
        if self.is_raw_data_entry:
476 477 478 479 480 481
            copy_id = 'raw_data' in self.adjust_fields
            assert self.raw_data is not None
            return DNSReplyRaw(self.raw_data, query, copy_id)
        copy_id = 'copy_id' in self.adjust_fields
        copy_query = 'copy_query' in self.adjust_fields
        return DNSReply(self.message, query, copy_id, copy_query)
482

483 484 485 486
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
487
        if fields and fields[0].isdigit():
488
            version = int(fields.pop(0))
489
        if fields and fields[0].isdigit():
490
            bufsize = int(fields.pop(0))
491 492 493
        if bufsize == 0:
            self.message.use_edns(False)
            return
494 495
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
496
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
497
            if k.lower() == 'nsid':
498
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
499 500 501
            if k.lower() == 'subnet':
                net = v.split('/')
                subnet_addr = net[0]
502 503
                family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
                addr = socket.inet_pton(family, subnet_addr)
504 505 506
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
507 508
                addr = addr[0: (prefix + 7) / 8]
                if prefix % 8 != 0:  # Mask the last byte
509
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
510 511 512
                opts.append(dns.edns.GenericOption(8, struct.pack(
                    "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns=version, payload=bufsize, options=opts)
513

514

Marek Vavruša's avatar
Marek Vavruša committed
515 516 517 518
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """
519
    log = logging.getLogger('pydnstest.scenario.Range')
Marek Vavruša's avatar
Marek Vavruša committed
520

521
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
522
        """ Initialize reply range. """
523
        self.node = node
524 525
        self.a = int(node['/from'].value)
        self.b = int(node['/to'].value)
526
        assert self.a <= self.b
527 528 529

        address = node["/address"].value
        self.addresses = {address} if address is not None else set()
530
        self.addresses |= {a.value for a in node.match("/address/*")}
531
        self.stored = [Entry(n) for n in node.match("/entry")]
532 533 534 535 536
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
537 538
        self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
                      self.a, self.b, self.addresses, self.received, self.sent)
Marek Vavruša's avatar
Marek Vavruša committed
539

540 541 542 543 544 545 546 547 548 549
    def __str__(self):
        txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b)
        for addr in self.addresses:
            txt += '        ADDRESS {0}\n'.format(addr)

        for entry in self.stored:
            txt += '\n'
            txt += str(entry)
        txt += 'RANGE_END\n\n'
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
550

551
    def eligible(self, ident, address):
Marek Vavruša's avatar
Marek Vavruša committed
552
        """ Return true if this range is eligible for fetching reply. """
553
        if self.a <= ident <= self.b:
554
            return (None is address
555 556
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
557 558
        return False

559
    def reply(self, query: dns.message.Message) -> Optional[DNSBlob]:
560
        """Get answer for given query (adjusted if needed)."""
561
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
562 563 564
        for candidate in self.stored:
            try:
                candidate.match(query)
565
                resp = candidate.reply(query)
566 567 568
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
569
                        return DNSReplyServfail(query)
570
                self.sent += 1
571
                candidate.fired += 1
572
                return resp
573
            except ValueError:
Marek Vavruša's avatar
Marek Vavruša committed
574
                pass
575
        return DNSReplyServfail(query)
Marek Vavruša's avatar
Marek Vavruša committed
576 577


578
class StepLogger(logging.LoggerAdapter):  # pylint: disable=too-few-public-methods
579 580 581 582 583 584 585
    """
    Prepent Step identification before each log message.
    """
    def process(self, msg, kwargs):
        return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs


Marek Vavruša's avatar
Marek Vavruša committed
586 587 588 589 590
class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """
591
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
592

593
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
594
        """ Initialize single scenario step. """
595
        self.node = node
596 597
        self.id = int(node.value)
        self.type = node["/type"].value
598
        self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
599
                              {'id': self.id, 'type': self.type})
600 601 602 603 604
        try:
            self.delay = int(node["/timestamp"].value)
        except KeyError:
            pass
        self.data = [Entry(n) for n in node.match("/entry")]
605
        self.queries = []
606
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
607 608
        self.answer = None
        self.raw_answer = None
609 610 611
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
Petr Špaček's avatar
Petr Špaček committed
612

613
        # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
        # if type == 'CHECK_ANSWER':
        #     for arg in extra_args:
        #         param = arg.split('=')
        #         try:
        #             if param[0] == 'REPEAT':
        #                 self.repeat_if_fail = int(param[1])
        #             elif param[0] == 'PAUSE':
        #                 self.pause_if_fail = float(param[1])
        #             elif param[0] == 'NEXT':
        #                 self.next_if_fail = int(param[1])
        #         except Exception as e:
        #             raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))

    def __str__(self):
        txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type)
        if self.repeat_if_fail:
            txt += ' REPEAT {v}'.format(v=self.repeat_if_fail)
        elif self.pause_if_fail:
            txt += ' PAUSE {v}'.format(v=self.pause_if_fail)
        elif self.next_if_fail != -1:
            txt += ' NEXT {v}'.format(v=self.next_if_fail)
        # if self.args:
        #     txt += ' '
        #     txt += ' '.join(self.args)
        txt += '\n'

        for data in self.data:
            # from IPython.core.debugger import Tracer
            # Tracer()()
            txt += str(data)
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
645

646
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
647 648
        """ Play one step from a scenario. """
        if self.type == 'QUERY':
649 650
            self.log.info('')
            self.log.debug(self.data[0].message.to_text())
651 652
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
653
            return self.__query(ctx, tcp=tcp, choice=choice, source=source)
654
        elif self.type == 'CHECK_OUT_QUERY':  # ignore
655
            self.log.info('')
656
            return None
657
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
658
            self.log.info('')
Marek Vavruša's avatar
Marek Vavruša committed
659
            return self.__check_answer(ctx)
660
        elif self.type == 'TIME_PASSES ELAPSE':
661
            self.log.info('')
662
            return self.__time_passes()
663
        elif self.type == 'REPLY' or self.type == 'MOCK':
664
            self.log.info('')
665
            return None
666 667 668 669 670 671 672 673 674 675
        # Parser currently doesn't support step types LOG, REPLAY and ASSERT.
        # No test uses them.
        # elif self.type == 'LOG':
        #     if not ctx.log:
        #         raise Exception('scenario has no log interface')
        #     return ctx.log.match(self.args)
        # elif self.type == 'REPLAY':
        #     self.__replay(ctx)
        # elif self.type == 'ASSERT':
        #     self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
676
        else:
677
            raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type))
678

Marek Vavruša's avatar
Marek Vavruša committed
679 680
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
681
        if not self.data:
682
            raise ValueError("response definition required")
Marek Vavruša's avatar
Marek Vavruša committed
683 684
        expected = self.data[0]
        if expected.is_raw_data_entry is True:
685
            self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
686 687 688
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
689
                raise ValueError("no answer from preceding query")
690
            self.log.debug("answer: %s", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
691 692
            expected.match(ctx.last_answer)

693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
    # def __replay(self, ctx, chunksize=8):
    #     nqueries = len(self.queries)
    #     if len(self.args) > 0 and self.args[0].isdigit():
    #         nqueries = int(self.args.pop(0))
    #     destination = ctx.client[ctx.client.keys()[0]]
    #     self.log.info('replaying %d queries to %s@%d (%s)',
    #                   nqueries, destination[0], destination[1], ' '.join(self.args))
    #     if 'INTENSIFY' in os.environ:
    #         nqueries *= int(os.environ['INTENSIFY'])
    #     tstart = datetime.now()
    #     nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
    #     # Keep/print the statistics
    #     rtt = (datetime.now() - tstart).total_seconds() * 1000
    #     pps = 1000 * nrcvd / rtt
    #     self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
    #     tag = None
    #     for arg in self.args:
    #         if arg.upper().startswith('PRINT'):
    #             _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
    #     if tag:
    #         self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
    #                       tag.ljust(11), pps, rtt, nsent, nrcvd)
715

716
    def __query(self, ctx, tcp=False, choice=None, source=None):
717 718 719 720 721
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
722
        if not self.data:
723
            raise ValueError("query definition required")
Marek Vavruša's avatar
Marek Vavruša committed
724 725 726 727 728
        if self.data[0].is_raw_data_entry is True:
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
729
        if choice is None or not choice:
730
            choice = list(ctx.client.keys())[0]
731
        if choice not in ctx.client:
732
            raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
733 734 735
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
736
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
737 738
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
739
        if tcp:
740 741 742
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
743
            sock.bind((source, 0))
744
        sock.connect(destination)
745
        # Send query to client and wait for response
746
        tstart = datetime.now()
Marek Vavruša's avatar
Marek Vavruša committed
747 748
        while True:
            try:
749
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
750
                break
751
            except OSError as ex:
Marek Vavruša's avatar
Marek Vavruša committed
752
                # ENOBUFS, throttle sending
753
                if ex.errno == errno.ENOBUFS:
Marek Vavruša's avatar
Marek Vavruša committed
754 755 756 757
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
        if not self.data[0].is_raw_data_entry:
758
            while True:
759 760
                if (datetime.now() - tstart).total_seconds() > 5:
                    raise RuntimeError("Server took too long to respond")
761
                try:
762
                    answer, _ = recvfrom_msg(sock, True)
763
                    break
764 765
                except OSError as ex:
                    if ex.errno == errno.ENOBUFS:
766
                        time.sleep(0.1)
767 768 769 770 771
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
772 773 774 775
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
776
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
Marek Vavruša's avatar
Marek Vavruša committed
777 778 779 780
        else:
            self.answer = None
        ctx.last_answer = self.answer

781
    def __time_passes(self):
Marek Vavruša's avatar
Marek Vavruša committed
782
        """ Modify system time. """
783 784 785 786
        file_old = os.environ["FAKETIME_TIMESTAMP_FILE"]
        file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next"
        with open(file_old, 'r') as time_file:
            line = time_file.readline().strip()
787
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
788
        t += self.delay
789 790 791 792
        with open(file_next, 'w') as time_file:
            time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
            time_file.flush()
        os.replace(file_next, file_old)
Marek Vavruša's avatar
Marek Vavruša committed
793

794 795 796 797 798 799 800 801 802 803 804 805 806
    # def __assert(self, ctx):
    #     """ Assert that a passed expression evaluates to True. """
    #     result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #     # Evaluate subexpressions for clarity
    #     subexpr = []
    #     for expr in self.args:
    #         try:
    #             ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #             subexpr.append(str(ee))
    #         except:
    #             subexpr.append(expr)
    #     assert result is True, '"%s" assertion fails (%s)' % (
    #                            ' '.join(self.args), ' '.join(subexpr))
807

808

Marek Vavruša's avatar
Marek Vavruša committed
809
class Scenario:
810
    log = logging.getLogger('pydnstest.scenatio.Scenario')
811

812
    def __init__(self, node, filename):
Marek Vavruša's avatar
Marek Vavruša committed
813
        """ Initialize scenario with description. """
814
        self.node = node
815
        self.info = node.value
816
        self.file = filename
817
        self.ranges = [Range(n) for n in node.match("/range")]
818
        self.current_range = None
819
        self.steps = [Step(n) for n in node.match("/step")]
Marek Vavruša's avatar
Marek Vavruša committed
820
        self.current_step = None
821
        self.client = {}
Marek Vavruša's avatar
Marek Vavruša committed
822

823 824 825 826 827
    def __str__(self):
        txt = 'SCENARIO_BEGIN'
        if self.info:
            txt += ' {0}'.format(self.info)
        txt += '\n'
828 829
        for range_ in self.ranges:
            txt += str(range_)
830 831 832 833 834
        for step in self.steps:
            txt += str(step)
        txt += "\nSCENARIO_END"
        return txt

835
    def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]:
836
        """Generate answer packet for given query."""
837
        current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
838 839
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
840
        all_addresses = set()  # type: ignore
841 842 843
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
844 845 846
            address = None
        # Find current valid query response range
        for rng in self.ranges:
847
            if rng.eligible(current_step_id, address):
848
                self.current_range = rng
849
                return rng.reply(query)
Marek Vavruša's avatar
Marek Vavruša committed
850 851
        # Find any prescripted one-shot replies
        for step in self.steps:
852
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
853 854 855
                continue
            try:
                candidate = step.data[0]
856 857 858
                candidate.match(query)
                step.data.remove(candidate)
                return candidate.reply(query)
859
            except (IndexError, ValueError):
Marek Vavruša's avatar
Marek Vavruša committed
860
                pass
861
        return DNSReplyServfail(query)
Marek Vavruša's avatar
Marek Vavruša committed
862

863
    def play(self, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
864
        """ Play given scenario. """
865 866
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
867

868 869 870 871 872 873 874
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
875
            except ValueError as ex:
876
                if step.repeat_if_fail > 0:
877
                    self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
878
                                  step.id, ex, step.next_if_fail, step.repeat_if_fail)
879
                    step.repeat_if_fail -= 1
880
                    if step.pause_if_fail > 0:
881
                        time.sleep(step.pause_if_fail)
882
                    if step.next_if_fail != -1:
883 884
                        next_steps = [j for j in range(len(self.steps)) if self.steps[
                            j].id == step.next_if_fail]
885
                        if not next_steps:
886 887
                            raise ValueError('step %d: wrong NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
888
                        next_step = next_steps[0]
889
                        if next_step < len(self.steps):
890 891
                            i = next_step
                        else:
892 893
                            raise ValueError('step %d: Can''t branch to NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
894 895
                    continue
                else:
896 897
                    raise ValueError('%s step %d %s' % (self.file, step.id, str(ex)))
            i += 1
898

899 900
        for r in self.ranges:
            for e in r.stored:
901
                if e.mandatory and e.fired == 0:
902
                    # TODO: cisla radku
903
                    raise ValueError('Mandatory section at %s not fired' % e.mandatory.span)
904

905

906
def get_next(file_in, skip_empty=True):
907 908 909
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
910
        if not line:
911
            return False
912
        quoted, escaped = False, False
913 914
        for i, char in enumerate(line):
            if char == '\\':
915
                escaped = not escaped
916
            if not escaped and char == '"':
917
                quoted = not quoted
918
            if char == ';' and not quoted:
919 920
                line = line[0:i]
                break
921
            if char != '\\':
922
                escaped = False
923
        tokens = ' '.join(line.strip().split()).split()
924
        if not tokens:
925 926 927 928
            if skip_empty:
                continue
            else:
                return '', []
929 930 931
        op = tokens.pop(0)
        return op, tokens

932

933
def parse_config(scn_cfg, qmin, installdir):  # FIXME: pylint: disable=too-many-statements
934 935
    """
    Transform scene config (key, value) pairs into dict filled with defaults.
936 937 938
    Returns tuple:
      context dict: {Jinja2 variable: value}
      trust anchor dict: {domain: [TA lines for particular domain]}
939 940 941 942 943 944
    """
    # defaults
    do_not_query_localhost = True
    harden_glue = True
    sockfamily = 0  # auto-select value for socket.getaddrinfo
    trust_anchor_list = []
945
    trust_anchor_files = {}
946
    negative_ta_list = []
947 948 949 950 951 952 953 954 955 956 957
    stub_addr = None
    override_timestamp = None

    features = {}
    feature_list_delimiter = ';'
    feature_pair_delimiter = '='

    for k, v in scn_cfg:
        # Enable selectively for some tests
        if k == 'do-not-query-localhost':
            do_not_query_localhost = str2bool(v)
958
        elif k == 'domain-insecure':
959
            negative_ta_list.append(v)
960
        elif k == 'harden-glue':
961
            harden_glue = str2bool(v)
962
        elif k == 'query-minimization':
963 964
            qmin = str2bool(v)
        elif k == 'trust-anchor':
965 966 967 968 969 970
            trust_anchor = v.strip('"\'')
            trust_anchor_list.append(trust_anchor)
            domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize()
            if domain not in trust_anchor_files:
                trust_anchor_files[domain] = []
            trust_anchor_files[domain].append(trust_anchor)
971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999
        elif k == 'val-override-timestamp':
            override_timestamp_str = v.strip('"\'')
            override_timestamp = int(override_timestamp_str)
        elif k == 'val-override-date':
            override_date_str = v.strip('"\'')
            ovr_yr = override_date_str[0:4]
            ovr_mnt = override_date_str[4:6]
            ovr_day = override_date_str[6:8]
            ovr_hr = override_date_str[8:10]
            ovr_min = override_date_str[10:12]
            ovr_sec = override_date_str[12:]
            override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
                ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
            override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
            override_timestamp = calendar.timegm(override_date)
        elif k == 'stub-addr':
            stub_addr = v.strip('"\'')
        elif k == 'features':
            feature_list = v.split(feature_list_delimiter)
            try:
                for f_item in feature_list:
                    if f_item.find(feature_pair_delimiter) != -1:
                        f_key, f_value = [x.strip()
                                          for x
                                          in f_item.split(feature_pair_delimiter, 1)]
                    else:
                        f_key = f_item.strip()
                        f_value = ""
                    features[f_key] = f_value
1000 1001
            except KeyError as ex:
                raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex)))
1002 1003 1004 1005 1006 1007 1008
        elif k == 'feature-list':
            try:
                f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
                if f_key not in features:
                    features[f_key] = []
                f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
                features[f_key].append(f_value)
1009 1010 1011
            except KeyError as ex:
                raise KeyError("can't parse feature-list (%s) in config section (%s)"
                               % (v, str(ex)))
1012 1013
        elif k == 'force-ipv6' and v.upper() == 'TRUE':
            sockfamily = socket.AF_INET6
1014 1015
        else:
            raise NotImplementedError('unsupported CONFIG key "%s"' % k)
1016 1017 1018

    ctx = {
        "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
1019
        "NEGATIVE_TRUST_ANCHORS": negative_ta_list,
1020 1021 1022 1023 1024
        "FEATURES": features,
        "HARDEN_GLUE": str(harden_glue).lower(),
        "INSTALL_DIR": installdir,
        "QMIN": str(qmin).lower(),
        "TRUST_ANCHORS": trust_anchor_list,
1025
        "TRUST_ANCHOR_FILES": trust_anchor_files.keys()
1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
    }
    if stub_addr:
        ctx['ROOT_ADDR'] = stub_addr
        # determine and verify socket family for specified root address
        gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
                                 socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
        assert len(gai) == 1
        sockfamily = gai[0][0]
    if not sockfamily:
        sockfamily = socket.AF_INET  # default to IPv4
    ctx['_SOCKET_FAMILY'] = sockfamily
    if override_timestamp:
        ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
1039
    return (ctx, trust_anchor_files)
1040 1041


1042
def parse_file(path):
1043
    """ Parse scenario from a file. """
1044

1045 1046
    aug = pydnstest.augwrap.AugeasWrapper(
        confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__))
1047
    node = aug.tree
1048 1049 1050
    config = []
    for line in [c.value for c in node.match("/config/*")]:
        if line:
1051 1052 1053 1054 1055
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
1056
                kv = [x.strip() for x in line.split(':', 1)]
1057 1058
                if len(kv) >= 2:
                    config.append(kv)
1059 1060
    scenario = Scenario(node["/scenario"], posixpath.basename(node.path))
    return scenario, config