test.py: add type annotations

Add type annotations where possible.
This commit is contained in:
Konstantin Osipov
2022-06-17 21:30:03 +03:00
parent 2470b1d888
commit fd3d08e560

174
test.py
View File

@@ -35,7 +35,7 @@ from test.pylib.artifact_registry import ArtifactRegistry
from test.pylib.host_registry import HostRegistry
from test.pylib.pool import Pool
from test.pylib.scylla_server import ScyllaServer, ScyllaCluster
from typing import Dict, List
from typing import Dict, List, Callable, Any, Iterable, Optional, Awaitable
output_is_a_tty = sys.stdout.isatty()
@@ -43,13 +43,13 @@ all_modes = set(['debug', 'release', 'dev', 'sanitize', 'coverage'])
debug_modes = set(['debug', 'sanitize'])
def create_formatter(*decorators):
def create_formatter(*decorators) -> Callable[[Any], str]:
"""Return a function which decorates its argument with the given
color/style if stdout is a tty, and leaves intact otherwise."""
def color(arg):
def color(arg: Any) -> str:
return "".join(decorators) + str(arg) + colorama.Style.RESET_ALL
def nocolor(arg):
def nocolor(arg: Any) -> str:
return str(arg)
return color if output_is_a_tty else nocolor
@@ -77,20 +77,20 @@ class TestSuite(ABC):
E.g. it can be unit tests, boost tests, or CQL tests."""
# All existing test suites, one suite per path/mode.
suites: Dict[str, ABC] = dict()
suites: Dict[str, 'TestSuite'] = dict()
artifacts = ArtifactRegistry()
hosts = HostRegistry()
FLAKY_RETRIES = 5
_next_id = 0
def __init__(self, path, cfg, options, mode):
def __init__(self, path: str, cfg: dict, options: argparse.Namespace, mode: str) -> None:
self.path = path
self.name = os.path.basename(self.path)
self.cfg = cfg
self.options = options
self.mode = mode
self.suite_key = os.path.join(path, mode)
self.tests = []
self.tests: List['Test'] = []
self.pending_test_count = 0
# The number of failed tests
self.n_failed = 0
@@ -120,16 +120,16 @@ class TestSuite(ABC):
self.disabled_tests.update(skip_in_m - run_in_m)
@property
def next_id(self):
def next_id(self) -> int:
TestSuite._next_id += 1
return TestSuite._next_id
@staticmethod
def test_count():
def test_count() -> int:
return TestSuite._next_id
@staticmethod
def load_cfg(path):
def load_cfg(path: str) -> dict:
with open(os.path.join(path, "suite.yaml"), "r") as cfg_file:
cfg = yaml.safe_load(cfg_file.read())
if not isinstance(cfg, dict):
@@ -137,7 +137,7 @@ class TestSuite(ABC):
return cfg
@staticmethod
def opt_create(path, options, mode):
def opt_create(path: str, options: argparse.Namespace, mode: str) -> 'TestSuite':
"""Return a subclass of TestSuite with name cfg["type"].title + TestSuite.
Ensures there is only one suite instance per path."""
suite_key = os.path.join(path, mode)
@@ -159,24 +159,25 @@ class TestSuite(ABC):
if not SpecificTestSuite:
raise RuntimeError("Failed to load tests in {}: suite type '{}' not found".format(path, kind))
suite = SpecificTestSuite(path, cfg, options, mode)
assert suite is not None
TestSuite.suites[suite_key] = suite
return suite
@staticmethod
def all_tests():
def all_tests() -> Iterable['Test']:
return itertools.chain(*[suite.tests for suite in
TestSuite.suites.values()])
@property
@abstractmethod
def pattern(self):
def pattern(self) -> str:
pass
@abstractmethod
async def add_test(self, shortname):
async def add_test(self, shortname: str) -> None:
pass
async def run(self, test, options):
async def run(self, test: 'Test', options: argparse.Namespace):
try:
for i in range(1, self.FLAKY_RETRIES):
await test.run(options)
@@ -214,7 +215,7 @@ class TestSuite(ABC):
if options.skip_pattern and options.skip_pattern in t:
continue
async def add_test(shortname):
async def add_test(shortname) -> None:
# Add variants of the same test sequentially
# so that case cache has a chance to populate
for i in range(options.repeat):
@@ -238,7 +239,7 @@ class TestSuite(ABC):
class UnitTestSuite(TestSuite):
"""TestSuite instantiation for non-boost unit tests"""
def __init__(self, path, cfg, options, mode):
def __init__(self, path: str, cfg: dict, options: argparse.Namespace, mode: str) -> None:
super().__init__(path, cfg, options, mode)
# Map of custom test command line arguments, if configured
self.custom_args = cfg.get("custom_args", {})
@@ -247,7 +248,7 @@ class UnitTestSuite(TestSuite):
test = UnitTest(self.next_id, shortname, suite, args)
self.tests.append(test)
async def add_test(self, shortname):
async def add_test(self, shortname) -> None:
"""Create a UnitTest class with possibly custom command line
arguments and add it to the list of tests"""
# Skip tests which are not configured, and hence are not built
@@ -261,7 +262,7 @@ class UnitTestSuite(TestSuite):
await self.create_test(shortname, self, a)
@property
def pattern(self):
def pattern(self) -> str:
return "*_test.cc"
@@ -272,10 +273,10 @@ class BoostTestSuite(UnitTestSuite):
# --list_content. Static to share across all modes.
_case_cache: Dict[str, List[str]] = dict()
def __init__(self, path, cfg, options, mode):
def __init__(self, path, cfg: dict, options: argparse.Namespace, mode) -> None:
super().__init__(path, cfg, options, mode)
async def create_test(self, shortname, suite, args):
async def create_test(self, shortname: str, suite, args) -> None:
options = self.options
if options.parallel_cases and (shortname not in self.no_parallel_cases):
fqname = os.path.join(self.mode, self.name, shortname)
@@ -306,7 +307,7 @@ class BoostTestSuite(UnitTestSuite):
test = BoostTest(self.next_id, shortname, suite, args, None)
self.tests.append(test)
def junit_tests(self):
def junit_tests(self) -> Iterable['Test']:
"""Boost tests produce an own XML output, so are not included in a junit report"""
return []
@@ -314,7 +315,7 @@ class BoostTestSuite(UnitTestSuite):
class PythonTestSuite(TestSuite):
"""A collection of Python pytests against a single Scylla instance"""
def __init__(self, path, cfg, options, mode):
def __init__(self, path, cfg: dict, options: argparse.Namespace, mode: str) -> None:
super().__init__(path, cfg, options, mode)
self.scylla_exe = os.path.join("build", self.mode, "scylla")
if self.mode == "coverage":
@@ -329,7 +330,7 @@ class PythonTestSuite(TestSuite):
self.clusters = Pool(cfg.get("pool_size", 2), self.create_cluster)
def topology_for_class(self, class_name, cfg):
def topology_for_class(self, class_name: str, cfg: dict) -> Callable[[], Awaitable]:
def create_server(cluster_name, seed):
cmdline_options = self.cfg.get("extra_scylla_cmdline_options", [])
@@ -363,34 +364,34 @@ class PythonTestSuite(TestSuite):
else:
raise RuntimeError("Unsupported topology name")
async def add_test(self, shortname):
async def add_test(self, shortname) -> None:
test = PythonTest(self.next_id, shortname, self)
self.tests.append(test)
@property
def pattern(self):
def pattern(self) -> str:
return "test_*.py"
class CQLApprovalTestSuite(PythonTestSuite):
"""Run CQL commands against a single Scylla instance"""
def __init__(self, path, cfg, options, mode):
def __init__(self, path, cfg, options: argparse.Namespace, mode) -> None:
super().__init__(path, cfg, options, mode)
async def add_test(self, shortname):
async def add_test(self, shortname: str) -> None:
test = CQLApprovalTest(self.next_id, shortname, self)
self.tests.append(test)
@property
def pattern(self):
def pattern(self) -> str:
return "*test.cql"
class RunTestSuite(TestSuite):
"""TestSuite for test directory with a 'run' script """
def __init__(self, path, cfg, options, mode):
def __init__(self, path: str, cfg, options: argparse.Namespace, mode: str) -> None:
super().__init__(path, cfg, options, mode)
self.scylla_exe = os.path.join("build", self.mode, "scylla")
if self.mode == "coverage":
@@ -399,19 +400,21 @@ class RunTestSuite(TestSuite):
self.scylla_env = dict()
self.scylla_env['SCYLLA'] = self.scylla_exe
async def add_test(self, shortname):
async def add_test(self, shortname) -> None:
test = RunTest(self.next_id, shortname, self)
self.tests.append(test)
@property
def pattern(self):
def pattern(self) -> str:
return "run"
class Test:
"""Base class for CQL, Unit and Boost tests"""
def __init__(self, test_no, shortname, suite):
def __init__(self, test_no: int, shortname: str, suite) -> None:
self.id = test_no
self.path = ""
self.args: List[str] = []
# Name with test suite name
self.name = os.path.join(suite.name, shortname.split('.')[0])
# Name within the suite
@@ -429,28 +432,31 @@ class Test:
self.is_cancelled = False
Test._reset(self)
def reset(self):
def reset(self) -> None:
"""Reset this object, including all derived state."""
for cls in reversed(self.__class__.__mro__):
if hasattr(cls, "_reset"):
cls._reset(self)
_reset = getattr(cls, '_reset', None)
if _reset is not None:
_reset(self)
def _reset(self):
def _reset(self) -> None:
"""Reset the test before a retry, if it is retried as flaky"""
self.success = None
self.success = False
self.time_start: float = 0
self.time_end: float = 0
@abstractmethod
async def run(self, options):
async def run(self, options: argparse.Namespace) -> 'Test':
pass
@abstractmethod
def print_summary(self):
def print_summary(self) -> None:
pass
def get_junit_etree(self):
return None
def check_log(self, trim):
def check_log(self, trim: bool) -> None:
"""Check and trim logs and xml output for tests which have it"""
if trim:
self.log_filename.unlink()
@@ -463,7 +469,7 @@ class UnitTest(Test):
"--blocked-reactor-notify-ms 2000000 --collectd 0 "
"--max-networking-io-control-blocks=100 ")
def __init__(self, test_no, shortname, suite, args):
def __init__(self, test_no: int, shortname: str, suite, args: str) -> None:
super().__init__(test_no, shortname, suite)
self.path = os.path.join("build", self.mode, "test", self.name)
self.args = shlex.split(args) + UnitTest.standard_args
@@ -473,15 +479,15 @@ class UnitTest(Test):
self.env = dict()
UnitTest._reset(self)
def _reset(self):
def _reset(self) -> None:
"""Reset the test before a retry, if it is retried as flaky"""
pass
def print_summary(self):
def print_summary(self) -> None:
print("Output of {} {}:".format(self.path, " ".join(self.args)))
print(read_log(self.log_filename))
async def run(self, options):
async def run(self, options) -> Test:
self.success = await run_test(self, options, env=self.env)
logging.info("Test #%d %s", self.id, "succeeded" if self.success else "failed ")
return self
@@ -490,7 +496,8 @@ class UnitTest(Test):
class BoostTest(UnitTest):
"""A unit test which can produce its own XML output"""
def __init__(self, test_no, shortname, suite, args, casename):
def __init__(self, test_no: int, shortname: str, suite, args: str,
casename: Optional[str]) -> None:
boost_args = []
if casename:
shortname += '.' + casename
@@ -505,12 +512,13 @@ class BoostTest(UnitTest):
self.args = boost_args + self.args
self.casename = casename
BoostTest._reset(self)
self.__junit_etree: Optional[ET.ElementTree] = None
def _reset(self):
def _reset(self) -> None:
"""Reset the test before a retry, if it is retried as flaky"""
self.__junit_etree = None
def get_junit_etree(self):
def get_junit_etree(self) -> ET.ElementTree:
def adjust_suite_name(name):
# Normalize "path/to/file.cc" to "path.to.file" to conform to
# Jenkins expectations that the suite name is a class name. ".cc"
@@ -535,7 +543,7 @@ class BoostTest(UnitTest):
os.unlink(self.xmlout)
return self.__junit_etree
def check_log(self, trim):
def check_log(self, trim: bool) -> None:
self.get_junit_etree()
super().check_log(trim)
@@ -547,7 +555,7 @@ class BoostTest(UnitTest):
class CQLApprovalTest(Test):
"""Run a sequence of CQL commands against a standlone Scylla"""
def __init__(self, test_no, shortname, suite):
def __init__(self, test_no: int, shortname: str, suite) -> None:
super().__init__(test_no, shortname, suite)
# Path to cql_repl driver, in the given build mode
self.path = "pytest"
@@ -562,22 +570,22 @@ class CQLApprovalTest(Test):
]
CQLApprovalTest._reset(self)
def _reset(self):
def _reset(self) -> None:
"""Reset the test before a retry, if it is retried as flaky"""
self.is_before_test_ok = False
self.is_executed_ok = False
self.is_new = False
self.is_after_test_ok = False
self.is_equal_result = None
self.is_equal_result = False
self.summary = "not run"
self.unidiff = None
self.unidiff: Optional[str] = None
self.server_log = None
self.env = dict()
self.env: Dict[str, str] = dict()
old_tmpfile = pathlib.Path(self.tmpfile)
if old_tmpfile.exists():
old_tmpfile.unlink()
async def run(self, options):
async def run(self, options: argparse.Namespace) -> Test:
self.success = False
self.summary = "failed"
@@ -612,6 +620,7 @@ Check test log at {}.""".format(self.log_filename))
if self.is_equal_result is False:
self.unidiff = format_unidiff(self.result, self.tmpfile)
set_summary("failed: test output does not match expected result")
assert self.unidiff is not None
logging.info("\n{}".format(palette.nocolor(self.unidiff)))
else:
self.success = True
@@ -642,7 +651,7 @@ Check test log at {}.""".format(self.log_filename))
return self
def print_summary(self):
def print_summary(self) -> None:
print("Test {} ({}) {}".format(palette.path(self.name), self.mode,
self.summary))
if self.is_executed_ok is False:
@@ -657,7 +666,7 @@ Check test log at {}.""".format(self.log_filename))
class RunTest(Test):
"""Run tests in a directory started by a run script"""
def __init__(self, test_no, shortname, suite):
def __init__(self, test_no: int, shortname: str, suite) -> None:
super().__init__(test_no, shortname, suite)
self.path = os.path.join(suite.path, shortname)
self.xmlout = os.path.join(suite.options.tmpdir, self.mode, "xml", self.uname + ".xunit.xml")
@@ -668,11 +677,11 @@ class RunTest(Test):
"""Reset the test before a retry, if it is retried as flaky"""
pass
def print_summary(self):
def print_summary(self) -> None:
print("Output of {} {}:".format(self.path, " ".join(self.args)))
print(read_log(self.log_filename))
async def run(self, options):
async def run(self, options: argparse.Namespace) -> Test:
# This test can and should be killed gently, with SIGTERM, not with SIGKILL
self.success = await run_test(self, options, gentle_kill=True, env=self.suite.scylla_env)
logging.info("Test #%d %s", self.id, "succeeded" if self.success else "failed ")
@@ -682,7 +691,7 @@ class RunTest(Test):
class PythonTest(Test):
"""Run a pytest collection of cases against a standalone Scylla"""
def __init__(self, test_no, shortname, suite):
def __init__(self, test_no: int, shortname: str, suite) -> None:
super().__init__(test_no, shortname, suite)
self.path = "pytest"
self.xmlout = os.path.join(self.suite.options.tmpdir, self.mode, "xml", self.uname + ".xunit.xml")
@@ -691,20 +700,20 @@ class PythonTest(Test):
os.path.join(suite.path, shortname + ".py")]
PythonTest._reset(self)
def _reset(self):
def _reset(self) -> None:
"""Reset the test before a retry, if it is retried as flaky"""
self.server_log = None
self.is_before_test_ok = False
self.is_after_test_ok = False
def print_summary(self):
def print_summary(self) -> None:
print("Output of {} {}:".format(self.path, " ".join(self.args)))
print(read_log(self.log_filename))
if self.server_log:
print("Server log of the first server:")
print(self.server_log)
async def run(self, options):
async def run(self, options: argparse.Namespace) -> Test:
async with self.suite.clusters.instance() as cluster:
self.args.insert(0, "--host={}".format(cluster[0].host))
try:
@@ -729,24 +738,24 @@ class PythonTest(Test):
class TabularConsoleOutput:
"""Print test progress to the console"""
def __init__(self, verbose, test_count):
def __init__(self, verbose: bool, test_count: int) -> None:
self.verbose = verbose
self.test_count = test_count
self.print_newline = False
self.last_test_no = 0
self.last_line_len = 1
def print_start_blurb(self):
def print_start_blurb(self) -> None:
print("="*80)
print("{:10s} {:^8s} {:^7s} {:8s} {}".format("[N/TOTAL]", "SUITE", "MODE", "RESULT", "TEST"))
print("-"*78)
def print_end_blurb(self):
def print_end_blurb(self) -> None:
if self.print_newline:
print("")
print("-"*78)
def print_progress(self, test):
def print_progress(self, test: Test) -> None:
self.last_test_no += 1
status = ""
if test.success:
@@ -776,12 +785,11 @@ class TabularConsoleOutput:
print(msg)
self.print_newline = False
else:
if hasattr(test, 'time_end') and test.time_end > 0:
msg += " {:.2f}s".format(test.time_end - test.time_start)
msg += " {:.2f}s".format(test.time_end - test.time_start)
print(msg)
async def run_test(test, options, gentle_kill=False, env=dict()):
async def run_test(test: Test, options: argparse.Namespace, gentle_kill=False, env=dict()) -> bool:
"""Run test program, return True if success else False"""
with test.log_filename.open("wb") as log:
@@ -867,7 +875,7 @@ async def run_test(test, options, gentle_kill=False, env=dict()):
return False
def setup_signal_handlers(loop, signaled):
def setup_signal_handlers(loop, signaled) -> None:
async def shutdown(loop, signo, signaled):
print("\nShutdown requested... Aborting tests:"),
@@ -882,7 +890,7 @@ def setup_signal_handlers(loop, signaled):
loop.add_signal_handler(signo, lambda: asyncio.create_task(shutdown(loop, signo, signaled)))
def parse_cmd_line():
def parse_cmd_line() -> argparse.Namespace:
""" Print usage and process command line options. """
parser = argparse.ArgumentParser(description="Scylla test runner")
@@ -989,7 +997,7 @@ def parse_cmd_line():
return args
async def find_tests(options):
async def find_tests(options: argparse.Namespace) -> None:
for f in glob.glob(os.path.join("test", "*")):
if os.path.isdir(f) and os.path.isfile(os.path.join(f, "suite.yaml")):
@@ -1010,7 +1018,7 @@ async def find_tests(options):
print("Found {} tests.".format(TestSuite.test_count()))
async def run_all_tests(signaled, options):
async def run_all_tests(signaled: asyncio.Event, options: argparse.Namespace) -> None:
console = TabularConsoleOutput(options.verbose, TestSuite.test_count())
signaled_task = asyncio.create_task(signaled.wait())
pending = set([signaled_task])
@@ -1055,7 +1063,7 @@ async def run_all_tests(signaled, options):
console.print_end_blurb()
def read_log(log_filename: pathlib.Path):
def read_log(log_filename: pathlib.Path) -> str:
"""Intelligently read test log output"""
try:
with log_filename.open("r") as log:
@@ -1067,7 +1075,7 @@ def read_log(log_filename: pathlib.Path):
return "===Error reading log {}===".format(e)
def print_summary(failed_tests, options):
def print_summary(failed_tests, options: argparse.Namespace) -> None:
if failed_tests:
print("The following test(s) have failed: {}".format(
palette.path(" ".join([t.name for t in failed_tests]))))
@@ -1079,7 +1087,7 @@ def print_summary(failed_tests, options):
len(failed_tests), TestSuite.test_count()))
def format_unidiff(fromfile, tofile):
def format_unidiff(fromfile: str, tofile: str) -> str:
with open(fromfile, "r") as frm, open(tofile, "r") as to:
buf = StringIO()
diff = difflib.unified_diff(
@@ -1104,7 +1112,7 @@ def format_unidiff(fromfile, tofile):
return buf.getvalue()
def write_junit_report(tmpdir, mode):
def write_junit_report(tmpdir: str, mode: str) -> None:
junit_filename = os.path.join(tmpdir, mode, "xml", "junit.xml")
total = 0
failed = 0
@@ -1132,7 +1140,7 @@ def write_junit_report(tmpdir, mode):
ET.ElementTree(xml_results).write(f, encoding="unicode")
def write_consolidated_boost_junit_xml(tmpdir, mode):
def write_consolidated_boost_junit_xml(tmpdir: str, mode: str) -> None:
xml = ET.Element("TestLog")
for suite in TestSuite.suites.values():
for test in suite.tests:
@@ -1145,7 +1153,7 @@ def write_consolidated_boost_junit_xml(tmpdir, mode):
et.write(f'{tmpdir}/{mode}/xml/boost.xunit.xml', encoding='unicode')
def open_log(tmpdir):
def open_log(tmpdir: str) -> None:
pathlib.Path(tmpdir).mkdir(parents=True, exist_ok=True)
logging.basicConfig(
filename=os.path.join(tmpdir, "test.py.log"),
@@ -1157,7 +1165,7 @@ def open_log(tmpdir):
logging.critical("Started %s", " ".join(sys.argv))
async def main():
async def main() -> int:
options = parse_cmd_line()
@@ -1179,7 +1187,7 @@ async def main():
raise
if signaled.is_set():
return -signaled.signo
return -signaled.signo # type: ignore
failed_tests = [t for t in TestSuite.all_tests() if t.success is not True]
@@ -1197,7 +1205,7 @@ async def main():
return 0 if not failed_tests else 1
async def workaround_python26789():
async def workaround_python26789() -> int:
"""Workaround for https://bugs.python.org/issue26789.
We'd like to print traceback if there is an internal error
in test.py. However, traceback module calls asyncio