Commit 54fab1fb authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Petr Špaček

rplint: add typehint to rplint

parent 8ebed192
...@@ -5,6 +5,7 @@ import glob ...@@ -5,6 +5,7 @@ import glob
import itertools import itertools
import os import os
import sys import sys
from typing import Any, Callable, Iterable, Iterator, Optional, List, Union, Set
import dns.name import dns.name
...@@ -12,6 +13,8 @@ import pydnstest.augwrap ...@@ -12,6 +13,8 @@ import pydnstest.augwrap
import pydnstest.matchpart import pydnstest.matchpart
import pydnstest.scenario import pydnstest.scenario
Element = Union["Entry", "Step", pydnstest.scenario.Range]
RCODES = {"NOERROR", "FORMERR", "SERVFAIL", "NXDOMAIN", "NOTIMP", "REFUSED", "YXDOMAIN", "YXRRSET", RCODES = {"NOERROR", "FORMERR", "SERVFAIL", "NXDOMAIN", "NOTIMP", "REFUSED", "YXDOMAIN", "YXRRSET",
"NXRRSET", "NOTAUTH", "NOTZONE", "BADVERS", "BADSIG", "BADKEY", "BADTIME", "BADMODE", "NXRRSET", "NOTAUTH", "NOTZONE", "BADVERS", "BADSIG", "BADKEY", "BADTIME", "BADMODE",
"BADNAME", "BADALG", "BADTRUNC", "BADCOOKIE"} "BADNAME", "BADALG", "BADTRUNC", "BADCOOKIE"}
...@@ -27,15 +30,16 @@ class RplintError(ValueError): ...@@ -27,15 +30,16 @@ class RplintError(ValueError):
super().__init__(msg) super().__init__(msg)
def get_line_number(file, char_number): def get_line_number(file: str, char_number: int) -> int:
pos = 0 pos = 0
for number, line in enumerate(open(file)): for number, line in enumerate(open(file)):
pos += len(line) pos += len(line)
if pos >= char_number: if pos >= char_number:
return number + 2 return number + 2
return 0
def is_empty(iterable): def is_empty(iterable: Iterator[Any]) -> bool:
try: try:
next(iterable) next(iterable)
except StopIteration: except StopIteration:
...@@ -44,7 +48,7 @@ def is_empty(iterable): ...@@ -44,7 +48,7 @@ def is_empty(iterable):
class Entry: class Entry:
def __init__(self, node): def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
self.match = {m.value for m in node.match("/match")} self.match = {m.value for m in node.match("/match")}
self.adjust = {a.value for a in node.match("/adjust")} self.adjust = {a.value for a in node.match("/adjust")}
self.authority = list(node.match("/section/authority/record")) self.authority = list(node.match("/section/authority/record"))
...@@ -54,21 +58,24 @@ class Entry: ...@@ -54,21 +58,24 @@ class Entry:
class Step: class Step:
def __init__(self, node): def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
self.node = node self.node = node
self.type = node["/type"].value self.type = node["/type"].value
try: try:
self.entry = Entry(node["/entry"]) self.entry = Entry(node["/entry"]) # type: Optional[Entry]
except KeyError: except KeyError:
self.entry = None self.entry = None
class RplintFail: class RplintFail:
def __init__(self, test, element=None, etc=""): def __init__(self, test: "RplintTest",
element: Optional[Element] = None,
etc: str = "") -> None:
self.path = test.path self.path = test.path
self.element = element self.element = element # type: Optional[Element]
self.line = get_line_number(self.path, element.node.char if element is not None else 0) self.line = get_line_number(self.path, element.node.char if element is not None else 0)
self.etc = etc self.etc = etc
self.check = None self.check = None # type: Optional[Callable[[RplintTest], List[RplintFail]]]
def __str__(self): def __str__(self):
if self.etc: if self.etc:
...@@ -79,7 +86,7 @@ class RplintFail: ...@@ -79,7 +86,7 @@ class RplintFail:
class RplintTest: class RplintTest:
def __init__(self, path): def __init__(self, path: str) -> None:
aug = pydnstest.augwrap.AugeasWrapper(confpath=os.path.realpath(path), aug = pydnstest.augwrap.AugeasWrapper(confpath=os.path.realpath(path),
lens='Deckard', lens='Deckard',
loadpath=os.path.join(os.path.dirname(__file__), loadpath=os.path.join(os.path.dirname(__file__),
...@@ -96,7 +103,7 @@ class RplintTest: ...@@ -96,7 +103,7 @@ class RplintTest:
self.ranges = [pydnstest.scenario.Range(n) for n in self.node.match("/scenario/range")] self.ranges = [pydnstest.scenario.Range(n) for n in self.node.match("/scenario/range")]
self.fails = None self.fails = None # type: Optional[List[RplintFail]]
self.checks = [ self.checks = [
entry_more_than_one_rcode, entry_more_than_one_rcode,
entry_no_qname_qtype_copy_query, entry_no_qname_qtype_copy_query,
...@@ -115,7 +122,7 @@ class RplintTest: ...@@ -115,7 +122,7 @@ class RplintTest:
step_duplicate_id, step_duplicate_id,
] ]
def run_checks(self): def run_checks(self) -> bool:
"""returns True iff all tests passed""" """returns True iff all tests passed"""
self.fails = [] self.fails = []
for check in self.checks: for check in self.checks:
...@@ -128,12 +135,14 @@ class RplintTest: ...@@ -128,12 +135,14 @@ class RplintTest:
return True return True
return False return False
def print_fails(self): def print_fails(self) -> None:
if self.fails is None:
raise RuntimeError("Maybe you should run some test first…")
for fail in self.fails: for fail in self.fails:
print(fail) print(fail)
def config_trust_anchor_trailing_period_missing(test): def config_trust_anchor_trailing_period_missing(test: RplintTest) -> List[RplintFail]:
"""Trust-anchor option in configuration contains domain without trailing period""" """Trust-anchor option in configuration contains domain without trailing period"""
for conf in test.config: for conf in test.config:
if conf[0] == "trust-anchor": if conf[0] == "trust-anchor":
...@@ -142,7 +151,7 @@ def config_trust_anchor_trailing_period_missing(test): ...@@ -142,7 +151,7 @@ def config_trust_anchor_trailing_period_missing(test):
return [] return []
def scenario_timestamp(test): def scenario_timestamp(test: RplintTest) -> List[RplintFail]:
"""RRSSIG record present in test but no val-override-date or val-override-timestamp in config""" """RRSSIG record present in test but no val-override-date or val-override-timestamp in config"""
rrsigs = [] rrsigs = []
for entry in test.entries: for entry in test.entries:
...@@ -156,7 +165,7 @@ def scenario_timestamp(test): ...@@ -156,7 +165,7 @@ def scenario_timestamp(test):
return rrsigs return rrsigs
def entry_no_qname_qtype_copy_query(test): def entry_no_qname_qtype_copy_query(test: RplintTest) -> List[RplintFail]:
"""ENTRY without qname and qtype in MATCH and without copy_query in ADJUST""" """ENTRY without qname and qtype in MATCH and without copy_query in ADJUST"""
fails = [] fails = []
for entry in test.range_entries: for entry in test.range_entries:
...@@ -167,7 +176,7 @@ def entry_no_qname_qtype_copy_query(test): ...@@ -167,7 +176,7 @@ def entry_no_qname_qtype_copy_query(test):
return fails return fails
def entry_ns_in_authority(test): def entry_ns_in_authority(test: RplintTest) -> List[RplintFail]:
"""ENTRY has authority section with NS records, consider using MATCH subdomain""" """ENTRY has authority section with NS records, consider using MATCH subdomain"""
fails = [] fails = []
for entry in test.range_entries: for entry in test.range_entries:
...@@ -178,7 +187,7 @@ def entry_ns_in_authority(test): ...@@ -178,7 +187,7 @@ def entry_ns_in_authority(test):
return fails return fails
def entry_more_than_one_rcode(test): def entry_more_than_one_rcode(test: RplintTest) -> List[RplintFail]:
"""ENTRY has more than one rcode in MATCH""" """ENTRY has more than one rcode in MATCH"""
fails = [] fails = []
for entry in test.entries: for entry in test.entries:
...@@ -187,7 +196,7 @@ def entry_more_than_one_rcode(test): ...@@ -187,7 +196,7 @@ def entry_more_than_one_rcode(test):
return fails return fails
def scenario_ad_or_rrsig_no_ta(test): def scenario_ad_or_rrsig_no_ta(test: RplintTest) -> List[RplintFail]:
"""AD or RRSIG present in test but no trust-anchor present in config""" """AD or RRSIG present in test but no trust-anchor present in config"""
dnssec = [] dnssec = []
for entry in test.entries: for entry in test.entries:
...@@ -205,43 +214,45 @@ def scenario_ad_or_rrsig_no_ta(test): ...@@ -205,43 +214,45 @@ def scenario_ad_or_rrsig_no_ta(test):
return dnssec return dnssec
def step_query_match(test): def step_query_match(test: RplintTest) -> List[RplintFail]:
"""STEP QUERY has a MATCH rule""" """STEP QUERY has a MATCH rule"""
return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and step.entry.match] return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and
step.entry and step.entry.match]
def step_check_answer_no_match(test): def step_check_answer_no_match(test: RplintTest) -> List[RplintFail]:
"""ENTRY in STEP CHECK_ANSWER has no MATCH rule""" """ENTRY in STEP CHECK_ANSWER has no MATCH rule"""
return [RplintFail(test, step) for step in test.steps if step.type == "CHECK_ANSWER" and return [RplintFail(test, step) for step in test.steps if step.type == "CHECK_ANSWER" and
not step.entry.match] step.entry and not step.entry.match]
def step_unchecked_rcode(test): def step_unchecked_rcode(test: RplintTest) -> List[RplintFail]:
"""ENTRY specifies rcode but STEP MATCH does not check for it.""" """ENTRY specifies rcode but STEP MATCH does not check for it."""
fails = [] fails = []
for step in test.steps: for step in test.steps:
if step.type == "CHECK_ANSWER" and "all" not in step.entry.match: if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
if step.entry.reply & RCODES and "rcode" not in step.entry.match: if step.entry.reply & RCODES and "rcode" not in step.entry.match:
fails.append(RplintFail(test, step.entry)) fails.append(RplintFail(test, step.entry))
return fails return fails
def step_unchecked_match(test): def step_unchecked_match(test: RplintTest) -> List[RplintFail]:
"""ENTRY specifies flags but MATCH does not check for them""" """ENTRY specifies flags but MATCH does not check for them"""
fails = [] fails = []
for step in test.steps: for step in test.steps:
if step.type == "CHECK_ANSWER": if step.type == "CHECK_ANSWER":
entry = step.entry entry = step.entry
if "all" not in entry.match and entry.reply - RCODES and "flags" not in entry.match: if entry and "all" not in entry.match and entry.reply - RCODES and \
"flags" not in entry.match:
fails.append(RplintFail(test, entry, str(entry.reply - RCODES))) fails.append(RplintFail(test, entry, str(entry.reply - RCODES)))
return fails return fails
def step_section_unchecked(test): def step_section_unchecked(test: RplintTest) -> List[RplintFail]:
"""ENTRY has non-empty sections but MATCH does not check for all of them""" """ENTRY has non-empty sections but MATCH does not check for all of them"""
fails = [] fails = []
for step in test.steps: for step in test.steps:
if step.type == "CHECK_ANSWER" and "all" not in step.entry.match: if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
for section in SECTIONS: for section in SECTIONS:
if not is_empty(step.node.match("/entry/section/" + section + "/*")): if not is_empty(step.node.match("/entry/section/" + section + "/*")):
if section not in step.entry.match: if section not in step.entry.match:
...@@ -249,7 +260,7 @@ def step_section_unchecked(test): ...@@ -249,7 +260,7 @@ def step_section_unchecked(test):
return fails return fails
def range_overlapping_ips(test): def range_overlapping_ips(test: RplintTest) -> List[RplintFail]:
"""RANGE has common IPs with some previous overlapping RANGE""" """RANGE has common IPs with some previous overlapping RANGE"""
fails = [] fails = []
for r1, r2 in itertools.combinations(test.ranges, 2): for r1, r2 in itertools.combinations(test.ranges, 2):
...@@ -261,7 +272,7 @@ def range_overlapping_ips(test): ...@@ -261,7 +272,7 @@ def range_overlapping_ips(test):
return fails return fails
def range_shadowing_match_rules(test): def range_shadowing_match_rules(test: RplintTest) -> List[RplintFail]:
"""ENTRY has no effect since one of previous entries has the same or broader match rules""" """ENTRY has no effect since one of previous entries has the same or broader match rules"""
fails = [] fails = []
for r in test.ranges: for r in test.ranges:
...@@ -275,10 +286,10 @@ def range_shadowing_match_rules(test): ...@@ -275,10 +286,10 @@ def range_shadowing_match_rules(test):
return fails return fails
def step_duplicate_id(test): def step_duplicate_id(test: RplintTest) -> List[RplintFail]:
"""STEP has the same ID as one of previous ones""" """STEP has the same ID as one of previous ones"""
fails = [] fails = []
step_numbers = set() step_numbers = set() # type: Set[int]
for step in test.steps: for step in test.steps:
if step.node.value in step_numbers: if step.node.value in step_numbers:
fails.append(RplintFail(test, step)) fails.append(RplintFail(test, step))
...@@ -292,12 +303,13 @@ def step_duplicate_id(test): ...@@ -292,12 +303,13 @@ def step_duplicate_id(test):
# if "copy_id" not in adjust: # if "copy_id" not in adjust:
# entry_error(test, entry, "copy_id should be in ADJUST") # entry_error(test, entry, "copy_id should be in ADJUST")
def test_run_rplint(rpl_path): def test_run_rplint(rpl_path: str) -> None:
t = RplintTest(rpl_path) t = RplintTest(rpl_path)
passed = t.run_checks() passed = t.run_checks()
if not passed: if not passed:
raise RplintError(t.fails) raise RplintError(t.fails)
if __name__ == '__main__': if __name__ == '__main__':
try: try:
test_path = sys.argv[1] test_path = sys.argv[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment