scenario.py 38.5 KB
Newer Older
1
# FIXME pylint: disable=too-many-lines
2
from abc import ABC
3
import binascii
4
import calendar
5
from datetime import datetime
6
import errno
7
import logging
8
import os
9
import posixpath
10
import random
11
import socket
12
import string
13
import struct
Marek Vavruša's avatar
Marek Vavruša committed
14
import time
15
from typing import Optional
16 17 18

import dns.dnssec
import dns.message
19
import dns.name
20 21 22
import dns.rcode
import dns.rrset
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
23

24
import pydnstest.augwrap
25
import pydnstest.matchpart
26

27

28 29 30 31 32
def str2bool(v):
    """ Return conversion of JSON-ish string value to boolean. """
    return v.lower() in ('yes', 'true', 'on', '1')


33 34 35 36
# Global statistics
g_rtt = 0.0
g_nqueries = 0

37

38 39 40 41 42 43 44 45 46 47 48 49
def recvfrom_msg(stream, raw=False):
    """
    Receive DNS message from TCP/UDP socket.

    Returns:
        if raw == False: (DNS message object, peer address)
        if raw == True: (blob, peer address)
    """
    if stream.type & socket.SOCK_DGRAM:
        data, addr = stream.recvfrom(4096)
    elif stream.type & socket.SOCK_STREAM:
        data = stream.recv(2)
50
        if not data:
51 52 53 54 55 56
            return None, None
        msg_len = struct.unpack_from("!H", data)[0]
        data = b""
        received = 0
        while received < msg_len:
            next_chunk = stream.recv(4096)
57
            if not next_chunk:
58 59 60 61 62 63
                return None, None
            data += next_chunk
            received += len(next_chunk)
        addr = stream.getpeername()[0]
    else:
        raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
64 65 66 67 68
    if raw:
        return data, addr
    else:
        msg = dns.message.from_wire(data, one_rr_per_rrset=True)
        return msg, addr
69 70 71 72 73 74 75 76 77 78 79 80 81 82


def sendto_msg(stream, message, addr=None):
    """ Send DNS/UDP/TCP message. """
    try:
        if stream.type & socket.SOCK_DGRAM:
            if addr is None:
                stream.send(message)
            else:
                stream.sendto(message, addr)
        elif stream.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            stream.send(data)
        else:
83 84 85 86
            raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
    except socket.error as ex:
        if ex.errno != errno.ECONNREFUSED:  # TODO Investigate how this can happen
            raise
87 88


89
def replay_rrs(rrs, nqueries, destination, args=None):
90
    """ Replay list of queries and report statistics. """
91 92
    if args is None:
        args = []
93 94
    navail, queries = len(rrs), []
    chunksize = 16
95
    for i in range(nqueries if 'RAND' in args else navail):
96
        rr = rrs[i % navail]
97 98
        name = rr.name
        if 'RAND' in args:
99
            prefix = ''.join([random.choice(string.ascii_letters + string.digits)
100
                              for _ in range(8)])
101 102 103 104
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
105 106 107 108 109 110 111 112 113 114 115 116 117
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
118
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
119
        if to_write:
120 121 122 123 124
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
125
            except socket.error:
126
                pass  # EINVAL
127
        if to_read:
128 129 130 131 132
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
133
            except socket.error:
134
                pass
135
        if not to_write and not to_read:
136
            nwait = 0  # Timeout, started dropping packets
137
            break
138 139
    return nsent, nrcvd

140

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
class DNSBlob(ABC):
    def to_wire(self) -> bytes:
        raise NotImplementedError

    def __str__(self) -> str:
        return '<DNSBlob>'


class DNSMessage(DNSBlob):
    def __init__(self, message: dns.message.Message) -> None:
        assert message is not None
        self.message = message

    def to_wire(self) -> bytes:
        return self.message.to_wire(max_size=65535)

    def __str__(self) -> str:
        return str(self.message)


class DNSReply(DNSMessage):
162 163
    def __init__(
                self,
164 165 166 167
                message: dns.message.Message,
                query: Optional[dns.message.Message] = None,
                copy_id: bool = False,
                copy_query: bool = False
168
            ) -> None:
