server.py 32.7 KB
Newer Older
1 2
#!/usr/bin/env python3

3
import glob
4
import inspect
5
import psutil
6
import re
7 8 9 10
import random
import shutil
import socket
import time
11 12
import dns.message
import dns.query
13
import dns.update
14
from subprocess import Popen, PIPE, check_call, CalledProcessError
15
from dnstest.utils import *
16
import dnstest.inquirer
17 18 19 20
import dnstest.params as params
import dnstest.keys
import dnstest.response
import dnstest.update
21

22 23 24 25
def zone_arg_check(zone):
    # Convert one item list to single object.
    if isinstance(zone, list):
        if len(zone) != 1:
26
            raise Failed("One zone required")
27 28 29
        return zone[0]
    return zone

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
class KnotConf(object):
    '''Knot server config generator'''

    def __init__(self):
        self.conf = ""
        self.indent = ""

    def sub(self):
        self.indent += "\t"

    def unsub(self):
        self.indent = self.indent[:-1]

    def begin(self, name):
        self.conf += "%s%s {\n" % (self.indent, name)
        self.sub()

    def end(self):
        self.unsub()
        self.conf += "%s}\n" % (self.indent)
        if not self.indent:
            self.conf += "\n"

    def item(self, name, value):
        self.conf += "%s%s %s;\n" % (self.indent, name, value)

    def item_str(self, name, value):
        self.conf += "%s%s \"%s\";\n" % (self.indent, name, value)

class BindConf(object):
    '''Bind server config generator'''

    def __init__(self):
        self.conf = ""
        self.indent = ""

    def sub(self):
        self.indent += "\t"

    def unsub(self):
        self.indent = self.indent[:-1]

    def begin(self, name, string=None):
        if string:
            self.conf += "%s%s \"%s\" {\n" % (self.indent, name, string)
        else:
            self.conf += "%s%s {\n" % (self.indent, name)
        self.sub()

    def end(self):
        self.unsub()
        self.conf += "%s};\n" % (self.indent)
        if not self.indent:
            self.conf += "\n"

    def item(self, name, value=None):
        if value:
            self.conf += "%s%s %s;\n" % (self.indent, name, value)
        else:
            self.conf += "%s%s;\n" % (self.indent, name)

    def item_str(self, name, value):
        self.conf += "%s%s \"%s\";\n" % (self.indent, name, value)

class Zone(object):
95
    '''DNS zone description'''
96

97
    def __init__(self, zone_file=None, ddns=False, ixfr=False):
98
        self.zfile = zone_file
99 100 101
        self.master = None
        self.slaves = set()
        self.ddns = ddns
102 103
        # ixfr from differences
        self.ixfr = ixfr
104 105
        # modules
        self.query_modules = []
106

107 108 109 110
    @property
    def name(self):
        return self.zfile.name

111 112 113
    def add_query_module(self, module, param):
        self.query_modules.append((module, param))

114
class Server(object):
115 116
    '''Specification of DNS server'''

117
    START_WAIT = 2
118
    START_WAIT_VALGRIND = 5
119
    STOP_TIMEOUT = 30
120
    COMPILE_TIMEOUT = 60
121
    DIG_TIMEOUT = 5
122 123 124 125 126

    # Instance counter.
    count = 0

    def __init__(self):
127 128
        self.proc = None
        self.valgrind = []
129
        self.start_params = None
130 131
        self.reload_params = None
        self.flush_params = None
132 133
        self.compile_params = None

134 135
        self.data_dir = None

136 137
        self.dnssec_enable = None

138 139 140 141 142 143 144 145
        self.nsid = None
        self.ident = None
        self.version = None

        self.ip = None
        self.addr = None
        self.port = None
        self.ctlport = None
146
        self.ctlkey = None
147
        self.ctlkeyfile = None
148
        self.tsig = None
149
        self.tsig_test = None
150 151

        self.zones = dict()
152 153 154

        self.ratelimit = None
        self.disable_any = None
155
        self.disable_notify = None
156
        self.max_conn_idle = None
157
        self.zonefile_sync = None
158

159 160
        self.inquirer = None

161 162 163 164 165 166 167 168
        # Working directory.
        self.dir = None
        # Name of server instance.
        self.name = None
        self.fout = None
        self.ferr = None
        self.conffile = None

