scenario.py 35.6 KB
Newer Older
1 2
from __future__ import absolute_import

3
import binascii
4
import calendar
5
import errno
6
import logging
7
import os
8
import posixpath
9
import random
10
import socket
11
import string
12
import struct
Marek Vavruša's avatar
Marek Vavruša committed
13
import time
14
from datetime import datetime
15 16 17

import dns.dnssec
import dns.message
18
import dns.name
19 20 21
import dns.rcode
import dns.rrset
import dns.tsigkeyring
Marek Vavruša's avatar
Marek Vavruša committed
22

23
import pydnstest.augwrap
24
import pydnstest.matchpart
25

26

27 28 29 30 31
def str2bool(v):
    """ Return conversion of JSON-ish string value to boolean. """
    return v.lower() in ('yes', 'true', 'on', '1')


32 33 34 35
# Global statistics
g_rtt = 0.0
g_nqueries = 0

36

37 38 39 40 41 42 43 44 45 46 47 48
def recvfrom_msg(stream, raw=False):
    """
    Receive DNS message from TCP/UDP socket.

    Returns:
        if raw == False: (DNS message object, peer address)
        if raw == True: (blob, peer address)
    """
    if stream.type & socket.SOCK_DGRAM:
        data, addr = stream.recvfrom(4096)
    elif stream.type & socket.SOCK_STREAM:
        data = stream.recv(2)
49
        if not data:
50 51 52 53 54 55
            return None, None
        msg_len = struct.unpack_from("!H", data)[0]
        data = b""
        received = 0
        while received < msg_len:
            next_chunk = stream.recv(4096)
56
            if not next_chunk:
57 58 59 60 61 62
                return None, None
            data += next_chunk
            received += len(next_chunk)
        addr = stream.getpeername()[0]
    else:
        raise NotImplementedError("[recvfrom_msg]: unknown socket type '%i'" % stream.type)
63 64 65 66 67
    if raw:
        return data, addr
    else:
        msg = dns.message.from_wire(data, one_rr_per_rrset=True)
        return msg, addr
68 69 70 71 72 73 74 75 76 77 78 79 80 81


def sendto_msg(stream, message, addr=None):
    """ Send DNS/UDP/TCP message. """
    try:
        if stream.type & socket.SOCK_DGRAM:
            if addr is None:
                stream.send(message)
            else:
                stream.sendto(message, addr)
        elif stream.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            stream.send(data)
        else:
82 83 84 85
            raise NotImplementedError("[sendto_msg]: unknown socket type '%i'" % stream.type)
    except socket.error as ex:
        if ex.errno != errno.ECONNREFUSED:  # TODO Investigate how this can happen
            raise
86 87


88
def replay_rrs(rrs, nqueries, destination, args=[]):
89 90 91
    """ Replay list of queries and report statistics. """
    navail, queries = len(rrs), []
    chunksize = 16
92
    for i in range(nqueries if 'RAND' in args else navail):
93
        rr = rrs[i % navail]
94 95
        name = rr.name
        if 'RAND' in args:
96
            prefix = ''.join([random.choice(string.ascii_letters + string.digits)
97
                              for _ in range(8)])
98 99 100 101
            name = prefix + '.' + rr.name.to_text()
        msg = dns.message.make_query(name, rr.rdtype, rr.rdclass)
        if 'DO' in args:
            msg.want_dnssec(True)
102 103 104 105 106 107 108 109 110 111 112 113 114
        queries.append(msg.to_wire())
    # Make a UDP connected socket to the destination
    family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
    sock = socket.socket(family, socket.SOCK_DGRAM)
    sock.connect(destination)
    sock.setblocking(False)
    # Play the query set
    # @NOTE: this is only good for relative low-speed replay
    rcvbuf = bytearray('\x00' * 512)
    nsent, nrcvd, nwait, navail = 0, 0, 0, len(queries)
    fdset = [sock]
    import select
    while nsent - nwait < nqueries:
115
        to_read, to_write, _ = select.select(fdset, fdset if nwait < chunksize else [], [], 0.5)
116
        if to_write:
117 118 119 120 121
            try:
                while nsent < nqueries and nwait < chunksize:
                    sock.send(queries[nsent % navail])
                    nwait += 1
                    nsent += 1
122
            except socket.error:
123
                pass  # EINVAL
124
        if to_read:
125 126 127 128 129
            try:
                while nwait > 0:
                    sock.recv_into(rcvbuf)
                    nwait -= 1
                    nrcvd += 1
130
            except socket.error:
131
                pass
132
        if not to_write and not to_read:
133
            nwait = 0  # Timeout, started dropping packets
134
            break
135 136
    return nsent, nrcvd

137

Marek Vavruša's avatar
Marek Vavruša committed
138 139
class Entry:
    """
140 141
    Data entry represents scripted message and extra metadata,
    notably match criteria and reply adjustments.
Marek Vavruša's avatar
Marek Vavruša committed
142 143 144 145 146 147 148
    """

    # Globals
    default_ttl = 3600
    default_cls = 'IN'
    default_rc = 'NOERROR'