169 170 171 172 173
        super().__init__(message)
        if copy_id or copy_query:
            if query is None:
                raise ValueError("query must be provided to adjust copy_id/copy_query")
            self.adjust_reply(query, copy_id, copy_query)
174

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    def adjust_reply(
                self,
                query: dns.message.Message,
                copy_id: bool = True,
                copy_query: bool = True
            ) -> None:
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
        if copy_id:
            answer.id = query.id
            # Copy letter-case if the template has QD
            if answer.question:
                answer.question[0].name = query.question[0].name
        if copy_query:
            answer.question = query.question
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
194

195 196 197 198 199
        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
        self.message = answer
200

201 202 203 204 205 206 207 208 209 210 211

class DNSReplyRaw(DNSBlob):
    def __init__(
                self,
                wire: bytes,
                query: Optional[dns.message.Message] = None,
                copy_id: bool = False
            ) -> None:
        assert wire is not None
        self.wire = wire
        if copy_id:
212 213
            if query is None:
                raise ValueError("query must be provided to adjust copy_id")
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
            self.adjust_reply(query, copy_id)

    def adjust_reply(
                self,
                query: dns.message.Message,
                copy_id: bool = True
            ) -> None:
        if copy_id:
            if len(self.wire) < 2:
                raise ValueError(
                    'wire data must contain at least 2 bytes to adjust query id')
            raw_answer = bytearray(self.wire)
            struct.pack_into('!H', raw_answer, 0, query.id)
            self.wire = bytes(raw_answer)

    def to_wire(self) -> bytes:
        return self.wire

    def __str__(self) -> str:
        return '<DNSReplyRaw>'


class DNSReplyServfail(DNSMessage):
    def __init__(self, query: dns.message.Message) -> None:
        message = dns.message.make_response(query)
        message.set_rcode(dns.rcode.SERVFAIL)
        super().__init__(message)
241 242