169 170 171 172 173 174
    def _check_socket(self, proto, port):
        if self.ip == 4:
            iface = "%i%s@%s:%i" % (self.ip, proto, self.addr, port)
        else:
            iface = "%i%s@[%s]:%i" % (self.ip, proto, self.addr, port)

175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        for i in range(5):
            proc = Popen(["lsof", "-t", "-i", iface],
                         stdout=PIPE, stderr=PIPE, universal_newlines=True)
            (out, err) = proc.communicate()

            # Create list of pids excluding last empty line.
            pids = list(filter(None, out.split("\n")))

            # Check for successful bind.
            if len(pids) == 1 and str(self.proc.pid) in pids:
                return True

            time.sleep(1)

        return False
190

191
    def set_master(self, zone, slave=None, ddns=False, ixfr=False):
192 193 194 195
        '''Set the server as a master for the zone'''

        if zone.name not in self.zones:
            master_file = zone.clone(self.dir + "/master")
196
            z = Zone(master_file, ddns, ixfr)
197
            self.zones[zone.name] = z
198
        else:
199
            z = self.zones[zone.name]
200

201 202 203
        if slave:
            z.slaves.add(slave)

204
    def set_slave(self, zone, master, ddns=False, ixfr=False):
205 206 207
        '''Set the server as a slave for the zone'''

        if zone.name in self.zones:
208
            raise Failed("Can't set zone='%s' as a slave" % zone.name)
209 210

        slave_file = zone.clone(self.dir + "/slave", exists=False)
211
        z = Zone(slave_file, ddns, ixfr)
212 213
        z.master = master
        self.zones[zone.name] = z
214 215 216 217

    def compile(self):
        try:
            p = Popen([self.control_bin] + self.compile_params,
218
                      stdout=self.fout, stderr=self.ferr)
219
            p.communicate(timeout=Server.COMPILE_TIMEOUT)
220
        except:
221
            raise Failed("Can't compile server='%s'" %self.name)
222

223 224 225
    def start(self, clean=False):
        mode = "w" if clean else "a"

226 227 228 229
        try:
            if self.compile_params:
                self.compile()

230 231
            if self.daemon_bin != None:
                self.proc = Popen(self.valgrind + [self.daemon_bin] + \
232 233 234
                                  self.start_params,
                                  stdout=open(self.fout, mode=mode),
                                  stderr=open(self.ferr, mode=mode))
235

236
            if self.valgrind:
237
                time.sleep(Server.START_WAIT_VALGRIND)
238
            else:
239
                time.sleep(Server.START_WAIT)
240
        except OSError:
241
            raise Failed("Can't start server='%s'" % self.name)
242

243 244 245 246
        # Start inquirer if enabled.
        if params.test.stress and self.inquirer:
            self.inquirer.start(self)

247 248
    def reload(self):
        try:
249 250 251
            check_call([self.control_bin] + self.reload_params,
                       stdout=open(self.dir + "/call.out", mode="a"),
                       stderr=open(self.dir + "/call.err", mode="a"))
252
            time.sleep(Server.START_WAIT)
253
        except CalledProcessError as e:
254
            self.backtrace()
255 256
            raise Failed("Can't reload server='%s', ret='%i'" %
                         (self.name, e.returncode))
257 258 259

    def flush(self):
        try:
260 261 262 263 264
            check_call([self.control_bin] + self.flush_params,
                       stdout=open(self.dir + "/call.out", mode="a"),
                       stderr=open(self.dir + "/call.err", mode="a"))
            time.sleep(Server.START_WAIT)
        except CalledProcessError as e:
265
            self.backtrace()
266 267
            raise Failed("Can't flush server='%s', ret='%i'" %
                         (self.name, e.returncode))
268

269 270
    def running(self):
        proc = psutil.Process(self.proc.pid)
271 272 273 274 275 276 277
        status = proc.status
        # psutil 2.0.0+ makes status a function
        if psutil.version_info[0] >= 2:
            status = proc.status()
        if status == psutil.STATUS_RUNNING or \
           status == psutil.STATUS_SLEEPING or \
           status == psutil.STATUS_DISK_SLEEP:
278 279 280 281
            return True
        else:
            return False

282 283 284 285
    def _valgrind_check(self):
        if not self.valgrind:
            return

286 287 288
        check_log("VALGRIND CHECK %s" % self.name)

        lock = False
289
        lost = 0
290
        reachable = 0
291
        errcount = 0
292