149
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
150
        """ Initialize data entry. """
151
        self.node = node
Marek Vavruša's avatar
Marek Vavruša committed
152 153
        self.origin = '.'
        self.message = dns.message.Message()
154 155
        self.message.use_edns(edns=0, payload=4096)
        self.fired = 0
Marek Vavruša's avatar
Marek Vavruša committed
156

157
        # RAW
158 159 160 161
        try:
            self.raw_data = binascii.unhexlify(node["/raw"].value)
            self.is_raw_data_entry = True
            return
162
        except KeyError:
163 164 165
            self.raw_data = None
            self.is_raw_data_entry = False

166
        # MATCH
167
        self.match_fields = [m.value for m in node.match("/match")]
168

169 170 171
        if not self.match_fields:
            self.match_fields = ['opcode', 'qtype', 'qname']

172
        # FLAGS
173
        self.process_reply_line(node)
174

175
        # ADJUST
176 177 178 179
        self.adjust_fields = [m.value for m in node.match("/adjust")]
        if not self.adjust_fields:
            self.adjust_fields = ['copy_id']

180
        # MANDATORY
181
        try:
182
            self.mandatory = list(node.match("/mandatory"))[0]
183
        except (KeyError, IndexError):
184
            self.mandatory = None
185

186
        # TSIG
187 188 189 190 191 192 193 194 195
        try:
            tsig = list(node.match("/tsig"))[0]
            tsig_keyname = tsig["/keyname"].value
            tsig_secret = tsig["/secret"].value
            keyring = dns.tsigkeyring.from_text({tsig_keyname: tsig_secret})
            self.message.use_tsig(keyring=keyring, keyname=tsig_keyname)
        except (KeyError, IndexError):
            pass

196
        # SECTIONS & RECORDS
197 198 199 200 201 202 203
        self.sections = []
        for section in node.match("/section/*"):
            section_name = posixpath.basename(section.path)
            self.sections.append(section_name)
            for record in section.match("/record"):
                owner = record['/domain'].value
                if not owner.endswith("."):
204
                    owner += self.origin
205 206 207 208 209 210 211
                try:
                    ttl = dns.ttl.from_text(record['/ttl'].value)
                except KeyError:
                    ttl = self.default_ttl
                try:
                    rdclass = dns.rdataclass.from_text(record['/class'].value)
                except KeyError:
212
                    rdclass = dns.rdataclass.from_text(self.default_cls)
213 214 215 216
                rdtype = dns.rdatatype.from_text(record['/type'].value)
                rr = dns.rrset.from_text(owner, ttl, rdclass, rdtype)
                if section_name != "question":
                    rd = record['/data'].value.split()
217
                    if rd:
218 219 220 221 222 223 224 225
                        if rdtype == dns.rdatatype.DS:
                            rd[1] = str(dns.dnssec.algorithm_from_text(rd[1]))
                        rd = dns.rdata.from_text(rr.rdclass, rr.rdtype, ' '.join(
                            rd), origin=dns.name.from_text(self.origin), relativize=False)
                    rr.add(rd)
                if section_name == 'question':
                    if rr.rdtype == dns.rdatatype.AXFR:
                        self.message.xfr = True
226
                    self.message.question.append(rr)
227
                elif section_name == 'answer':
228
                    self.message.answer.append(rr)
229
                elif section_name == 'authority':
230
                    self.message.authority.append(rr)
231
                elif section_name == 'additional':
232
                    self.message.additional.append(rr)
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261

    def __str__(self):
        txt = 'ENTRY_BEGIN\n'
        if not self.is_raw_data_entry:
            txt += 'MATCH {0}\n'.format(' '.join(self.match_fields))
        txt += 'ADJUST {0}\n'.format(' '.join(self.adjust_fields))
        txt += 'REPLY {rcode} {flags}\n'.format(
            rcode=dns.rcode.to_text(self.message.rcode()),
            flags=' '.join([dns.flags.to_text(self.message.flags),
                            dns.flags.edns_to_text(self.message.ednsflags)])
        )
        for sect_name in ['question', 'answer', 'authority', 'additional']:
            sect = getattr(self.message, sect_name)
            if not sect:
                continue
            txt += 'SECTION {n}\n'.format(n=sect_name.upper())
            for rr in sect:
                txt += str(rr)
                txt += '\n'
        if self.is_raw_data_entry:
            txt += 'RAW\n'
            if self.raw_data:
                txt += binascii.hexlify(self.raw_data)
            else:
                txt += 'NULL'
            txt += '\n'
        txt += 'ENTRY_END\n'
        return txt