Marek Vavruša's avatar
Marek Vavruša committed
243 244
class Entry:
    """
245 246
    Data entry represents scripted message and extra metadata,
    notably match criteria and reply adjustments.
Marek Vavruša's avatar
Marek Vavruša committed
247 248 249 250 251 252 253
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

254
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
255
        """ Initialize data entry. """
256
        self.node = node
Marek Vavruša's avatar
Marek Vavruša committed
257 258
        self.origin = '.'
        self.message = dns.message.Message()
259 260
        self.message.use_edns(edns=0, payload=4096)
        self.fired = 0
Marek Vavruša's avatar
Marek Vavruša committed
261

262
        # RAW
263
        self.raw_data = self.process_raw()
264

265
        # MATCH
266
        self.match_fields = self.process_match()
267

268
        # FLAGS
269
        self.process_reply_line()
270

271
        # ADJUST
272
        self.adjust_fields = {m.value for m in node.match("/adjust")}
273

274
        # MANDATORY
275
        try:
276
            self.mandatory = list(node.match("/mandatory"))[0]
277
        except (KeyError, IndexError):
278
            self.mandatory = None
279

280
        # TSIG
281 282 283 284 285 286 287
        self.process_tsig()

        # SECTIONS & RECORDS
        self.sections = self.process_sections()

    def process_raw(self):
        try:
288
            return binascii.unhexlify(self.node["/raw"].value)
289
        except KeyError:
290
            return None
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324

    def process_match(self):
        try:
            self.node["/match_present"]
        except KeyError:
            return None

        fields = set(m.value for m in self.node.match("/match"))

        if 'all' in fields:
            fields.remove("all")
            fields |= set(["opcode", "qtype", "qname", "flags",
                           "rcode", "answer", "authority", "additional"])

        if 'question' in fields:
            fields.remove("question")
            fields |= set(["qtype", "qname"])

        return fields

    def process_reply_line(self):
        """Extracts flags, rcode and opcode from given node and adjust dns message accordingly"""
        self.fields = [f.value for f in self.node.match("/reply")]
        if 'DO' in self.fields:
            self.message.want_dnssec(True)
        opcode = self.get_opcode(fields=self.fields)
        rcode = self.get_rcode(fields=self.fields)
        self.message.flags = self.get_flags(fields=self.fields)
        if rcode is not None:
            self.message.set_rcode(rcode)
        if opcode is not None:
            self.message.set_opcode(opcode)

    def process_tsig(self):
325
        try:
326
            tsig = list(self.node.match("/tsig"))[0]
327 328 329 330 331 332 333
            tsig_keyname = tsig["/keyname"].value
            tsig_secret = tsig["/secret"].value
            keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
            self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
        except (KeyError, IndexError):
            pass

334 335 336
    def process_sections(self):
        sections = set()
        for section in self.node.match("/section/*"):
337
            section_name = posixpath.basename(section.path)
338
            sections.add(section_name)
339 340 341
            for record in section.match("/record"):
                owner = record['/domain'].value
                if not owner.endswith("."):
342
                    owner += self.origin
343 344 345 346 347 348 349
                try:
                    ttl = dns.ttl.from_text(record['/ttl'].value)
                except KeyError:
                    ttl = self.default_ttl
                try:
                    rdclass = dns.rdataclass.from_text(record['/class'].value)
                except KeyError:
350
                    rdclass = dns.rdataclass.from_text(self.default_cls)
351 352 353 354
                rdtype = dns.rdatatype.from_text(record['/type'].value)
                rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
                if section_name != "question":
                    rd = record['/data'].value.split()
355
                    if rd:
356 357 358 359 360 361 362 363
                        if rdtype == dns.rdatatype.DS:
                            rd[1] = str(dns.dnssec.algorithm_from_text(rd[1]))
                        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
                            rd), origin=dns.name.from_text(self.origin), relativize=False)
                    rr.add(rd)
                if section_name == 'question':
                    if rr.rdtype == dns.rdatatype.AXFR:
                        self.message.xfr = True
364
                    self.message.question.append(rr)
365
                elif section_name == 'answer':
366
                    self.message.answer.append(rr)
367
                elif section_name == 'authority':
368
                    self.message.authority.append(rr)
369
                elif section_name == 'additional':
370
                    self.message.additional.append(rr)
371
        return sections
372 373 374

    def __str__(self):
        txt = 'ENTRY_BEGIN\n'
375
        if self.raw_data is None:
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
            txt += 'MATCH {0}\n'.format(' '.join(self.match_fields))
        txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields))
        txt += 'REPLY {rcode} {flags}\n'.format(
            rcode=dns.rcode.to_text(self.message.rcode()),
            flags=' '.join([dns.flags.to_text(self.message.flags),
                            dns.flags.edns_to_text(self.message.ednsflags)])
        )
        for sect_name in ['question', 'answer', 'authority', 'additional']:
            sect = getattr(self.message, sect_name)
            if not sect:
                continue
            txt += 'SECTION {n}\n'.format(n=sect_name.upper())
            for rr in sect:
                txt += str(rr)
                txt += '\n'
391
        if self.raw_data is not None:
392 393 394 395 396 397 398 399 400
            txt += 'RAW\n'
            if self.raw_data:
                txt += binascii.hexlify(self.raw_data)
            else:
                txt += 'NULL'
            txt += '\n'
        txt += 'ENTRY_END\n'
        return txt

401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
    @classmethod
    def get_flags(cls, fields):
        """From `fields` extracts and returns flags"""
        flags = []
        for code in fields:
            try:
                dns.flags.from_text(code)  # throws KeyError on failure
                flags.append(code)
            except KeyError:
                pass
        return dns.flags.from_text(' '.join(flags))

    @classmethod
    def get_rcode(cls, fields):
        """
        From `fields` extracts and returns rcode.
        Throws `ValueError` if there are more then one rcodes
        """
        rcodes = []
        for code in fields:
            try:
                rcodes.append(dns.rcode.from_text(code))
            except dns.rcode.UnknownRcode:
                pass
        if len(rcodes) > 1:
            raise ValueError("Parse failed, too many rcode values.", rcodes)
427
        if not rcodes:
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
            return None
        return rcodes[0]

    @classmethod
    def get_opcode(cls, fields):
        """
        From `fields` extracts and returns opcode.
        Throws `ValueError` if there are more then one opcodes
        """
        opcodes = []
        for code in fields:
            try:
                opcodes.append(dns.opcode.from_text(code))
            except dns.opcode.UnknownOpcode:
                pass
        if len(opcodes) > 1:
            raise ValueError("Parse failed, too many opcode values.")
445
        if not opcodes:
446 447 448
            return None
        return opcodes[0]

Marek Vavruša's avatar
Marek Vavruša committed
449 450
    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
451
        for code in self.match_fields:
Marek Vavruša's avatar
Marek Vavruša committed
452
            try:
453 454
                pydnstest.matchpart.match_part(self.message, msg, code)
            except pydnstest.matchpart.DataMismatch as ex:
455
                errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
456 457
                # TODO: cisla radku
                raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
458 459

    def cmp_raw(self, raw_value):
460
        assert self.raw_data is not None
Marek Vavruša's avatar
Marek Vavruša committed
461 462 463 464 465 466 467
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
468
            raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
Marek Vavruša's avatar
Marek Vavruša committed
469

470
    def reply(self, query) -> Optional[DNSBlob]:
471 472
        if 'do_not_answer' in self.adjust_fields:
            return None
473
        if self.raw_data is not None:
474 475 476 477 478 479
            copy_id = 'raw_data' in self.adjust_fields
            assert self.raw_data is not None
            return DNSReplyRaw(self.raw_data, query, copy_id)
        copy_id = 'copy_id' in self.adjust_fields
        copy_query = 'copy_query' in self.adjust_fields
        return DNSReply(self.message, query, copy_id, copy_query)
480

481 482 483 484
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
485
        if fields and fields[0].isdigit():
486
            version = int(fields.pop(0))
487
        if fields and fields[0].isdigit():
488
            bufsize = int(fields.pop(0))
489 490 491
        if bufsize == 0:
            self.message.use_edns(False)
            return
492 493
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
494
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
495
            if k.lower() == 'nsid':
496
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
497 498 499
            if k.lower() == 'subnet':
                net = v.split('/')
                subnet_addr = net[0]
500 501
                family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
                addr = socket.inet_pton(family, subnet_addr)
502 503 504
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
505 506
                addr = addr[0: (prefix + 7) / 8]
                if prefix % 8 != 0:  # Mask the last byte
507
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
508 509 510
                opts.append(dns.edns.GenericOption(8, struct.pack(
                    "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns=version, payload=bufsize, options=opts)
511

512

Marek Vavruša's avatar
Marek Vavruša committed
513 514 515 516
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """
517
    log = logging.getLogger('pydnstest.scenario.Range')
