Commit 54fab1fb authored by Štěpán Balážik's avatar Štěpán Balážik Committed by Petr Špaček

rplint: add typehint to rplint

parent 8ebed192
......@@ -5,6 +5,7 @@ import glob
import itertools
import os
import sys
from typing import Any, Callable, Iterable, Iterator, Optional, List, Union, Set
import dns.name
......@@ -12,6 +13,8 @@ import pydnstest.augwrap
import pydnstest.matchpart
import pydnstest.scenario
Element = Union["Entry", "Step", pydnstest.scenario.Range]
RCODES = {"NOERROR", "FORMERR", "SERVFAIL", "NXDOMAIN", "NOTIMP", "REFUSED", "YXDOMAIN", "YXRRSET",
"NXRRSET", "NOTAUTH", "NOTZONE", "BADVERS", "BADSIG", "BADKEY", "BADTIME", "BADMODE",
"BADNAME", "BADALG", "BADTRUNC", "BADCOOKIE"}
......@@ -27,15 +30,16 @@ class RplintError(ValueError):
super().__init__(msg)
def get_line_number(file, char_number):
def get_line_number(file: str, char_number: int) -> int:
pos = 0
for number, line in enumerate(open(file)):
pos += len(line)
if pos >= char_number:
return number + 2
return 0
def is_empty(iterable):
def is_empty(iterable: Iterator[Any]) -> bool:
try:
next(iterable)
except StopIteration:
......@@ -44,7 +48,7 @@ def is_empty(iterable):
class Entry:
def __init__(self, node):
def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
self.match = {m.value for m in node.match("/match")}
self.adjust = {a.value for a in node.match("/adjust")}
self.authority = list(node.match("/section/authority/record"))
......@@ -54,21 +58,24 @@ class Entry:
class Step:
def __init__(self, node):
def __init__(self, node: pydnstest.augwrap.AugeasNode) -> None:
self.node = node
self.type = node["/type"].value
try:
self.entry = Entry(node["/entry"])
self.entry = Entry(node["/entry"]) # type: Optional[Entry]
except KeyError:
self.entry = None
class RplintFail:
def __init__(self, test, element=None, etc=""):
def __init__(self, test: "RplintTest",
element: Optional[Element] = None,
etc: str = "") -> None:
self.path = test.path
self.element = element
self.element = element # type: Optional[Element]
self.line = get_line_number(self.path, element.node.char if element is not None else 0)
self.etc = etc
self.check = None
self.check = None # type: Optional[Callable[[RplintTest], List[RplintFail]]]
def __str__(self):
if self.etc:
......@@ -79,7 +86,7 @@ class RplintFail:
class RplintTest:
def __init__(self, path):
def __init__(self, path: str) -> None:
aug = pydnstest.augwrap.AugeasWrapper(confpath=os.path.realpath(path),
lens='Deckard',
loadpath=os.path.join(os.path.dirname(__file__),
......@@ -96,7 +103,7 @@ class RplintTest:
self.ranges = [pydnstest.scenario.Range(n) for n in self.node.match("/scenario/range")]
self.fails = None
self.fails = None # type: Optional[List[RplintFail]]
self.checks = [
entry_more_than_one_rcode,
entry_no_qname_qtype_copy_query,
......@@ -115,7 +122,7 @@ class RplintTest:
step_duplicate_id,
]
def run_checks(self):
def run_checks(self) -> bool:
"""returns True iff all tests passed"""
self.fails = []
for check in self.checks:
......@@ -128,12 +135,14 @@ class RplintTest:
return True
return False
def print_fails(self):
def print_fails(self) -> None:
if self.fails is None:
raise RuntimeError("Maybe you should run some test first…")
for fail in self.fails:
print(fail)
def config_trust_anchor_trailing_period_missing(test):
def config_trust_anchor_trailing_period_missing(test: RplintTest) -> List[RplintFail]:
"""Trust-anchor option in configuration contains domain without trailing period"""
for conf in test.config:
if conf[0] == "trust-anchor":
......@@ -142,7 +151,7 @@ def config_trust_anchor_trailing_period_missing(test):
return []
def scenario_timestamp(test):
def scenario_timestamp(test: RplintTest) -> List[RplintFail]:
"""RRSSIG record present in test but no val-override-date or val-override-timestamp in config"""
rrsigs = []
for entry in test.entries:
......@@ -156,7 +165,7 @@ def scenario_timestamp(test):
return rrsigs
def entry_no_qname_qtype_copy_query(test):
def entry_no_qname_qtype_copy_query(test: RplintTest) -> List[RplintFail]:
"""ENTRY without qname and qtype in MATCH and without copy_query in ADJUST"""
fails = []
for entry in test.range_entries:
......@@ -167,7 +176,7 @@ def entry_no_qname_qtype_copy_query(test):
return fails
def entry_ns_in_authority(test):
def entry_ns_in_authority(test: RplintTest) -> List[RplintFail]:
"""ENTRY has authority section with NS records, consider using MATCH subdomain"""
fails = []
for entry in test.range_entries:
......@@ -178,7 +187,7 @@ def entry_ns_in_authority(test):
return fails
def entry_more_than_one_rcode(test):
def entry_more_than_one_rcode(test: RplintTest) -> List[RplintFail]:
"""ENTRY has more than one rcode in MATCH"""
fails = []
for entry in test.entries:
......@@ -187,7 +196,7 @@ def entry_more_than_one_rcode(test):
return fails
def scenario_ad_or_rrsig_no_ta(test):
def scenario_ad_or_rrsig_no_ta(test: RplintTest) -> List[RplintFail]:
"""AD or RRSIG present in test but no trust-anchor present in config"""
dnssec = []
for entry in test.entries:
......@@ -205,43 +214,45 @@ def scenario_ad_or_rrsig_no_ta(test):
return dnssec
def step_query_match(test):
def step_query_match(test: RplintTest) -> List[RplintFail]:
"""STEP QUERY has a MATCH rule"""
return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and step.entry.match]
return [RplintFail(test, step) for step in test.steps if step.type == "QUERY" and
step.entry and step.entry.match]
def step_check_answer_no_match(test):
def step_check_answer_no_match(test: RplintTest) -> List[RplintFail]:
"""ENTRY in STEP CHECK_ANSWER has no MATCH rule"""
return [RplintFail(test, step) for step in test.steps if step.type == "CHECK_ANSWER" and
not step.entry.match]
step.entry and not step.entry.match]
def step_unchecked_rcode(test):
def step_unchecked_rcode(test: RplintTest) -> List[RplintFail]:
"""ENTRY specifies rcode but STEP MATCH does not check for it."""
fails = []
for step in test.steps:
if step.type == "CHECK_ANSWER" and "all" not in step.entry.match:
if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
if step.entry.reply & RCODES and "rcode" not in step.entry.match:
fails.append(RplintFail(test, step.entry))
return fails
def step_unchecked_match(test):
def step_unchecked_match(test: RplintTest) -> List[RplintFail]:
"""ENTRY specifies flags but MATCH does not check for them"""
fails = []
for step in test.steps:
if step.type == "CHECK_ANSWER":
entry = step.entry
if "all" not in entry.match and entry.reply - RCODES and "flags" not in entry.match:
if entry and "all" not in entry.match and entry.reply - RCODES and \
"flags" not in entry.match:
fails.append(RplintFail(test, entry, str(entry.reply - RCODES)))
return fails
def step_section_unchecked(test):
def step_section_unchecked(test: RplintTest) -> List[RplintFail]:
"""ENTRY has non-empty sections but MATCH does not check for all of them"""
fails = []
for step in test.steps:
if step.type == "CHECK_ANSWER" and "all" not in step.entry.match:
if step.type == "CHECK_ANSWER" and step.entry and "all" not in step.entry.match:
for section in SECTIONS:
if not is_empty(step.node.match("/entry/section/" + section + "/*")):
if section not in step.entry.match:
......@@ -249,7 +260,7 @@ def step_section_unchecked(test):
return fails
def range_overlapping_ips(test):
def range_overlapping_ips(test: RplintTest) -> List[RplintFail]:
"""RANGE has common IPs with some previous overlapping RANGE"""
fails = []
for r1, r2 in itertools.combinations(test.ranges, 2):
......@@ -261,7 +272,7 @@ def range_overlapping_ips(test):
return fails
def range_shadowing_match_rules(test):
def range_shadowing_match_rules(test: RplintTest) -> List[RplintFail]:
"""ENTRY has no effect since one of previous entries has the same or broader match rules"""
fails = []
for r in test.ranges:
......@@ -275,10 +286,10 @@ def range_shadowing_match_rules(test):
return fails
def step_duplicate_id(test):
def step_duplicate_id(test: RplintTest) -> List[RplintFail]:
"""STEP has the same ID as one of previous ones"""
fails = []
step_numbers = set()
step_numbers = set() # type: Set[int]
for step in test.steps:
if step.node.value in step_numbers:
fails.append(RplintFail(test, step))
......@@ -292,12 +303,13 @@ def step_duplicate_id(test):
# if "copy_id" not in adjust:
# entry_error(test, entry, "copy_id should be in ADJUST")
def test_run_rplint(rpl_path):
def test_run_rplint(rpl_path: str) -> None:
t = RplintTest(rpl_path)
passed = t.run_checks()
if not passed:
raise RplintError(t.fails)
if __name__ == '__main__':
try:
test_path = sys.argv[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment