match.py 9.22 KB
Newer Older
1 2 3 4
import collections
import logging
from typing import (  # noqa
    Any, Dict, Hashable, Iterator, Mapping, Optional, Sequence, Tuple)
5 6 7

import dns.rdatatype
from dns.rrset import RRset
8
import dns.message
9

10
from .database import DNSReply
11
from .typing import FieldLabel, MismatchValue, ResolverID
12 13


14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class DataMismatch(Exception):
    def __init__(self, exp_val: MismatchValue, got_val: MismatchValue) -> None:
        def convert_val_type(val: Any) -> MismatchValue:
            if isinstance(val, str):
                return val
            if isinstance(val, collections.abc.Sequence):
                return [convert_val_type(item) for item in val]
            if isinstance(val, dns.rrset.RRset):
                return str(val)
            logging.warning(
                'DataMismatch: unknown value type (%s), casting to str', type(val),
                stack_info=True)
            return str(val)

        exp_val = convert_val_type(exp_val)
        got_val = convert_val_type(got_val)

        super(DataMismatch, self).__init__(exp_val, got_val)
        if exp_val == got_val:
            raise RuntimeError("exp_val == got_val ({})".format(exp_val))
        self.exp_val = exp_val
        self.got_val = got_val

    @staticmethod
    def format_value(value: MismatchValue) -> str:
        if isinstance(value, list):
            value = ' '.join(value)
        return str(value)

    def __str__(self) -> str:
        return "expected '{}' got '{}'".format(
            self.format_value(self.exp_val),
            self.format_value(self.got_val))

    def __repr__(self) -> str:
        return 'DataMismatch({}, {})'.format(self.exp_val, self.got_val)

    def __eq__(self, other) -> bool:
        return (isinstance(other, DataMismatch)
                and tuple(self.exp_val) == tuple(other.exp_val)
                and tuple(self.got_val) == tuple(other.got_val))

    @property
    def key(self) -> Tuple[Hashable, Hashable]:
        def make_hashable(value):
            if isinstance(value, list):
                value = tuple(value)
            return value

        return (make_hashable(self.exp_val), make_hashable(self.got_val))

    def __hash__(self) -> int:
        return hash(self.key)


69 70 71 72 73 74 75
def compare_val(exp_val: MismatchValue, got_val: MismatchValue):
    """ Compare values, throw exception if different. """
    if exp_val != got_val:
        raise DataMismatch(str(exp_val), str(got_val))
    return True


76
def compare_rrs(expected: Sequence[RRset], got: Sequence[RRset]):
77 78 79 80 81 82 83
    """ Compare lists of RR sets, throw exception if different. """
    for rr in expected:
        if rr not in got:
            raise DataMismatch(expected, got)
    for rr in got:
        if rr not in expected:
            raise DataMismatch(expected, got)
84
    if len(expected) != len(got):  # detect duplicates
85 86 87 88
        raise DataMismatch(expected, got)
    return True


89 90 91 92
def compare_rrs_types(
            exp_val: Sequence[RRset],
            got_val: Sequence[RRset],
            compare_rrsigs: bool):
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    """sets of RR types in both sections must match"""
    def rr_ordering_key(rrset):
        return rrset.covers if compare_rrsigs else rrset.rdtype

    def key_to_text(rrtype):
        if not compare_rrsigs:
            return dns.rdatatype.to_text(rrtype)
        return 'RRSIG(%s)' % dns.rdatatype.to_text(rrtype)

    def filter_by_rrsig(seq, rrsig):
        for el in seq:
            el_rrsig = el.rdtype == dns.rdatatype.RRSIG
            if el_rrsig == rrsig:
                yield el

    exp_types = frozenset(rr_ordering_key(rrset)
                          for rrset in filter_by_rrsig(exp_val, compare_rrsigs))
    got_types = frozenset(rr_ordering_key(rrset)
                          for rrset in filter_by_rrsig(got_val, compare_rrsigs))
    if exp_types != got_types:
        raise DataMismatch(
            tuple(key_to_text(i) for i in sorted(exp_types)),
            tuple(key_to_text(i) for i in sorted(got_types)))


118 119 120 121 122
def match_part(  # pylint: disable=inconsistent-return-statements
            exp_msg: dns.message.Message,
            got_msg: dns.message.Message,
            criteria: FieldLabel
        ):
123
    """ Compare scripted reply to given message using single criteria. """
Tomas Krizek's avatar
Tomas Krizek committed
124
    if criteria == 'opcode':
125
        return compare_val(exp_msg.opcode(), got_msg.opcode())
Tomas Krizek's avatar
Tomas Krizek committed
126
    elif criteria == 'flags':
127
        return compare_val(dns.flags.to_text(exp_msg.flags), dns.flags.to_text(got_msg.flags))
Tomas Krizek's avatar
Tomas Krizek committed
128
    elif criteria == 'rcode':
129
        return compare_val(dns.rcode.to_text(exp_msg.rcode()), dns.rcode.to_text(got_msg.rcode()))
Tomas Krizek's avatar
Tomas Krizek committed
130
    elif criteria == 'question':
131 132 133 134 135 136
        question_match = compare_rrs(exp_msg.question, got_msg.question)
        if not exp_msg.question:  # 0 RRs, nothing else to compare
            return True
        assert len(exp_msg.question) == 1, "multiple question in single DNS query unsupported"
        case_match = compare_val(got_msg.question[0].name.labels, exp_msg.question[0].name.labels)
        return question_match and case_match
137
    elif criteria in ('answer', 'ttl'):
138
        return compare_rrs(exp_msg.answer, got_msg.answer)
Tomas Krizek's avatar
Tomas Krizek committed
139
    elif criteria == 'answertypes':
140
        return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=False)
Tomas Krizek's avatar
Tomas Krizek committed
141
    elif criteria == 'answerrrsigs':
142
        return compare_rrs_types(exp_msg.answer, got_msg.answer, compare_rrsigs=True)
Tomas Krizek's avatar
Tomas Krizek committed
143
    elif criteria == 'authority':
144
        return compare_rrs(exp_msg.authority, got_msg.authority)
Tomas Krizek's avatar
Tomas Krizek committed
145
    elif criteria == 'additional':
146
        return compare_rrs(exp_msg.additional, got_msg.additional)
Tomas Krizek's avatar
Tomas Krizek committed
147
    elif criteria == 'edns':
148 149 150 151
        if got_msg.edns != exp_msg.edns:
            raise DataMismatch(str(exp_msg.edns), str(got_msg.edns))
        if got_msg.payload != exp_msg.payload:
            raise DataMismatch(str(exp_msg.payload), str(got_msg.payload))
Tomas Krizek's avatar
Tomas Krizek committed
152
    elif criteria == 'nsid':
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
        nsid_opt = None
        for opt in exp_msg.options:
            if opt.otype == dns.edns.NSID:
                nsid_opt = opt
                break
        # Find matching NSID
        for opt in got_msg.options:
            if opt.otype == dns.edns.NSID:
                if not nsid_opt:
                    raise DataMismatch('', str(opt.data))
                if opt == nsid_opt:
                    return True
                else:
                    raise DataMismatch(str(nsid_opt.data), str(opt.data))
        if nsid_opt:
            raise DataMismatch(str(nsid_opt.data), '')
    else:
Tomas Krizek's avatar
Tomas Krizek committed
170
        raise NotImplementedError('unknown match request "%s"' % criteria)
171 172 173


def match(
174 175
            expected: DNSReply,
            got: DNSReply,
176 177 178
            match_fields: Sequence[FieldLabel]
        ) -> Iterator[Tuple[FieldLabel, DataMismatch]]:
    """ Compare scripted reply to given message based on match criteria. """
179 180 181 182 183 184 185
    exp_msg, exp_res = expected.parse_wire()
    got_msg, got_res = got.parse_wire()
    exp_malformed = exp_res != DNSReply.WIREFORMAT_VALID
    got_malformed = got_res != DNSReply.WIREFORMAT_VALID

    if expected.timeout or got.timeout:
        if not expected.timeout:
186
            yield 'timeout', DataMismatch('answer', 'timeout')
187
        if not got.timeout:
188
            yield 'timeout', DataMismatch('timeout', 'answer')
189 190 191 192 193 194 195 196 197 198
    elif exp_malformed or got_malformed:
        if exp_res == got_res:
            logging.warning(
                'match: DNS replies malformed in the same way! (%s)', exp_res)
        else:
            yield 'malformed', DataMismatch(exp_res, got_res)

    if expected.timeout or got.timeout or exp_malformed or got_malformed:
        return  # don't attempt to match any other fields

Tomas Krizek's avatar
Tomas Krizek committed
199
    for criteria in match_fields:
200
        try:
201
            match_part(exp_msg, got_msg, criteria)
202
        except DataMismatch as ex:
Tomas Krizek's avatar
Tomas Krizek committed
203
            yield criteria, ex
204 205 206


def diff_pair(
207
            answers: Mapping[ResolverID, DNSReply],
208 209 210 211 212 213 214 215
            criteria: Sequence[FieldLabel],
            name1: ResolverID,
            name2: ResolverID
        ) -> Iterator[Tuple[FieldLabel, DataMismatch]]:
    yield from match(answers[name1], answers[name2], criteria)


def transitive_equality(
216
            answers: Mapping[ResolverID, DNSReply],
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
            criteria: Sequence[FieldLabel],
            resolvers: Sequence[ResolverID]
        ) -> bool:
    """
    Compare answers from all resolvers.
    Optimization is based on transitivity of equivalence relation.
    """
    assert len(resolvers) >= 2
    res_a = resolvers[0]  # compare all others to this resolver
    res_others = resolvers[1:]
    return all(map(
        lambda res_b: not any(diff_pair(answers, criteria, res_a, res_b)),
        res_others))


def compare(
233
            answers: Mapping[ResolverID, DNSReply],
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
            criteria: Sequence[FieldLabel],
            target: ResolverID
        ) -> Tuple[bool, Optional[Mapping[FieldLabel, DataMismatch]]]:
    others = list(answers.keys())
    try:
        others.remove(target)
    except ValueError:
        return (False, None)  # HACK, target did not reply
    if not others:
        return (False, None)  # HACK, not enough targets to compare
    random_other = others[0]

    if len(others) >= 2:
        # do others agree on the answer?
        others_agree = transitive_equality(answers, criteria, others)
        if not others_agree:
            return (False, None)
    else:
        others_agree = True
    target_diffs = dict(diff_pair(answers, criteria, random_other, target))
    return (others_agree, target_diffs)