293 294 295 296 297 298 299
        try:
            f = open(self.ferr, "r")
        except:
            detail_log("No err log file")
            detail_log(SEP)
            return

300
        for line in f:
301
            if re.search("(HEAP|LEAK) SUMMARY", line):
302
                lost = 0
303
                reachable = 0
304
                errcount = 0
305
                lock = True
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
                continue

            if lock:
                lost_line = re.search("lost:", line)
                if lost_line:
                    lost += int(line[lost_line.end():].lstrip(). \
                                split(" ")[0].replace(",", ""))
                    continue

                reach_line = re.search("reachable:", line)
                if reach_line:
                    reachable += int(line[reach_line.end():].lstrip(). \
                                     split(" ")[0].replace(",", ""))
                    continue

                err_line = re.search("ERROR SUMMARY:", line)
                if err_line:
                    errcount += int(line[err_line.end():].lstrip(). \
                                    split(" ")[0].replace(",", ""))

326
                    if lost > 0 or reachable > 32 or errcount > 0:
327
                        set_err("VALGRIND")
328 329 330
                        detail_log("%s memcheck: lost(%i B), reachable(%i B), " \
                                   "errcount(%i)" \
                                   % (self.name, lost, reachable, errcount))
331 332 333

                    lock = False
                    continue
334 335

        detail_log(SEP)
336 337
        f.close()

338 339 340 341 342 343 344 345
    def backtrace(self):
        if self.valgrind:
            check_log("BACKTRACE %s" % self.name)

            try:
                check_call([params.gdb_bin, "-ex", "set confirm off", "-ex",
                            "target remote | %s --pid=%s" %
                            (params.vgdb_bin, self.proc.pid),
346
                            "-ex", "info threads",
347 348
                            "-ex", "thread apply all bt full", "-ex", "q",
                            self.daemon_bin],
349 350
                           stdout=open(self.dir + "/gdb.out", mode="a"),
                           stderr=open(self.dir + "/gdb.err", mode="a"))
351 352 353 354 355
            except:
                detail_log("!Failed to get backtrace")

            detail_log(SEP)

356
    def stop(self, check=True):
357 358 359
        if self.proc:
            try:
                self.proc.terminate()
360
                self.proc.wait(Server.STOP_TIMEOUT)
361 362
            except ProcessLookupError:
                pass
363
            except:
364
                self.backtrace()
365 366
                check_log("WARNING: KILLING %s" % self.name)
                detail_log(SEP)
367
                self.kill()
368 369
        if check:
            self._valgrind_check()
370

371 372 373
        if self.inquirer:
            self.inquirer.stop()

374 375
    def kill(self):
        if self.proc:
376 377 378
            # Store PID before kill.
            pid = self.proc.pid

379 380
            self.proc.kill()

381 382 383 384 385 386 387
            # Remove uncleaned vgdb pipes.
            for f in glob.glob("/tmp/vgdb-pipe*-%s-*" % pid):
                try:
                    os.remove(f)
                except:
                    pass

388 389 390
        if self.inquirer:
            self.inquirer.stop()

391
    def gen_confile(self):
392
        f = open(self.confile, mode="w")
393
        f.write(self.get_config())
394
        f.close()
395

396
    def dig(self, rname, rtype, rclass="IN", udp=None, serial=None,
397
            timeout=None, tries=3, flags="", bufsize=None, edns=None,
398
            nsid=False, dnssec=False, log_no_sep=False):
399
        key_params = self.tsig_test.key_params if self.tsig_test else dict()
400

401 402 403
        # Convert one item zone list to zone name.
        if isinstance(rname, list):
            if len(rname) != 1:
404
                raise Failed("One zone required")
405 406
            rname = rname[0].name

407 408 409
        rtype_str = rtype.upper()

        # Set port type.
Daniel Salzman's avatar
Daniel Salzman committed
410 411 412 413 414 415
        if rtype.upper() == "AXFR":
            # Always use TCP.
            udp = False
        elif rtype.upper() == "IXFR":
            # Use TCP if not specified.
            udp = udp if udp != None else False
416
            rtype_str += "=%i" % int(serial)
Daniel Salzman's avatar
Daniel Salzman committed
417 418 419 420
        else:
            # Use TCP or UDP at random if not specified.
            udp = udp if udp != None else random.choice([True, False])

421 422 423 424 425
        if udp:
            dig_flags = "+notcp"
        else:
            dig_flags = "+tcp"

426
        dig_flags += " +retry=%i" % (tries - 1)
427 428 429 430 431 432 433 434 435

        # Set timeout.
        if timeout is None:
            timeout = self.DIG_TIMEOUT
        dig_flags += " +time=%i" % timeout

        # Prepare query (useless for XFR).
        query = dns.message.make_query(rname, rtype, rclass)

436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
        # Remove implicit RD flag.
        query.flags &= ~dns.flags.RD

        # Set packet flags.
        flag_names = flags.split()
        for flag in flag_names:
            if flag == "AA":
                query.flags |= dns.flags.AA
                dig_flags += " +aa"
            elif flag == "TC":
                query.flags |= dns.flags.TC
                dig_flags += " +tc"
            elif flag == "RD":
                query.flags |= dns.flags.RD
                dig_flags += " +rd"
            elif flag == "RA":
                query.flags |= dns.flags.RA
                dig_flags += " +ra"
            elif flag == "AD":
                query.flags |= dns.flags.AD
                dig_flags += " +ad"
            elif flag == "CD":
                query.flags |= dns.flags.CD
                dig_flags += " +cd"
460 461

        # Set EDNS.
462
        if edns != None or bufsize or nsid:
463 464 465 466 467 468 469
            class NsidFix(object):
                '''Current pythondns doesn't implement NSID option.'''
                def __init__(self):
                    self.otype = dns.edns.NSID
                def to_wire(self, file=None):
                    pass

470 471 472 473 474 475
            if edns:
                edns = int(edns)
            else:
                edns = 0
            dig_flags += " +edns=%i" % edns

476 477 478 479 480 481 482 483 484 485 486 487
            if bufsize:
                payload = int(bufsize)
            else:
                payload = 1280
            dig_flags += " +bufsize=%i" % payload

            if nsid:
                options = [NsidFix()]
                dig_flags += " +nsid"
            else:
                options = None

488
            query.use_edns(edns=edns, payload=payload, options=options)
489 490 491 492 493 494

        # Set DO flag.
        if dnssec:
            query.want_dnssec()
            dig_flags += " +dnssec +bufsize=%i" % query.payload

495 496 497 498 499 500 501
        # Store function arguments for possible comparation.
        args = dict()
        params = inspect.getargvalues(inspect.currentframe())
        for param in params.args:
            if param != "self":
                args[param] = params.locals[param]

502 503 504
        check_log("DIG %s %s %s @%s -p %i %s" %
                  (rname, rtype_str, rclass, self.addr, self.port, dig_flags))
        if key_params:
505 506
            detail_log("%s:%s:%s" %
                (self.tsig_test.alg, self.tsig_test.name, self.tsig_test.key))
Daniel Salzman's avatar
Daniel Salzman committed
507

508 509 510
        for t in range(tries):
            try:
                if rtype.upper() == "AXFR":
511 512
                    resp = dns.query.xfr(self.addr, rname, rtype, rclass,
                                         port=self.port, lifetime=timeout,
Daniel Salzman's avatar
Daniel Salzman committed
513
                                         use_udp=udp, **key_params)
514
                elif rtype.upper() == "IXFR":
515 516 517
                    resp = dns.query.xfr(self.addr, rname, rtype, rclass,
                                         port=self.port, lifetime=timeout,
                                         use_udp=udp, serial=int(serial),
518
                                         **key_params)
519 520 521
                elif udp:
                    resp = dns.query.udp(query, self.addr, port=self.port,
                                         timeout=timeout)
522
                else:
523 524
                    resp = dns.query.tcp(query, self.addr, port=self.port,
                                         timeout=timeout)
525

526 527
                if not log_no_sep:
                    detail_log(SEP)
528

529
                return dnstest.response.Response(self, resp, query, args)
530 531
            except dns.exception.Timeout:
                pass
532 533 534
            except:
                time.sleep(timeout)

535 536
        raise Failed("Can't query server='%s' for '%s %s %s'" % \
                     (self.name, rname, rclass, rtype))
537

538 539 540 541 542 543 544 545 546 547 548
    def create_sock(self, socket_type):
        family = socket.AF_INET
        if self.ip == 6:
            family = socket.AF_INET6
        return socket.socket(family, socket_type)

    def send_raw(self, data, sock=None):
        if sock is None:
            sock = self.create_sock(socket.SOCK_DGRAM)
        sent = sock.sendto(bytes(data, 'utf-8'), (self.addr, self.port))
        if sent != len(data):
549 550
            raise Failed("Can't send RAW data (%d bytes) to server='%s'" %
                         (len(data), self.name))
551 552 553 554

    def zone_wait(self, zone, serial=None):
        '''Try to get SOA record with serial higher then specified'''

555
        zone = zone_arg_check(zone)
556

557 558
        _serial = 0

559 560
        check_log("ZONE WAIT %s: %s" % (self.name, zone.name))

561
        for t in range(60):
562 563
            try:
                resp = self.dig(zone.name, "SOA", udp=True, tries=1,
564
                                timeout=2, log_no_sep=True)
565 566 567 568 569
            except:
                pass
            else:
                if resp.resp.rcode() == 0:
                    if not resp.resp.answer:
570 571
                        raise Failed("No SOA in ANSWER, zone='%s', server='%s'" %
                                     (zone.name, self.name))
572 573 574 575 576 577 578 579

                    soa = str((resp.resp.answer[0]).to_rdataset())
                    _serial = int(soa.split()[5])

                    if serial:
                        if serial < _serial:
                            break
                    else:
580
                        break
581
            time.sleep(2)
582
        else:
583
            self.backtrace()
584 585 586
            raise Failed("Can't get SOA%s, zone='%s', server='%s'" %
                         (" serial > %i" % serial if serial else "",
                          zone.name, self.name))
587

588 589
        detail_log(SEP)

590
        return _serial
591

592 593 594
    def zones_wait(self, zone_list, serials=None):
        new_serials = dict()

595
        for zone in zone_list:
596 597 598 599 600
            old_serial = serials[zone.name] if serials else None
            new_serial = self.zone_wait(zone, serial=old_serial)
            new_serials[zone.name] = new_serial

        return new_serials
601

602
    def zone_verify(self, zone):
603
        zone = zone_arg_check(zone)
604 605 606

        self.zones[zone.name].zfile.dnssec_verify()

607
    def check_nsec(self, zone, nsec3=False, nonsec=False):
608
        zone = zone_arg_check(zone)
609 610 611 612

        resp = self.dig("0-x-not-existing-x-0." + zone.name, "ANY", dnssec=True)
        resp.check_nsec(nsec3=nsec3, nonsec=nonsec)

613
    def update(self, zone):
614
        zone = zone_arg_check(zone)
615

616
        key_params = self.tsig_test.key_params if self.tsig_test else dict()
617

618
        return dnstest.update.Update(self, dns.update.Update(zone.name,
619
                                                             **key_params))
620

621
    def gen_key(self, zone, **args):
622
        zone = zone_arg_check(zone)
623

624
        prepare_dir(self.keydir)
625 626 627 628 629
        key = dnstest.keys.Key(self.keydir, zone.name, **args)
        key.generate()

        return key

630
    def use_keys(self, zone):
631
        zone = zone_arg_check(zone)
632

633 634 635
        # Copy generated keys to server key directory.
        prepare_dir(self.keydir)

636
        src_files = os.listdir(zone.key_dir)
637
        for file_name in src_files:
638
            if (zone.name[:-1]).lower() in file_name:
639 640 641
                full_file_name = os.path.join(zone.key_dir, file_name)
                if (os.path.isfile(full_file_name)):
                    shutil.copy(full_file_name, self.keydir)
642

643
    def enable_nsec3(self, zone, **args):
644
        zone = zone_arg_check(zone)
645 646 647 648

        self.zones[zone.name].zfile.enable_nsec3(**args)

    def disable_nsec3(self, zone):
649
        zone = zone_arg_check(zone)
650 651 652

        self.zones[zone.name].zfile.disable_nsec3()

653
    def backup_zone(self, zone):
654
        zone = zone_arg_check(zone)
655 656 657

        self.zones[zone.name].zfile.backup()

658
    def update_zonefile(self, zone, version=None, random=False):
659
        zone = zone_arg_check(zone)
660

661 662 663
        if random:
            self.zones[zone.name].zfile.update_rnd()
        else:
664 665
            self.zones[zone.name].zfile.upd_file(storage=self.data_dir,
                                                 version=version)
666

667 668 669 670
    def add_query_module(self, zone, module, param):
        # Convert one item list to single object.
        if isinstance(zone, list):
            if len(zone) != 1:
