Files
scylladb/test/pylib/cql_repl.py
Evgeniy Naydanov cdc4b520da test.py: cql: run tests using bare pytest command
Create a custom pytest test collector for .cql files and
move CQL test execution logic from `CQLApprovalTest` class
and `pylib/cql_repl/cql_repl.py` file to `CqlTest.runtest()`
method.

In result, the only difference between CQLApproval and Python
suite types is suffixes of test files.
2025-06-03 07:54:51 +00:00

220 lines
7.2 KiB
Python

#
# Copyright (C) 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
from __future__ import annotations
import difflib
import filecmp
import re
import shutil
import time
from collections.abc import Iterable
from functools import cached_property
from itertools import islice
from typing import TYPE_CHECKING
import pytest
from _pytest.python import Module
from _pytest.fixtures import TopRequest
from cassandra import ConsistencyLevel
from cassandra.query import SimpleStatement
from tabulate import tabulate
from test.pylib.suite.base import palette
if TYPE_CHECKING:
from pathlib import Path
from typing import Any
from _pytest._code.code import ExceptionInfo, TerminalRepr, TracebackStyle
from cassandra.cluster import Session
CQL_TEST_SUFFIX = "_test.cql"
COMMENT_RE = re.compile(r"^\s*((--|//).*)?$") # comments allowed by CQL - -- and //
DELIMITER_RE = re.compile(r"^(?!\s*(--|//)).*;\s*$") # a comment is not a delimiter even if ends with one
IP_ADDRESS_RE = re.compile(r"\d+\.\d+\.\d+\.\d+")
class CqlTestException(Exception):
...
class CqlFile(Module):
obj = None
def funcnamefilter(self, name: str) -> bool:
return True
def collect(self) -> Iterable[CqlTest]:
assert self.ihook.pytest_pycollect_makeitem(
collector=self,
name=self.name,
obj=self.obj,
) is None, ".cql file will be not collected as a Python module"
yield CqlTest.from_parent(parent=self, name=self.name.removesuffix(CQL_TEST_SUFFIX))
class CqlTest(pytest.Item):
def __init__(self, name: str, parent: CqlFile):
super().__init__(name=name, parent=parent)
# Stuff needed for fixture support.
self.obj = None
self._fixtureinfo = self.session._fixturemanager.getfixtureinfo(node=self, func=None, cls=None)
self.fixturenames = self._fixtureinfo.names_closure
self._initrequest()
def _initrequest(self) -> None:
self.funcargs: dict[str, Any] = {} # fixtures need to be used automatically to appear in this dict
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
def setup(self) -> None:
self._request._fillfixtures()
@cached_property
def cql(self) -> Session:
"""CQL connection to the DB host."""
return self.funcargs["cql"]
@cached_property
def keyspace(self) -> str:
"""Name of a keyspace used by the test."""
return self.funcargs["keyspace"]
@cached_property
def output_path(self) -> Path:
"""Path to a file to collect the test output."""
return self.funcargs["output_path"]
@cached_property
def result_path(self) -> Path:
"""Path to a file with the expected test output."""
return self.path.with_suffix(".result")
@cached_property
def reject_path(self) -> Path:
"""Path to a file to store the test output if it will be different from .result file."""
return self.path.with_suffix(".reject")
def runtest(self) -> None:
with self.path.open(encoding="utf-8") as ifile, self.output_path.open(mode="a", encoding="utf-8") as ofile:
self.cql.set_keyspace(self.keyspace)
for line in ifile:
ofile.write("> ")
if COMMENT_RE.match(line):
ofile.write(line)
continue
query_bits = [line]
# Read the rest of input until delimiter or EOF.
while not DELIMITER_RE.match(line):
line = ifile.readline()
if not line:
break
query_bits.append(line)
ofile.write("> ".join(query_bits))
stmt = SimpleStatement(
query_string="".join(query_bits),
consistency_level=ConsistencyLevel.ONE,
serial_consistency_level=ConsistencyLevel.SERIAL,
)
try:
result = self.cql.execute(stmt)
if not result.column_names:
ofile.write("OK\n")
else:
ofile.write(tabulate(
tabular_data=prettify(result.current_rows),
headers=result.column_names,
tablefmt="psql",
))
ofile.write("\n")
except Exception as exc:
# Replace IP addresses with 127.0.0.1 in the message to make the output stable.
ofile.write(f"{IP_ADDRESS_RE.sub(repl="127.0.0.1", string=str(exc))}\n")
if not self.output_path.is_file():
raise CqlTestException("No output file")
try:
if not self.result_path.is_file():
raise CqlTestException("No result file")
if not filecmp.cmp(self.result_path, self.output_path):
raise CqlTestException(
f"Test output does not match expected result:\n"
f"Expected: {self.result_path}\n"
f"Actual (copied to): {self.reject_path}\n"
f"Diff:\n{format_unified_diff(fromfile=self.result_path, tofile=self.output_path, head=61)}"
)
except CqlTestException:
# Move the .reject file close to the .result file so that it's easy to analyze the diff or
# overwrite .result with .reject.
shutil.move(self.output_path, self.reject_path)
raise
self.output_path.unlink()
def repr_failure(self,
excinfo: ExceptionInfo[BaseException],
style: TracebackStyle | None = None) -> str | TerminalRepr:
if excinfo.errisinstance(CqlTestException):
return excinfo.value.args[0]
return super().repr_failure(excinfo, style)
def prettify(rows: Any) -> Any:
"""Recursively replace all None's with "null" and convert all iterables to lists."""
if rows is None:
return "null"
if isinstance(rows, str | float | int | bool) or not isinstance(rows, Iterable):
return rows
return [prettify(row) for row in rows]
def colorize_unified_diff_line(line: str) -> str:
match line[0]:
case "+":
return palette.diff_in(line)
case "-":
return palette.diff_out(line)
case "@":
return palette.diff_mark(line)
return line
def format_unified_diff(fromfile: Path,
tofile: Path,
n_context_lines: int = 10, # number of diff context lines
head: int | None = None) -> str: # return first `head` lines of the output
with fromfile.open(encoding="utf-8") as frm, tofile.open(encoding="utf-8") as to:
diff = difflib.unified_diff(
a=frm.readlines(),
b=to.readlines(),
fromfile=str(fromfile),
tofile=str(tofile),
fromfiledate=time.ctime(fromfile.stat().st_mtime),
tofiledate=time.ctime(tofile.stat().st_mtime),
n=n_context_lines,
)
return "".join(colorize_unified_diff_line(line) for line in islice(diff, head))