Marek Vavruša's avatar
Marek Vavruša committed
518

519
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
520
        """ Initialize reply range. """
521
        self.node = node
522 523
        self.a = int(node['/from'].value)
        self.b = int(node['/to'].value)
524
        assert self.a <= self.b
525 526 527

        address = node["/address"].value
        self.addresses = {address} if address is not None else set()
528
        self.addresses |= {a.value for a in node.match("/address/*")}
529
        self.stored = [Entry(n) for n in node.match("/entry")]
530 531 532 533 534
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
535 536
        self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
                      self.a, self.b, self.addresses, self.received, self.sent)
Marek Vavruša's avatar
Marek Vavruša committed
537

538 539 540 541 542 543 544 545 546 547
    def __str__(self):
        txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b)
        for addr in self.addresses:
            txt += '        ADDRESS {0}\n'.format(addr)

        for entry in self.stored:
            txt += '\n'
            txt += str(entry)
        txt += 'RANGE_END\n\n'
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
548

549
    def eligible(self, ident, address):
Marek Vavruša's avatar
Marek Vavruša committed
550
        """ Return true if this range is eligible for fetching reply. """
551
        if self.a <= ident <= self.b:
552
            return (None is address
553 554
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
555 556
        return False

557
    def reply(self, query: dns.message.Message) -> Optional[DNSBlob]:
558
        """Get answer for given query (adjusted if needed)."""
559
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
560 561 562
        for candidate in self.stored:
            try:
                candidate.match(query)
563
                resp = candidate.reply(query)
564 565 566
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
567
                        return DNSReplyServfail(query)
568
                self.sent += 1
569
                candidate.fired += 1
570
                return resp
571
            except ValueError:
Marek Vavruša's avatar
Marek Vavruša committed
572
                pass
573
        return DNSReplyServfail(query)
Marek Vavruša's avatar
Marek Vavruša committed
574 575


576
class StepLogger(logging.LoggerAdapter):  # pylint: disable=too-few-public-methods
577 578 579 580 581 582 583
    """
    Prepent Step identification before each log message.
    """
    def process(self, msg, kwargs):
        return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs


Marek Vavruša's avatar
Marek Vavruša committed
584 585 586 587 588
class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """
589
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
590

591
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
592
        """ Initialize single scenario step. """
593
        self.node = node
594 595
        self.id = int(node.value)
        self.type = node["/type"].value
596
        self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
597
                              {'id': self.id, 'type': self.type})
598 599 600 601 602
        try:
            self.delay = int(node["/timestamp"].value)
        except KeyError:
            pass
        self.data = [Entry(n) for n in node.match("/entry")]
603
        self.queries = []
604
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
605 606
        self.answer = None
        self.raw_answer = None
607 608 609
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
Petr Špaček's avatar
Petr Špaček committed
610

611
        # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
        # if type == 'CHECK_ANSWER':
        #     for arg in extra_args:
        #         param = arg.split('=')
        #         try:
        #             if param[0] == 'REPEAT':
        #                 self.repeat_if_fail = int(param[1])
        #             elif param[0] == 'PAUSE':
        #                 self.pause_if_fail = float(param[1])
        #             elif param[0] == 'NEXT':
        #                 self.next_if_fail = int(param[1])
        #         except Exception as e:
        #             raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))

    def __str__(self):
        txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type)
        if self.repeat_if_fail:
            txt += ' REPEAT {v}'.format(v=self.repeat_if_fail)
        elif self.pause_if_fail:
            txt += ' PAUSE {v}'.format(v=self.pause_if_fail)
        elif self.next_if_fail != -1:
            txt += ' NEXT {v}'.format(v=self.next_if_fail)
        # if self.args:
        #     txt += ' '
        #     txt += ' '.join(self.args)
        txt += '\n'

        for data in self.data:
            # from IPython.core.debugger import Tracer
            # Tracer()()
            txt += str(data)
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
643

644
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
645 646
        """ Play one step from a scenario. """
        if self.type == 'QUERY':
647 648
            self.log.info('')
            self.log.debug(self.data[0].message.to_text())
649 650
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
651
            return self.__query(ctx, tcp=tcp, choice=choice, source=source)
652
        elif self.type == 'CHECK_OUT_QUERY':  # ignore
653
            self.log.info('')
654
            return None
655
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
656
            self.log.info('')
Marek Vavruša's avatar
Marek Vavruša committed
657
            return self.__check_answer(ctx)
658
        elif self.type == 'TIME_PASSES ELAPSE':
659
            self.log.info('')
660
            return self.__time_passes()
661
        elif self.type == 'REPLY' or self.type == 'MOCK':
662
            self.log.info('')
663
            return None
664 665 666 667 668 669 670 671 672 673
        # Parser currently doesn't support step types LOG, REPLAY and ASSERT.
        # No test uses them.
        # elif self.type == 'LOG':
        #     if not ctx.log:
        #         raise Exception('scenario has no log interface')
        #     return ctx.log.match(self.args)
        # elif self.type == 'REPLAY':
        #     self.__replay(ctx)
        # elif self.type == 'ASSERT':
        #     self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
674
        else:
675
            raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type))
676

Marek Vavruša's avatar
Marek Vavruša committed
677 678
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
679
        if not self.data:
680
            raise ValueError("response definition required")
Marek Vavruša's avatar
Marek Vavruša committed
681
        expected = self.data[0]
682
        if expected.raw_data is not None:
683
            self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
684 685 686
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
687
                raise ValueError("no answer from preceding query")
688
            self.log.debug("answer: %s", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
689 690
            expected.match(ctx.last_answer)

691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
    # def __replay(self, ctx, chunksize=8):
    #     nqueries = len(self.queries)
    #     if len(self.args) > 0 and self.args[0].isdigit():
    #         nqueries = int(self.args.pop(0))
    #     destination = ctx.client[ctx.client.keys()[0]]
    #     self.log.info('replaying %d queries to %s@%d (%s)',
    #                   nqueries, destination[0], destination[1], ' '.join(self.args))
    #     if 'INTENSIFY' in os.environ:
    #         nqueries *= int(os.environ['INTENSIFY'])
    #     tstart = datetime.now()
    #     nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
    #     # Keep/print the statistics
    #     rtt = (datetime.now() - tstart).total_seconds() * 1000
    #     pps = 1000 * nrcvd / rtt
    #     self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
    #     tag = None
    #     for arg in self.args:
    #         if arg.upper().startswith('PRINT'):
    #             _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
    #     if tag:
    #         self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
    #                       tag.ljust(11), pps, rtt, nsent, nrcvd)
713

714
    def __query(self, ctx, tcp=False, choice=None, source=None):
715 716 717 718 719
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
720
        if not self.data:
721
            raise ValueError("query definition required")
722
        if self.data[0].raw_data is not None:
Marek Vavruša's avatar
Marek Vavruša committed
723 724 725 726
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
727
        if choice is None or not choice:
728
            choice = list(ctx.client.keys())[0]
729
        if choice not in ctx.client:
730
            raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
731 732 733
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
734
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
735 736
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
737
        if tcp:
738 739 740
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
741
            sock.bind((source, 0))
742
        sock.connect(destination)
743
        # Send query to client and wait for response
744
        tstart = datetime.now()
Marek Vavruša's avatar
Marek Vavruša committed
745 746
        while True:
            try:
747
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
748
                break
749
            except OSError as ex:
Marek Vavruša's avatar
Marek Vavruša committed
750
                # ENOBUFS, throttle sending
751
                if ex.errno == errno.ENOBUFS:
Marek Vavruša's avatar
Marek Vavruša committed
752 753 754
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
755
        if self.data[0].raw_data is None:
756
            while True:
757 758
                if (datetime.now() - tstart).total_seconds() > 5:
                    raise RuntimeError("Server took too long to respond")
759
                try:
760
                    answer, _ = recvfrom_msg(sock, True)
761
                    break
762 763
                except OSError as ex:
                    if ex.errno == errno.ENOBUFS:
764
                        time.sleep(0.1)
765 766 767 768 769
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
770 771 772 773
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
774
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
Marek Vavruša's avatar
Marek Vavruša committed
775 776 777 778
        else:
            self.answer = None
        ctx.last_answer = self.answer

779
    def __time_passes(self):
Marek Vavruša's avatar
Marek Vavruša committed
780
        """ Modify system time. """
781 782 783 784
        file_old = os.environ["FAKETIME_TIMESTAMP_FILE"]
        file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next"
        with open(file_old, 'r') as time_file:
            line = time_file.readline().strip()
785
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
786
        t += self.delay
787 788 789 790
        with open(file_next, 'w') as time_file:
            time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
            time_file.flush()
        os.replace(file_next, file_old)
Marek Vavruša's avatar
Marek Vavruša committed
791

792 793 794 795 796 797 798 799 800 801 802 803 804
    # def __assert(self, ctx):
    #     """ Assert that a passed expression evaluates to True. """
    #     result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #     # Evaluate subexpressions for clarity
    #     subexpr = []
    #     for expr in self.args:
    #         try:
    #             ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #             subexpr.append(str(ee))
    #         except:
    #             subexpr.append(expr)
    #     assert result is True, '"%s" assertion fails (%s)' % (
    #                            ' '.join(self.args), ' '.join(subexpr))
805

806

Marek Vavruša's avatar
Marek Vavruša committed
807
class Scenario:
808
    log = logging.getLogger('pydnstest.scenatio.Scenario')
809

810
    def __init__(self, node, filename):
Marek Vavruša's avatar
Marek Vavruša committed
811
        """ Initialize scenario with description. """
812
        self.node = node
813
        self.info = node.value
814
        self.file = filename
815
        self.ranges = [Range(n) for n in node.match("/range")]
816
        self.current_range = None
817
        self.steps = [Step(n) for n in node.match("/step")]
Marek Vavruša's avatar
Marek Vavruša committed
818
        self.current_step = None
819
        self.client = {}
Marek Vavruša's avatar
Marek Vavruša committed
820

821 822 823 824 825
    def __str__(self):
        txt = 'SCENARIO_BEGIN'
        if self.info:
            txt += ' {0}'.format(self.info)
        txt += '\n'
826 827
        for range_ in self.ranges:
            txt += str(range_)
828 829 830 831 832
        for step in self.steps:
            txt += str(step)
        txt += "\nSCENARIO_END"
        return txt

833
    def reply(self, query: dns.message.Message, address=None) -> Optional[DNSBlob]:
834
        """Generate answer packet for given query."""
835
        current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
836 837
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
838
        all_addresses = set()  # type: ignore
839 840 841
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
842 843 844
            address = None
        # Find current valid query response range
        for rng in self.ranges:
845
            if rng.eligible(current_step_id, address):
846
                self.current_range = rng
847
                return rng.reply(query)
Marek Vavruša's avatar
Marek Vavruša committed
848 849
        # Find any prescripted one-shot replies
        for step in self.steps:
850
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
851 852 853
                continue
            try:
                candidate = step.data[0]
854 855 856
                candidate.match(query)
                step.data.remove(candidate)
                return candidate.reply(query)
857
            except (IndexError, ValueError):
Marek Vavruša's avatar
Marek Vavruša committed
858
                pass
859
        return DNSReplyServfail(query)
Marek Vavruša's avatar
Marek Vavruša committed
860

861
    def play(self, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
862
        """ Play given scenario. """
863 864
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
865

866 867 868 869 870 871 872
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
873
            except ValueError as ex:
874
                if step.repeat_if_fail > 0:
875
                    self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
876
                                  step.id, ex, step.next_if_fail, step.repeat_if_fail)
877
                    step.repeat_if_fail -= 1
878
                    if step.pause_if_fail > 0:
879
                        time.sleep(step.pause_if_fail)
880
                    if step.next_if_fail != -1:
881 882
                        next_steps = [j for j in range(len(self.steps)) if self.steps[
                            j].id == step.next_if_fail]
883
                        if not next_steps:
884 885
                            raise ValueError('step %d: wrong NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
886
                        next_step = next_steps[0]
887
                        if next_step < len(self.steps):
888 889
                            i = next_step
                        else:
890 891
                            raise ValueError('step %d: Can''t branch to NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
892 893
                    continue
                else:
894 895
                    raise ValueError('%s step %d %s' % (self.file, step.id, str(ex)))
            i += 1
896

897 898
        for r in self.ranges:
            for e in r.stored:
899
                if e.mandatory and e.fired == 0:
900
                    # TODO: cisla radku
901
                    raise ValueError('Mandatory section at %s not fired' % e.mandatory.span)
902

903

904
def get_next(file_in, skip_empty=True):
905 906 907
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
908
        if not line:
909
            return False
910
        quoted, escaped = False, False
911 912
        for i, char in enumerate(line):
            if char == '\\':
913
                escaped = not escaped
914
            if not escaped and char == '"':
915
                quoted = not quoted
916
            if char == ';' and not quoted:
917 918
                line = line[0:i]
                break
919
            if char != '\\':
920
                escaped = False
921
        tokens = ' '.join(line.strip().split()).split()
922
        if not tokens:
923 924 925 926
            if skip_empty:
                continue
            else:
                return '', []
927 928 929
        op = tokens.pop(0)
        return op, tokens

930

931
def parse_config(scn_cfg, qmin, installdir):  # FIXME: pylint: disable=too-many-statements
932 933
    """
    Transform scene config (key, value) pairs into dict filled with defaults.
934 935 936
    Returns tuple:
      context dict: {Jinja2 variable: value}
      trust anchor dict: {domain: [TA lines for particular domain]}
937 938 939 940 941 942
    """
    # defaults
    do_not_query_localhost = True
    harden_glue = True
    sockfamily = 0  # auto-select value for socket.getaddrinfo
    trust_anchor_list = []
943
    trust_anchor_files = {}
944
    negative_ta_list = []
945 946 947 948 949 950 951 952 953 954 955
    stub_addr = None
    override_timestamp = None

    features = {}
    feature_list_delimiter = ';'
    feature_pair_delimiter = '='

    for k, v in scn_cfg:
        # Enable selectively for some tests
        if k == 'do-not-query-localhost':
            do_not_query_localhost = str2bool(v)
956
        elif k == 'domain-insecure':
957
            negative_ta_list.append(v)
958
        elif k == 'harden-glue':
959
            harden_glue = str2bool(v)
960
        elif k == 'query-minimization':
961 962
            qmin = str2bool(v)
        elif k == 'trust-anchor':
963 964 965 966 967 968
            trust_anchor = v.strip('"\'')
            trust_anchor_list.append(trust_anchor)
            domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize()
            if domain not in trust_anchor_files:
                trust_anchor_files[domain] = []
            trust_anchor_files[domain].append(trust_anchor)
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997
        elif k == 'val-override-timestamp':
            override_timestamp_str = v.strip('"\'')
            override_timestamp = int(override_timestamp_str)
        elif k == 'val-override-date':
            override_date_str = v.strip('"\'')
            ovr_yr = override_date_str[0:4]
            ovr_mnt = override_date_str[4:6]
            ovr_day = override_date_str[6:8]
            ovr_hr = override_date_str[8:10]
            ovr_min = override_date_str[10:12]
            ovr_sec = override_date_str[12:]
            override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
                ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
            override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
            override_timestamp = calendar.timegm(override_date)
        elif k == 'stub-addr':
            stub_addr = v.strip('"\'')
        elif k == 'features':
            feature_list = v.split(feature_list_delimiter)
            try:
                for f_item in feature_list:
                    if f_item.find(feature_pair_delimiter) != -1:
                        f_key, f_value = [x.strip()
                                          for x
                                          in f_item.split(feature_pair_delimiter, 1)]
                    else:
                        f_key = f_item.strip()
                        f_value = ""
                    features[f_key] = f_value
998 999
            except KeyError as ex:
                raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex)))
1000 1001 1002 1003 1004 1005 1006
        elif k == 'feature-list':
            try:
                f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
                if f_key not in features:
                    features[f_key] = []
                f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
                features[f_key].append(f_value)
1007 1008 1009
            except KeyError as ex:
                raise KeyError("can't parse feature-list (%s) in config section (%s)"
                               % (v, str(ex)))
1010 1011
        elif k == 'force-ipv6' and v.upper() == 'TRUE':
            sockfamily = socket.AF_INET6
1012 1013
        else:
            raise NotImplementedError('unsupported CONFIG key "%s"' % k)
1014 1015 1016

    ctx = {
        "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
1017
        "NEGATIVE_TRUST_ANCHORS": negative_ta_list,
1018 1019 1020 1021 1022
        "FEATURES": features,
        "HARDEN_GLUE": str(harden_glue).lower(),
        "INSTALL_DIR": installdir,
        "QMIN": str(qmin).lower(),
        "TRUST_ANCHORS": trust_anchor_list,
1023
        "TRUST_ANCHOR_FILES": trust_anchor_files.keys()
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
    }
    if stub_addr:
        ctx['ROOT_ADDR'] = stub_addr
        # determine and verify socket family for specified root address
        gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
                                 socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
        assert len(gai) == 1
        sockfamily = gai[0][0]
    if not sockfamily:
        sockfamily = socket.AF_INET  # default to IPv4
    ctx['_SOCKET_FAMILY'] = sockfamily
    if override_timestamp:
        ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
1037
    return (ctx, trust_anchor_files)
1038 1039


1040
def parse_file(path):
1041
    """ Parse scenario from a file. """
1042

1043 1044
    aug = pydnstest.augwrap.AugeasWrapper(
        confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__))
1045
    node = aug.tree
1046 1047 1048
    config = []
    for line in [c.value for c in node.match("/config/*")]:
        if line:
1049 1050 1051 1052 1053
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
1054
                kv = [x.strip() for x in line.split(':', 1)]
1055 1056
                if len(kv) >= 2:
                    config.append(kv)
1057 1058
    scenario = Scenario(node["/scenario"], posixpath.basename(node.path))
    return scenario, config