671
                raise Failed("One zone required")
672 673 674 675
            zone = zone[0]

        self.zones[zone.name].add_query_module(module, param)

676
class Bind(Server):
677

678 679
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
680 681
        if not params.bind_bin:
            raise Skip("No Bind")
682 683
        self.daemon_bin = params.bind_bin
        self.control_bin = params.bind_ctl
684
        self.ctlkey = dnstest.keys.Tsig(alg="hmac-md5")
685

686
    def listening(self):
687 688 689 690 691
        tcp = super()._check_socket("tcp", self.port)
        udp = super()._check_socket("udp", self.port)
        ctltcp = super()._check_socket("tcp", self.ctlport)
        return (tcp and udp and ctltcp)

692 693 694 695
    def _str(self, conf, name, value):
        if value and value != True:
            conf.item_str(name, value)

696 697 698
    def get_config(self):
        s = BindConf()
        s.begin("options")
699 700
        self._str(s, "server-id", self.ident)
        self._str(s, "version", self.version)
701
        s.item_str("directory", self.dir)
702 703 704
        s.item_str("key-directory", self.dir)
        s.item_str("managed-keys-directory", self.dir)
        s.item_str("session-keyfile", self.dir + "/session.key")
705 706 707
        s.item_str("pid-file", "bind.pid")
        if self.ip == 4:
            s.item("listen-on port", "%i { %s; }" % (self.port, self.addr))
708
            s.item("listen-on-v6", "{ }")
709
        else:
710
            s.item("listen-on", "{ }")
711 712 713
            s.item("listen-on-v6 port", "%i { %s; }" % (self.port, self.addr))
        s.item("auth-nxdomain", "no")
        s.item("recursion", "no")
714
        s.item("masterfile-format", "text")
715 716 717 718
        s.item("max-refresh-time", "2")
        s.item("max-retry-time", "2")
        s.item("transfers-in", "30")
        s.item("transfers-out", "30")
719 720
        s.end()

721 722 723 724 725 726 727 728 729 730 731 732 733
        s.begin("key", self.ctlkey.name)
        s.item("algorithm", self.ctlkey.alg)
        s.item_str("secret", self.ctlkey.key)
        s.end()

        s.begin("controls")
        s.item("inet %s port %i allow { %s; } keys { %s; }"
               % (self.addr, self.ctlport, self.addr, self.ctlkey.name))
        s.end()

        if self.tsig:
            t = self.tsig
            s.begin("key", t.name)
734
            s.item("# Local key")
735 736 737
            s.item("algorithm", t.alg)
            s.item_str("secret", t.key)
            s.end()
738 739 740 741 742 743
            t = self.tsig_test
            s.begin("key", t.name)
            s.item("# Test key")
            s.item("algorithm", t.alg)
            s.item_str("secret", t.key)
            s.end()
744 745

            keys = set() # Duplicy check.
746
            for zone in sorted(self.zones):
747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
                z = self.zones[zone]
                if z.master and z.master.tsig.name not in keys:
                    t = z.master.tsig
                    s.begin("key", t.name)
                    s.item("algorithm", t.alg)
                    s.item_str("secret", t.key)
                    s.end()
                    keys.add(t.name)
                for slave in z.slaves:
                    if slave.tsig and slave.tsig.name not in keys:
                        t = slave.tsig
                        s.begin("key", t.name)
                        s.item("algorithm", t.alg)
                        s.item_str("secret", t.key)
                        s.end()
                        keys.add(t.name)

764
        for zone in sorted(self.zones):
765 766
            z = self.zones[zone]
            s.begin("zone", z.name)
767
            s.item_str("file", z.zfile.path)
768
            s.item("check-names", "warn")
769 770
            if z.master:
                s.item("type", "slave")
771 772 773 774 775 776 777 778 779

                if self.tsig:
                    s.item("allow-notify", "{ key %s; }" % z.master.tsig.name)
                    s.item("masters", "{ %s port %i key %s; }" \
                           % (z.master.addr, z.master.port, z.master.tsig.name))
                else:
                    s.item("allow-notify", "{ %s; }" % z.master.addr)
                    s.item("masters", "{ %s port %i; }" \
                           % (z.master.addr, z.master.port))
780 781 782
            else:
                s.item("type", "master")
                s.item("notify", "explicit")
783

784 785
            if z.ixfr and not z.master:
                s.item("ixfr-from-differences", "yes")
