diffrepro.py 5.9 KB
Newer Older
1 2 3
#!/usr/bin/env python3

import argparse
4
from itertools import zip_longest
5
import logging
6 7
from multiprocessing import pool
import random
8
import subprocess
9
import sys
10 11 12
from typing import (  # noqa
    Any, AbstractSet, Iterable, Iterator, Mapping, Sequence, Tuple, TypeVar,
    Union)
13

14
from respdiff import cli, sendrecv
15
from respdiff.database import (
16
    DNSReply, DNSRepliesFactory, key2qid, LMDB, MetaDatabase,
17
    ResolverID, qid2key, QKey, WireFormat)
18
from respdiff.dataformat import Diff, DiffReport, FieldLabel, ReproData, QID  # noqa
19
from respdiff.match import compare
20
from respdiff.query import get_query_iterator
21 22 23


T = TypeVar('T')
24 25


26 27 28 29 30 31 32 33 34
def restart_resolver(script_path: str) -> None:
    try:
        subprocess.check_call(script_path)
    except subprocess.CalledProcessError as exc:
        logging.warning('Resolver restart failed (exit code %d): %s',
                        exc.returncode, script_path)
    except PermissionError as exc:
        logging.warning('Resolver restart failed (permission error): %s',
                        script_path)
35 36


37 38 39 40 41 42 43 44
def get_restart_scripts(config: Mapping[str, Any]) -> Mapping[ResolverID, str]:
    restart_scripts = {}
    for resolver in config['servers']['names']:
        try:
            restart_scripts[resolver] = config[resolver]['restart_script']
        except KeyError:
            logging.warning('No restart script available for "%s"!', resolver)
    return restart_scripts
45 46


47 48 49 50
def disagreement_query_stream(
            lmdb,
            report: DiffReport,
            skip_unstable: bool = True,
51
            skip_non_reproducible: bool = True,
52 53
            shuffle: bool = True
        ) -> Iterator[Tuple[QKey, WireFormat]]:
54 55
    if report.target_disagreements is None or report.reprodata is None:
        raise RuntimeError("Report doesn't contain necessary data!")
56
    qids = report.target_disagreements.keys()  # type: Union[Sequence[QID], AbstractSet[QID]]
57 58
    if shuffle:
        # create a new, randomized list from disagreements
59
        qids = random.sample(qids, len(qids))
60
    queries = get_query_iterator(lmdb, qids)
61 62 63 64 65
    for qid, qwire in queries:
        diff = report.target_disagreements[qid]
        reprocounter = report.reprodata[qid]
        # verify if answers are stable
        if skip_unstable and reprocounter.retries != reprocounter.upstream_stable:
66 67 68 69
            logging.debug('Skipping QID %7d: unstable upstream', diff.qid)
            continue
        if skip_non_reproducible and reprocounter.retries != reprocounter.verified:
            logging.debug('Skipping QID %7d: not 100 %% reproducible', diff.qid)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            continue
        yield qid2key(qid), qwire


def chunker(iterable: Iterable[T], size: int) -> Iterator[Iterable[T]]:
    """
    Collect data into fixed-length chunks or blocks

    chunker([x, y, z], 2) --> [x, y], [z, None]
    """
    args = [iter(iterable)] * size
    return zip_longest(*args)


def process_answers(
            qkey: QKey,
86
            answers: Mapping[ResolverID, DNSReply],
87 88 89 90
            report: DiffReport,
            criteria: Sequence[FieldLabel],
            target: ResolverID
        ) -> None:
91 92
    if report.target_disagreements is None or report.reprodata is None:
        raise RuntimeError("Report doesn't contain necessary data!")
93 94
    qid = key2qid(qkey)
    reprocounter = report.reprodata[qid]
95
    others_agree, mismatches = compare(answers, criteria, target)
96 97 98 99

    reprocounter.retries += 1
    if others_agree:
        reprocounter.upstream_stable += 1
100
        assert mismatches is not None
101 102 103 104
        if Diff(qid, mismatches) == report.target_disagreements[qid]:
            reprocounter.verified += 1


105
def main():
106
    cli.setup_logging()
107 108
    parser = argparse.ArgumentParser(
        description='attempt to reproduce original diffs from JSON report')
109 110 111
    cli.add_arg_envdir(parser)
    cli.add_arg_config(parser)
    cli.add_arg_datafile(parser)
112 113
    parser.add_argument('-s', '--sequential', action='store_true', default=False,
                        help='send one query at a time (slower, but more reliable)')
114

115
    args = parser.parse_args()
116
    sendrecv.module_init(args)
117
    datafile = cli.get_datafile(args)
118
    report = DiffReport.from_json(datafile)
119
    restart_scripts = get_restart_scripts(args.cfg)
120 121
    servers = args.cfg['servers']['names']
    dnsreplies_factory = DNSRepliesFactory(servers)
122

123 124 125 126 127
    if args.sequential:
        nproc = 1
    else:
        nproc = args.cfg['sendrecv']['jobs']

128 129 130 131 132
    if report.reprodata is None:
        report.reprodata = ReproData()

    with LMDB(args.envdir, readonly=True) as lmdb:
        lmdb.open_db(LMDB.QUERIES)
133

134 135 136 137 138 139
        try:
            MetaDatabase(lmdb, servers, create=False)  # check version and servers
        except NotImplementedError as exc:
            logging.critical(exc)
            sys.exit(1)

140 141 142 143 144 145 146 147 148 149
        dstream = disagreement_query_stream(lmdb, report)
        try:
            with pool.Pool(processes=nproc) as p:
                done = 0
                for process_args in chunker(dstream, nproc):
                    # restart resolvers and clear their cache
                    for script in restart_scripts.values():
                        restart_resolver(script)

                    process_args = [args for args in process_args if args is not None]
150
                    for qkey, replies_data, in p.imap_unordered(
151 152 153
                            sendrecv.worker_perform_single_query,
                            process_args,
                            chunksize=1):
154
                        replies = dnsreplies_factory.parse(replies_data)
155 156 157 158 159 160 161 162 163
                        process_answers(qkey, replies, report,
                                        args.cfg['diff']['criteria'],
                                        args.cfg['diff']['target'])

                    done += len(process_args)
                    logging.info('Processed {:4d} queries'.format(done))
        finally:
            # make sure data is saved in case of interrupt
            report.export_json(datafile)
164 165 166 167


if __name__ == '__main__':
    main()