262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
    def process_reply_line(self, node):
        """Extracts flags, rcode and opcode from given node and adjust dns message accordingly"""
        self.fields = [f.value for f in node.match("/reply")]
        if 'DO' in self.fields:
            self.message.want_dnssec(True)
        opcode = self.get_opcode(fields=self.fields)
        rcode = self.get_rcode(fields=self.fields)
        self.message.flags = self.get_flags(fields=self.fields)
        if rcode is not None:
            self.message.set_rcode(rcode)
        if opcode is not None:
            self.message.set_opcode(opcode)

    @classmethod
    def get_flags(cls, fields):
        """From `fields` extracts and returns flags"""
        flags = []
        for code in fields:
            try:
                dns.flags.from_text(code)  # throws KeyError on failure
                flags.append(code)
            except KeyError:
                pass
        return dns.flags.from_text(' '.join(flags))

    @classmethod
    def get_rcode(cls, fields):
        """
        From `fields` extracts and returns rcode.
        Throws `ValueError` if there are more then one rcodes
        """
        rcodes = []
        for code in fields:
            try:
                rcodes.append(dns.rcode.from_text(code))
            except dns.rcode.UnknownRcode:
                pass
        if len(rcodes) > 1:
            raise ValueError("Parse failed, too many rcode values.", rcodes)
        if len(rcodes) == 0:
            return None
        return rcodes[0]

    @classmethod
    def get_opcode(cls, fields):
        """
        From `fields` extracts and returns opcode.
        Throws `ValueError` if there are more then one opcodes
        """
        opcodes = []
        for code in fields:
            try:
                opcodes.append(dns.opcode.from_text(code))
            except dns.opcode.UnknownOpcode:
                pass
        if len(opcodes) > 1:
            raise ValueError("Parse failed, too many opcode values.")
        if len(opcodes) == 0:
            return None
        return opcodes[0]

Marek Vavruša's avatar
Marek Vavruša committed
323 324 325 326
    def match(self, msg):
        """ Compare scripted reply to given message based on match criteria. """
        match_fields = self.match_fields
        if 'all' in match_fields:
327 328
            match_fields.remove('all')
            match_fields += ['flags'] + ['rcode'] + self.sections
Marek Vavruša's avatar
Marek Vavruša committed
329 330
        for code in match_fields:
            try:
331 332
                pydnstest.matchpart.match_part(self.message, msg, code)
            except pydnstest.matchpart.DataMismatch as ex:
333
                errstr = '%s in the response:\n%s' % (str(ex), msg.to_text())
334 335
                # TODO: cisla radku
                raise ValueError("%s, \"%s\": %s" % (self.node.span, code, errstr))
Marek Vavruša's avatar
Marek Vavruša committed
336 337

    def cmp_raw(self, raw_value):
338
        assert self.is_raw_data_entry
Marek Vavruša's avatar
Marek Vavruša committed
339 340 341 342 343 344 345
        expected = None
        if self.raw_data is not None:
            expected = binascii.hexlify(self.raw_data)
        got = None
        if raw_value is not None:
            got = binascii.hexlify(raw_value)
        if expected != got:
346
            raise ValueError("raw message comparsion failed: expected %s got %s" % (expected, got))
Marek Vavruša's avatar
Marek Vavruša committed
347 348 349

    def adjust_reply(self, query):
        """ Copy scripted reply and adjust to received query. """
350 351 352
        answer = dns.message.from_wire(self.message.to_wire(),
                                       xfr=self.message.xfr,
                                       one_rr_per_rrset=True)
Petr Špaček's avatar
Petr Špaček committed
353
        answer.use_edns(query.edns, query.ednsflags, options=self.message.options)
Marek Vavruša's avatar
Marek Vavruša committed
354 355
        if 'copy_id' in self.adjust_fields:
            answer.id = query.id
356
            # Copy letter-case if the template has QD
357
            if answer.question:
358
                answer.question[0].name = query.question[0].name
Marek Vavruša's avatar
Marek Vavruša committed
359 360
        if 'copy_query' in self.adjust_fields:
            answer.question = query.question
361 362
        # Re-set, as the EDNS might have reset the ext-rcode
        answer.set_rcode(self.message.rcode())
363 364 365 366 367

        # sanity check: adjusted answer should be almost the same
        assert len(answer.answer) == len(self.message.answer)
        assert len(answer.authority) == len(self.message.authority)
        assert len(answer.additional) == len(self.message.additional)
Marek Vavruša's avatar
Marek Vavruša committed
368 369
        return answer

370 371 372 373
    def set_edns(self, fields):
        """ Set EDNS version and bufsize. """
        version = 0
        bufsize = 4096
374
        if fields and fields[0].isdigit():
375
            version = int(fields.pop(0))
376
        if fields and fields[0].isdigit():
377
            bufsize = int(fields.pop(0))
378 379 380
        if bufsize == 0:
            self.message.use_edns(False)
            return
381 382
        opts = []
        for v in fields:
Marek Vavrusa's avatar
Marek Vavrusa committed
383
            k, v = tuple(v.split('=')) if '=' in v else (v, True)
384
            if k.lower() == 'nsid':
385
                opts.append(dns.edns.GenericOption(dns.edns.NSID, '' if v is True else v))
386 387 388
            if k.lower() == 'subnet':
                net = v.split('/')
                subnet_addr = net[0]
389 390
                family = socket.AF_INET6 if ':' in subnet_addr else socket.AF_INET
                addr = socket.inet_pton(family, subnet_addr)
391 392 393
                prefix = len(addr) * 8
                if len(net) > 1:
                    prefix = int(net[1])
394 395
                addr = addr[0: (prefix + 7) / 8]
                if prefix % 8 != 0:  # Mask the last byte
396
                    addr = addr[:-1] + chr(ord(addr[-1]) & 0xFF << (8 - prefix % 8))
397 398 399
                opts.append(dns.edns.GenericOption(8, struct.pack(
                    "!HBB", 1 if family == socket.AF_INET else 2, prefix, 0) + addr))
        self.message.use_edns(edns=version, payload=bufsize, options=opts)
400

401

Marek Vavruša's avatar
Marek Vavruša committed
402 403 404 405
class Range:
    """
    Range represents a set of scripted queries valid for given step range.
    """
406
    log = logging.getLogger('pydnstest.scenario.Range')
Marek Vavruša's avatar
Marek Vavruša committed
407

408
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
409
        """ Initialize reply range. """
410
        self.node = node
411 412
        self.a = int(node['/from'].value)
        self.b = int(node['/to'].value)
413
        assert self.a <= self.b
414 415 416 417 418

        address = node["/address"].value
        self.addresses = {address} if address is not None else set()
        self.addresses |= set([a.value for a in node.match("/address/*")])
        self.stored = [Entry(n) for n in node.match("/entry")]
419 420 421 422 423
        self.args = {}
        self.received = 0
        self.sent = 0

    def __del__(self):
424 425
        self.log.info('[ RANGE %d-%d ] %s received: %d sent: %d',
                      self.a, self.b, self.addresses, self.received, self.sent)
Marek Vavruša's avatar
Marek Vavruša committed
426

427 428 429 430 431 432 433 434 435 436
    def __str__(self):
        txt = '\nRANGE_BEGIN {a} {b}\n'.format(a=self.a, b=self.b)
        for addr in self.addresses:
            txt += '        ADDRESS {0}\n'.format(addr)

        for entry in self.stored:
            txt += '\n'
            txt += str(entry)
        txt += 'RANGE_END\n\n'
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
437 438 439 440

    def eligible(self, id, address):
        """ Return true if this range is eligible for fetching reply. """
        if self.a <= id <= self.b:
441
            return (None is address
442 443
                    or set() == self.addresses
                    or address in self.addresses)
Marek Vavruša's avatar
Marek Vavruša committed
444 445 446
        return False

    def reply(self, query):
447 448 449 450 451 452
        """
        Get answer for given query (adjusted if needed).

        Returns:
            (DNS message object) or None if there is no candidate in this range
        """
453
        self.received += 1
Marek Vavruša's avatar
Marek Vavruša committed
454 455 456
        for candidate in self.stored:
            try:
                candidate.match(query)
457 458 459 460 461 462
                resp = candidate.adjust_reply(query)
                # Probabilistic loss
                if 'LOSS' in self.args:
                    if random.random() < float(self.args['LOSS']):
                        return None
                self.sent += 1
463
                candidate.fired += 1
464
                return resp
465
            except ValueError:
Marek Vavruša's avatar
Marek Vavruša committed
466 467 468 469
                pass
        return None


470
class StepLogger(logging.LoggerAdapter):  # pylint: disable=too-few-public-methods
471 472 473 474 475 476 477
    """
    Prepent Step identification before each log message.
    """
    def process(self, msg, kwargs):
        return '[STEP %s %s] %s' % (self.extra['id'], self.extra['type'], msg), kwargs


Marek Vavruša's avatar
Marek Vavruša committed
478 479 480 481 482
class Step:
    """
    Step represents one scripted action in a given moment,
    each step has an order identifier, type and optionally data entry.
    """
483
    require_data = ['QUERY', 'CHECK_ANSWER', 'REPLY']
Marek Vavruša's avatar
Marek Vavruša committed
484

485
    def __init__(self, node):
Marek Vavruša's avatar
Marek Vavruša committed
486
        """ Initialize single scenario step. """
487
        self.node = node
488 489
        self.id = int(node.value)
        self.type = node["/type"].value
490
        self.log = StepLogger(logging.getLogger('pydnstest.scenario.Step'),
491
                              {'id': self.id, 'type': self.type})
492 493 494 495 496
        try:
            self.delay = int(node["/timestamp"].value)
        except KeyError:
            pass
        self.data = [Entry(n) for n in node.match("/entry")]
497
        self.queries = []
498
        self.has_data = self.type in Step.require_data
Marek Vavruša's avatar
Marek Vavruša committed
499 500
        self.answer = None
        self.raw_answer = None
501 502 503
        self.repeat_if_fail = 0
        self.pause_if_fail = 0
        self.next_if_fail = -1
Petr Špaček's avatar
Petr Špaček committed
504

505
        # TODO Parser currently can't parse CHECK_ANSWER args, player doesn't understand them anyway
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536
        # if type == 'CHECK_ANSWER':
        #     for arg in extra_args:
        #         param = arg.split('=')
        #         try:
        #             if param[0] == 'REPEAT':
        #                 self.repeat_if_fail = int(param[1])
        #             elif param[0] == 'PAUSE':
        #                 self.pause_if_fail = float(param[1])
        #             elif param[0] == 'NEXT':
        #                 self.next_if_fail = int(param[1])
        #         except Exception as e:
        #             raise Exception('step %d - wrong %s arg: %s' % (self.id, param[0], str(e)))

    def __str__(self):
        txt = '\nSTEP {i} {t}'.format(i=self.id, t=self.type)
        if self.repeat_if_fail:
            txt += ' REPEAT {v}'.format(v=self.repeat_if_fail)
        elif self.pause_if_fail:
            txt += ' PAUSE {v}'.format(v=self.pause_if_fail)
        elif self.next_if_fail != -1:
            txt += ' NEXT {v}'.format(v=self.next_if_fail)
        # if self.args:
        #     txt += ' '
        #     txt += ' '.join(self.args)
        txt += '\n'

        for data in self.data:
            # from IPython.core.debugger import Tracer
            # Tracer()()
            txt += str(data)
        return txt
Marek Vavruša's avatar
Marek Vavruša committed
537

538
    def play(self, ctx):
Marek Vavruša's avatar
Marek Vavruša committed
539 540
        """ Play one step from a scenario. """
        if self.type == 'QUERY':
541 542
            self.log.info('')
            self.log.debug(self.data[0].message.to_text())
543 544
            # Parse QUERY-specific parameters
            choice, tcp, source = None, False, None
545
            return self.__query(ctx, tcp=tcp, choice=choice, source=source)
Marek Vavruša's avatar
Marek Vavruša committed
546
        elif self.type == 'CHECK_OUT_QUERY':
547
            self.log.info('')
548
            pass  # Ignore
549
        elif self.type == 'CHECK_ANSWER' or self.type == 'ANSWER':
550
            self.log.info('')
Marek Vavruša's avatar
Marek Vavruša committed
551
            return self.__check_answer(ctx)
552
        elif self.type == 'TIME_PASSES ELAPSE':
553
            self.log.info('')
554
            return self.__time_passes()
555
        elif self.type == 'REPLY' or self.type == 'MOCK':
556
            self.log.info('')
557 558 559 560 561 562 563 564 565 566
        # Parser currently doesn't support step types LOG, REPLAY and ASSERT.
        # No test uses them.
        # elif self.type == 'LOG':
        #     if not ctx.log:
        #         raise Exception('scenario has no log interface')
        #     return ctx.log.match(self.args)
        # elif self.type == 'REPLAY':
        #     self.__replay(ctx)
        # elif self.type == 'ASSERT':
        #     self.__assert(ctx)
Marek Vavruša's avatar
Marek Vavruša committed
567
        else:
568
            raise NotImplementedError('step %03d type %s unsupported' % (self.id, self.type))
569

Marek Vavruša's avatar
Marek Vavruša committed
570 571
    def __check_answer(self, ctx):
        """ Compare answer from previously resolved query. """
572
        if not self.data:
573
            raise ValueError("response definition required")
Marek Vavruša's avatar
Marek Vavruša committed
574 575
        expected = self.data[0]
        if expected.is_raw_data_entry is True:
576
            self.log.debug("raw answer: %s", ctx.last_raw_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
577 578 579
            expected.cmp_raw(ctx.last_raw_answer)
        else:
            if ctx.last_answer is None:
580
                raise ValueError("no answer from preceding query")
581
            self.log.debug("answer: %s", ctx.last_answer.to_text())
Marek Vavruša's avatar
Marek Vavruša committed
582 583
            expected.match(ctx.last_answer)

584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
    # def __replay(self, ctx, chunksize=8):
    #     nqueries = len(self.queries)
    #     if len(self.args) > 0 and self.args[0].isdigit():
    #         nqueries = int(self.args.pop(0))
    #     destination = ctx.client[ctx.client.keys()[0]]
    #     self.log.info('replaying %d queries to %s@%d (%s)',
    #                   nqueries, destination[0], destination[1], ' '.join(self.args))
    #     if 'INTENSIFY' in os.environ:
    #         nqueries *= int(os.environ['INTENSIFY'])
    #     tstart = datetime.now()
    #     nsent, nrcvd = replay_rrs(self.queries, nqueries, destination, self.args)
    #     # Keep/print the statistics
    #     rtt = (datetime.now() - tstart).total_seconds() * 1000
    #     pps = 1000 * nrcvd / rtt
    #     self.log.debug('sent: %d, received: %d (%d ms, %d p/s)', nsent, nrcvd, rtt, pps)
    #     tag = None
    #     for arg in self.args:
    #         if arg.upper().startswith('PRINT'):
    #             _, tag = tuple(arg.split('=')) if '=' in arg else (None, 'replay')
    #     if tag:
    #         self.log.info('[ REPLAY ] test: %s pps: %5d time: %4d sent: %5d received: %5d',
    #                       tag.ljust(11), pps, rtt, nsent, nrcvd)
606

607
    def __query(self, ctx, tcp=False, choice=None, source=None):
608 609 610 611 612
        """
        Send query and wait for an answer (if the query is not RAW).

        The received answer is stored in self.answer and ctx.last_answer.
        """
613
        if not self.data:
614
            raise ValueError("query definition required")
Marek Vavruša's avatar
Marek Vavruša committed
615 616 617 618 619
        if self.data[0].is_raw_data_entry is True:
            data_to_wire = self.data[0].raw_data
        else:
            # Don't use a message copy as the EDNS data portion is not copied.
            data_to_wire = self.data[0].message.to_wire()
620
        if choice is None or not choice:
621
            choice = list(ctx.client.keys())[0]
622
        if choice not in ctx.client:
623
            raise ValueError('step %03d invalid QUERY target: %s' % (self.id, choice))
624 625 626
        # Create socket to test subject
        sock = None
        destination = ctx.client[choice]
627
        family = socket.AF_INET6 if ':' in destination[0] else socket.AF_INET
628 629
        sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
630
        if tcp:
631 632 633
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
        sock.settimeout(3)
        if source:
634
            sock.bind((source, 0))
635
        sock.connect(destination)
636
        # Send query to client and wait for response
637
        tstart = datetime.now()
Marek Vavruša's avatar
Marek Vavruša committed
638 639
        while True:
            try:
640
                sendto_msg(sock, data_to_wire)
Marek Vavruša's avatar
Marek Vavruša committed
641
                break
642
            except OSError as ex:
Marek Vavruša's avatar
Marek Vavruša committed
643
                # ENOBUFS, throttle sending
644
                if ex.errno == errno.ENOBUFS:
Marek Vavruša's avatar
Marek Vavruša committed
645 646 647 648
                    time.sleep(0.1)
        # Wait for a response for a reasonable time
        answer = None
        if not self.data[0].is_raw_data_entry:
649 650
            while True:
                try:
651
                    answer, _ = recvfrom_msg(sock, True)
652
                    break
653 654
                except OSError as ex:
                    if ex.errno == errno.ENOBUFS:
655
                        time.sleep(0.1)
656 657 658 659 660
        # Track RTT
        rtt = (datetime.now() - tstart).total_seconds() * 1000
        global g_rtt, g_nqueries
        g_nqueries += 1
        g_rtt += rtt
Marek Vavruša's avatar
Marek Vavruša committed
661 662 663 664
        # Remember last answer for checking later
        self.raw_answer = answer
        ctx.last_raw_answer = answer
        if self.raw_answer is not None:
Petr Špaček's avatar
Petr Špaček committed
665
            self.answer = dns.message.from_wire(self.raw_answer, one_rr_per_rrset=True)
Marek Vavruša's avatar
Marek Vavruša committed
666 667 668 669
        else:
            self.answer = None
        ctx.last_answer = self.answer

670
    def __time_passes(self):
Marek Vavruša's avatar
Marek Vavruša committed
671
        """ Modify system time. """
672 673 674 675
        file_old = os.environ["FAKETIME_TIMESTAMP_FILE"]
        file_next = os.environ["FAKETIME_TIMESTAMP_FILE"] + ".next"
        with open(file_old, 'r') as time_file:
            line = time_file.readline().strip()
676
        t = time.mktime(datetime.strptime(line, '@%Y-%m-%d %H:%M:%S').timetuple())
677
        t += self.delay
678 679 680 681
        with open(file_next, 'w') as time_file:
            time_file.write(datetime.fromtimestamp(t).strftime('@%Y-%m-%d %H:%M:%S') + "\n")
            time_file.flush()
        os.replace(file_next, file_old)
Marek Vavruša's avatar
Marek Vavruša committed
682

683 684 685 686 687 688 689 690 691 692 693 694 695
    # def __assert(self, ctx):
    #     """ Assert that a passed expression evaluates to True. """
    #     result = eval(' '.join(self.args), {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #     # Evaluate subexpressions for clarity
    #     subexpr = []
    #     for expr in self.args:
    #         try:
    #             ee = eval(expr, {'SCENARIO': ctx, 'RANGE': ctx.ranges})
    #             subexpr.append(str(ee))
    #         except:
    #             subexpr.append(expr)
    #     assert result is True, '"%s" assertion fails (%s)' % (
    #                            ' '.join(self.args), ' '.join(subexpr))
696

697

Marek Vavruša's avatar
Marek Vavruša committed
698
class Scenario:
699
    log = logging.getLogger('pydnstest.scenatio.Scenario')
700

701
    def __init__(self, node, filename):
Marek Vavruša's avatar
Marek Vavruša committed
702
        """ Initialize scenario with description. """
703
        self.node = node
704
        self.info = node.value
705
        self.file = filename
706
        self.ranges = [Range(n) for n in node.match("/range")]
707
        self.current_range = None
708
        self.steps = [Step(n) for n in node.match("/step")]
Marek Vavruša's avatar
Marek Vavruša committed
709
        self.current_step = None
710
        self.client = {}
Marek Vavruša's avatar
Marek Vavruša committed
711

712 713 714 715 716 717 718 719 720 721 722 723
    def __str__(self):
        txt = 'SCENARIO_BEGIN'
        if self.info:
            txt += ' {0}'.format(self.info)
        txt += '\n'
        for range in self.ranges:
            txt += str(range)
        for step in self.steps:
            txt += str(step)
        txt += "\nSCENARIO_END"
        return txt

724
    def reply(self, query, address=None):
725 726 727 728 729 730 731
        """
        Generate answer packet for given query.

        The answer can be DNS message object or a binary blob.
        Returns:
            (answer, boolean "is the answer binary blob?")
        """
732
        current_step_id = self.current_step.id
Marek Vavruša's avatar
Marek Vavruša committed
733 734
        # Unknown address, select any match
        # TODO: workaround until the server supports stub zones
735 736 737 738
        all_addresses = set()
        for rng in self.ranges:
            all_addresses.update(rng.addresses)
        if address not in all_addresses:
Marek Vavruša's avatar
Marek Vavruša committed
739 740 741
            address = None
        # Find current valid query response range
        for rng in self.ranges:
742
            if rng.eligible(current_step_id, address):
743
                self.current_range = rng
744
                return rng.reply(query), False
Marek Vavruša's avatar
Marek Vavruša committed
745 746
        # Find any prescripted one-shot replies
        for step in self.steps:
747
            if step.id < current_step_id or step.type != 'REPLY':
Marek Vavruša's avatar
Marek Vavruša committed
748 749 750 751 752 753 754
                continue
            try:
                candidate = step.data[0]
                if candidate.is_raw_data_entry is False:
                    candidate.match(query)
                    step.data.remove(candidate)
                    answer = candidate.adjust_reply(query)
755
                    return answer, False
Marek Vavruša's avatar
Marek Vavruša committed
756 757
                else:
                    answer = candidate.raw_data
758
                    return answer, True
759
            except (IndexError, ValueError):
Marek Vavruša's avatar
Marek Vavruša committed
760
                pass
761
        return None, True
Marek Vavruša's avatar
Marek Vavruša committed
762

763
    def play(self, paddr):
Marek Vavruša's avatar
Marek Vavruša committed
764
        """ Play given scenario. """
765 766
        # Store test subject => address mapping
        self.client = paddr
Marek Vavruša's avatar
Marek Vavruša committed
767

768 769 770 771 772 773 774
        step = None
        i = 0
        while i < len(self.steps):
            step = self.steps[i]
            self.current_step = step
            try:
                step.play(self)
775
            except ValueError as ex:
776
                if step.repeat_if_fail > 0:
777
                    self.log.info("[play] step %d: exception - '%s', retrying step %d (%d left)",
778
                                  step.id, ex, step.next_if_fail, step.repeat_if_fail)
779
                    step.repeat_if_fail -= 1
780
                    if step.pause_if_fail > 0:
781
                        time.sleep(step.pause_if_fail)
782
                    if step.next_if_fail != -1:
783 784
                        next_steps = [j for j in range(len(self.steps)) if self.steps[
                            j].id == step.next_if_fail]
785
                        if not next_steps:
786 787
                            raise ValueError('step %d: wrong NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
788
                        next_step = next_steps[0]
789
                        if next_step < len(self.steps):
790 791
                            i = next_step
                        else:
792 793
                            raise ValueError('step %d: Can''t branch to NEXT value "%d"' %
                                             (step.id, step.next_if_fail))
794 795
                    continue
                else:
796 797
                    raise ValueError('%s step %d %s' % (self.file, step.id, str(ex)))
            i += 1
798

799 800
        for r in self.ranges:
            for e in r.stored:
801
                if e.mandatory and e.fired == 0:
802
                    # TODO: cisla radku
803
                    raise RuntimeError('Mandatory section at %s not fired' % e.mandatory.span)
804

805

806
def get_next(file_in, skip_empty=True):
807 808 809
    """ Return next token from the input stream. """
    while True:
        line = file_in.readline()
810
        if not line:
811
            return False
812 813 814 815 816 817
        quoted, escaped = False, False
        for i in range(len(line)):
            if line[i] == '\\':
                escaped = not escaped
            if not escaped and line[i] == '"':
                quoted = not quoted
818
            if line[i] in ';' and not quoted:
819 820 821 822
                line = line[0:i]
                break
            if line[i] != '\\':
                escaped = False
823
        tokens = ' '.join(line.strip().split()).split()
824
        if not tokens:
825 826 827 828
            if skip_empty:
                continue
            else:
                return '', []
829 830 831
        op = tokens.pop(0)
        return op, tokens

832

833 834 835
def parse_config(scn_cfg, qmin, installdir):
    """
    Transform scene config (key, value) pairs into dict filled with defaults.
836 837 838
    Returns tuple:
      context dict: {Jinja2 variable: value}
      trust anchor dict: {domain: [TA lines for particular domain]}
839 840 841 842 843 844
    """
    # defaults
    do_not_query_localhost = True
    harden_glue = True
    sockfamily = 0  # auto-select value for socket.getaddrinfo
    trust_anchor_list = []
845
    trust_anchor_files = {}
846
    negative_ta_list = []
847 848 849 850 851 852 853 854 855 856 857
    stub_addr = None
    override_timestamp = None

    features = {}
    feature_list_delimiter = ';'
    feature_pair_delimiter = '='

    for k, v in scn_cfg:
        # Enable selectively for some tests
        if k == 'do-not-query-localhost':
            do_not_query_localhost = str2bool(v)
858
        elif k == 'domain-insecure':
859
            negative_ta_list.append(v)
860
        elif k == 'harden-glue':
861
            harden_glue = str2bool(v)
862
        elif k == 'query-minimization':
863 864
            qmin = str2bool(v)
        elif k == 'trust-anchor':
865 866 867 868 869 870
            trust_anchor = v.strip('"\'')
            trust_anchor_list.append(trust_anchor)
            domain = dns.name.from_text(trust_anchor.split()[0]).canonicalize()
            if domain not in trust_anchor_files:
                trust_anchor_files[domain] = []
            trust_anchor_files[domain].append(trust_anchor)
871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899
        elif k == 'val-override-timestamp':
            override_timestamp_str = v.strip('"\'')
            override_timestamp = int(override_timestamp_str)
        elif k == 'val-override-date':
            override_date_str = v.strip('"\'')
            ovr_yr = override_date_str[0:4]
            ovr_mnt = override_date_str[4:6]
            ovr_day = override_date_str[6:8]
            ovr_hr = override_date_str[8:10]
            ovr_min = override_date_str[10:12]
            ovr_sec = override_date_str[12:]
            override_date_str_arg = '{0} {1} {2} {3} {4} {5}'.format(
                ovr_yr, ovr_mnt, ovr_day, ovr_hr, ovr_min, ovr_sec)
            override_date = time.strptime(override_date_str_arg, "%Y %m %d %H %M %S")
            override_timestamp = calendar.timegm(override_date)
        elif k == 'stub-addr':
            stub_addr = v.strip('"\'')
        elif k == 'features':
            feature_list = v.split(feature_list_delimiter)
            try:
                for f_item in feature_list:
                    if f_item.find(feature_pair_delimiter) != -1:
                        f_key, f_value = [x.strip()
                                          for x
                                          in f_item.split(feature_pair_delimiter, 1)]
                    else:
                        f_key = f_item.strip()
                        f_value = ""
                    features[f_key] = f_value
900 901
            except KeyError as ex:
                raise KeyError("can't parse features (%s) in config section (%s)" % (v, str(ex)))
902 903 904 905 906 907 908
        elif k == 'feature-list':
            try:
                f_key, f_value = [x.strip() for x in v.split(feature_pair_delimiter, 1)]
                if f_key not in features:
                    features[f_key] = []
                f_value = f_value.replace("{{INSTALL_DIR}}", installdir)
                features[f_key].append(f_value)
909 910 911
            except KeyError as ex:
                raise KeyError("can't parse feature-list (%s) in config section (%s)"
                               % (v, str(ex)))
912 913
        elif k == 'force-ipv6' and v.upper() == 'TRUE':
            sockfamily = socket.AF_INET6
914 915
        else:
            raise NotImplementedError('unsupported CONFIG key "%s"' % k)
916 917 918

    ctx = {
        "DO_NOT_QUERY_LOCALHOST": str(do_not_query_localhost).lower(),
919
        "NEGATIVE_TRUST_ANCHORS": negative_ta_list,
920 921 922 923 924
        "FEATURES": features,
        "HARDEN_GLUE": str(harden_glue).lower(),
        "INSTALL_DIR": installdir,
        "QMIN": str(qmin).lower(),
        "TRUST_ANCHORS": trust_anchor_list,
925
        "TRUST_ANCHOR_FILES": trust_anchor_files.keys()
926 927 928 929 930 931 932 933 934 935 936 937 938
    }
    if stub_addr:
        ctx['ROOT_ADDR'] = stub_addr
        # determine and verify socket family for specified root address
        gai = socket.getaddrinfo(stub_addr, 53, sockfamily, 0,
                                 socket.IPPROTO_UDP, socket.AI_NUMERICHOST)
        assert len(gai) == 1
        sockfamily = gai[0][0]
    if not sockfamily:
        sockfamily = socket.AF_INET  # default to IPv4
    ctx['_SOCKET_FAMILY'] = sockfamily
    if override_timestamp:
        ctx['_OVERRIDE_TIMESTAMP'] = override_timestamp
939
    return (ctx, trust_anchor_files)
940 941


942
def parse_file(path):
943
    """ Parse scenario from a file. """
944

945 946
    aug = pydnstest.augwrap.AugeasWrapper(
        confpath=path, lens='Deckard', loadpath=os.path.dirname(__file__))
947
    node = aug.tree
948 949 950
    config = []
    for line in [c.value for c in node.match("/config/*")]:
        if line:
951 952 953 954 955
            if not line.startswith(';'):
                if '#' in line:
                    line = line[0:line.index('#')]
                # Break to key-value pairs
                # e.g.: ['minimization', 'on']
956
                kv = [x.strip() for x in line.split(':', 1)]
957 958
                if len(kv) >= 2:
                    config.append(kv)
959 960
    scenario = Scenario(node["/scenario"], posixpath.basename(node.path))
    return scenario, config