786 787 788 789 790 791

            if z.slaves:
                slaves = ""
                for slave in z.slaves:
                    if self.tsig:
                        slaves += "%s port %i key %s; " \
792
                                  % (slave.addr, slave.port, self.tsig.name)
793 794 795 796
                    else:
                        slaves += "%s port %i; " % (slave.addr, slave.port)
                s.item("also-notify", "{ %s}" % slaves)

797
            if z.ddns:
798 799 800 801 802
                if self.tsig:
                    upd = "key %s; " % self.tsig_test.name
                else:
                    upd = "%s; " % self.addr

803
                if z.master:
804
                    s.item("allow-update-forwarding", "{ %s}" % upd)
805
                else:
806
                    s.item("allow-update", "{ %s}" % upd)
807

808
            if self.tsig:
809 810
                s.item("allow-transfer", "{ key %s; key %s; }" %
                       (self.tsig.name, self.tsig_test.name))
811 812
            else:
                s.item("allow-transfer", "{ %s; }" % self.addr)
813 814
            s.end()

815
        self.start_params = ["-c", self.confile, "-g"]
816 817 818 819
        self.reload_params = ["-s", self.addr, "-p", str(self.ctlport), \
                              "-k", self.ctlkeyfile, "reload"]
        self.flush_params = ["-s", self.addr, "-p", str(self.ctlport), \
                             "-k", self.ctlkeyfile, "flush"]
820 821 822

        return s.conf

823
class Knot(Server):
824

825 826
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
827 828
        if not params.knot_bin:
            raise Skip("No Knot")
829 830
        self.daemon_bin = params.knot_bin
        self.control_bin = params.knot_ctl
831
        self.inquirer = dnstest.inquirer.Inquirer()
832

833 834 835 836
    @property
    def keydir(self):
        return os.path.join(self.dir, "keys")

837
    def listening(self):
838 839 840 841
        tcp = super()._check_socket("tcp", self.port)
        udp = super()._check_socket("udp", self.port)
        return (tcp and udp)

842 843 844
    def _on_str_hex(self, conf, name, value):
        if value == True:
            conf.item(name, "on")
845 846
        elif value == False:
            conf.item(name, "off")
847
        elif value:
848
            if isinstance(value, int) or value[:2] == "0x":
849 850 851 852 853 854 855 856 857 858
                conf.item(name, value)
            else:
                conf.item_str(name, value)

    def get_config(self):
        s = KnotConf()
        s.begin("system")
        self._on_str_hex(s, "identity", self.ident)
        self._on_str_hex(s, "version", self.version)
        self._on_str_hex(s, "nsid", self.nsid)
859
        self._on_str_hex(s, "rate-limit", self.ratelimit)
860
        s.item_str("rundir", self.dir)
861 862
        if (self.max_conn_idle):
            s.item("max-conn-idle", self.max_conn_idle)
863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879
        s.end()

        s.begin("control")
        s.item_str("listen-on", "knot.sock")
        s.end()

        s.begin("interfaces")
        if self.ip == 4:
            s.begin("ipv4")
        else:
            s.begin("ipv6")
        s.item("address", self.addr)
        s.item("port", self.port)
        s.end()
        s.end()

        if self.tsig:
880
            s.begin("keys")
881
            t = self.tsig
882
            s.item_str("\"%s\" %s" % (t.name, t.alg), t.key)
883 884
            t = self.tsig_test
            s.item_str("\"%s\" %s" % (t.name, t.alg), t.key)
885 886

            keys = set() # Duplicy check.
887
            for zone in sorted(self.zones):
888 889 890
                z = self.zones[zone]
                if z.master and z.master.tsig.name not in keys:
                    t = z.master.tsig
891
                    s.item_str("\"%s\" %s" % (t.name, t.alg), t.key)
892
                    keys.add(t.name)
893 894 895
                for slave in z.slaves:
                    if slave.tsig and slave.tsig.name not in keys:
                        t = slave.tsig
896
                        s.item_str("\"%s\" %s" % (t.name, t.alg), t.key)
897 898
                        keys.add(t.name)
            s.end()
899 900 901 902 903

        s.begin("remotes")
        s.begin("local")
        s.item("address", self.addr)
        if self.tsig:
904
            s.item_str("key", self.tsig.name)
905
        s.end()
906 907 908 909 910
        s.begin("test")
        s.item("address", self.addr)
        if self.tsig_test:
            s.item_str("key", self.tsig_test.name)
        s.end()
911 912

        servers = set() # Duplicity check.
913
        for zone in sorted(self.zones):
914 915 916 917 918 919
            z = self.zones[zone]
            if z.master and z.master.name not in servers:
                s.begin(z.master.name)
                s.item("address", z.master.addr)
                s.item("port", z.master.port)
                if z.master.tsig:
920
                    s.item_str("key", z.master.tsig.name)
921 922 923 924 925 926 927 928
                s.end()
                servers.add(z.master.name)
            for slave in z.slaves:
                if slave.name not in servers:
                    s.begin(slave.name)
                    s.item("address", slave.addr)
                    s.item("port", slave.port)
                    if slave.tsig:
929
                        s.item_str("key", self.tsig.name)
930 931 932 933 934
                    s.end()
                    servers.add(slave.name)
        s.end()

        s.begin("zones")
935
        s.item_str("storage", self.dir)
936 937 938 939
        if self.zonefile_sync:
            s.item("zonefile-sync", self.zonefile_sync)
        else:
            s.item("zonefile-sync", "1d")
940 941
        s.item("notify-timeout", "5")
        s.item("notify-retries", "5")
942
        s.item("semantic-checks", "on")
943 944
        if self.disable_any:
            s.item("disable-any", "on")
945 946 947
        if self.dnssec_enable:
            s.item_str("dnssec-keydir", self.keydir)
            s.item("dnssec-enable", "on")
948
        for zone in sorted(self.zones):
949 950
            z = self.zones[zone]
            s.begin(z.name)
951
            s.item_str("file", z.zfile.path)
952 953

            if z.master:
954 955
                if not self.disable_notify:
                    s.item("notify-in", z.master.name)
956 957
                s.item("xfr-in", z.master.name)

958
            slaves = ""
959 960
            if z.slaves:
                for slave in z.slaves:
961 962 963 964
                    if slaves:
                        slaves += ", "
                    slaves += slave.name
                s.item("notify-out", slaves)
965

966
            s.item("xfr-out", "local, test")
967

968
            if z.ddns:
969
                s.item("update-in", "test")
970 971

            if z.ixfr and not z.master:
972
                s.item("ixfr-from-differences", "on")
973 974 975 976 977 978

            if len(z.query_modules) > 0:
                s.begin("query_module")
                for query_module in z.query_modules:
                    s.item(query_module[0], '"' + query_module[1] + '"')
                s.end()
979 980 981 982 983
            s.end()
        s.end()

        s.begin("log")
        s.begin("stdout")
984
        s.item("any", "debug")
985 986 987 988 989 990 991 992
        s.end()
        s.begin("stderr")
        s.end()
        s.begin("syslog")
        s.end()
        s.end()

        self.start_params = ["-c", self.confile]
993 994
        self.reload_params = ["-c", self.confile, "reload"]
        self.flush_params = ["-c", self.confile, "flush"]
995 996 997

        return s.conf

998 999
    def ctl(self, params):
        try:
1000
            check_call([self.control_bin] + self.start_params + params.split(),
1001 1002 1003 1004 1005 1006 1007 1008
                       stdout=open(self.dir + "/call.out", mode="a"),
                       stderr=open(self.dir + "/call.err", mode="a"))
            time.sleep(Server.START_WAIT)
        except CalledProcessError as e:
            self.backtrace()
            raise Failed("Can't control='%s' server='%s', ret='%i'" %
                         (params, self.name, e.returncode))

1009
class Nsd(Server):
1010

1011 1012
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
1013 1014
        if not params.nsd_bin:
            raise Skip("No NSD")
1015 1016
        self.daemon_bin = params.nsd_bin
        self.control_bin = params.nsd_ctl
1017 1018 1019 1020 1021

    def get_config(self):
        self.start_params = ["-c", self.confile, "-d"]
        self.compile_params = ["-c", self.confile, "rebuild"]

1022
class Dummy(Server):
1023 1024
    ''' Dummy name server. '''

1025 1026
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
1027 1028 1029 1030
        self.daemon_bin = None
        self.control_bin = None

    def get_config(self):
1031
        return ''
1032

1033
    def start(self, clean=None):
1034 1035
        return True

1036 1037
    def listening(self):
        return True # Fake listening
1038 1039 1040

    def running(self):
        return True # Fake running