Create a custom pytest test collector for .cql files and move CQL test execution logic from `CQLApprovalTest` class and `pylib/cql_repl/cql_repl.py` file to `CqlTest.runtest()` method. In result, the only difference between CQLApproval and Python suite types is suffixes of test files.
220 lines
7.2 KiB
Python
220 lines
7.2 KiB
Python
#
|
|
# Copyright (C) 2025-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
#
|
|
|
|
from __future__ import annotations
|
|
|
|
import difflib
|
|
import filecmp
|
|
import re
|
|
import shutil
|
|
import time
|
|
from collections.abc import Iterable
|
|
from functools import cached_property
|
|
from itertools import islice
|
|
from typing import TYPE_CHECKING
|
|
|
|
import pytest
|
|
from _pytest.python import Module
|
|
from _pytest.fixtures import TopRequest
|
|
from cassandra import ConsistencyLevel
|
|
from cassandra.query import SimpleStatement
|
|
from tabulate import tabulate
|
|
|
|
from test.pylib.suite.base import palette
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from _pytest._code.code import ExceptionInfo, TerminalRepr, TracebackStyle
|
|
from cassandra.cluster import Session
|
|
|
|
|
|
CQL_TEST_SUFFIX = "_test.cql"
|
|
|
|
COMMENT_RE = re.compile(r"^\s*((--|//).*)?$") # comments allowed by CQL - -- and //
|
|
DELIMITER_RE = re.compile(r"^(?!\s*(--|//)).*;\s*$") # a comment is not a delimiter even if ends with one
|
|
IP_ADDRESS_RE = re.compile(r"\d+\.\d+\.\d+\.\d+")
|
|
|
|
|
|
class CqlTestException(Exception):
|
|
...
|
|
|
|
|
|
class CqlFile(Module):
|
|
obj = None
|
|
|
|
def funcnamefilter(self, name: str) -> bool:
|
|
return True
|
|
|
|
def collect(self) -> Iterable[CqlTest]:
|
|
assert self.ihook.pytest_pycollect_makeitem(
|
|
collector=self,
|
|
name=self.name,
|
|
obj=self.obj,
|
|
) is None, ".cql file will be not collected as a Python module"
|
|
yield CqlTest.from_parent(parent=self, name=self.name.removesuffix(CQL_TEST_SUFFIX))
|
|
|
|
|
|
class CqlTest(pytest.Item):
|
|
def __init__(self, name: str, parent: CqlFile):
|
|
super().__init__(name=name, parent=parent)
|
|
|
|
# Stuff needed for fixture support.
|
|
self.obj = None
|
|
self._fixtureinfo = self.session._fixturemanager.getfixtureinfo(node=self, func=None, cls=None)
|
|
self.fixturenames = self._fixtureinfo.names_closure
|
|
self._initrequest()
|
|
|
|
def _initrequest(self) -> None:
|
|
self.funcargs: dict[str, Any] = {} # fixtures need to be used automatically to appear in this dict
|
|
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
|
|
|
|
def setup(self) -> None:
|
|
self._request._fillfixtures()
|
|
|
|
@cached_property
|
|
def cql(self) -> Session:
|
|
"""CQL connection to the DB host."""
|
|
|
|
return self.funcargs["cql"]
|
|
|
|
@cached_property
|
|
def keyspace(self) -> str:
|
|
"""Name of a keyspace used by the test."""
|
|
|
|
return self.funcargs["keyspace"]
|
|
|
|
@cached_property
|
|
def output_path(self) -> Path:
|
|
"""Path to a file to collect the test output."""
|
|
|
|
return self.funcargs["output_path"]
|
|
|
|
@cached_property
|
|
def result_path(self) -> Path:
|
|
"""Path to a file with the expected test output."""
|
|
|
|
return self.path.with_suffix(".result")
|
|
|
|
@cached_property
|
|
def reject_path(self) -> Path:
|
|
"""Path to a file to store the test output if it will be different from .result file."""
|
|
|
|
return self.path.with_suffix(".reject")
|
|
|
|
def runtest(self) -> None:
|
|
with self.path.open(encoding="utf-8") as ifile, self.output_path.open(mode="a", encoding="utf-8") as ofile:
|
|
self.cql.set_keyspace(self.keyspace)
|
|
|
|
for line in ifile:
|
|
ofile.write("> ")
|
|
|
|
if COMMENT_RE.match(line):
|
|
ofile.write(line)
|
|
continue
|
|
|
|
query_bits = [line]
|
|
|
|
# Read the rest of input until delimiter or EOF.
|
|
while not DELIMITER_RE.match(line):
|
|
line = ifile.readline()
|
|
if not line:
|
|
break
|
|
query_bits.append(line)
|
|
|
|
ofile.write("> ".join(query_bits))
|
|
|
|
stmt = SimpleStatement(
|
|
query_string="".join(query_bits),
|
|
consistency_level=ConsistencyLevel.ONE,
|
|
serial_consistency_level=ConsistencyLevel.SERIAL,
|
|
)
|
|
|
|
try:
|
|
result = self.cql.execute(stmt)
|
|
if not result.column_names:
|
|
ofile.write("OK\n")
|
|
else:
|
|
ofile.write(tabulate(
|
|
tabular_data=prettify(result.current_rows),
|
|
headers=result.column_names,
|
|
tablefmt="psql",
|
|
))
|
|
ofile.write("\n")
|
|
except Exception as exc:
|
|
# Replace IP addresses with 127.0.0.1 in the message to make the output stable.
|
|
ofile.write(f"{IP_ADDRESS_RE.sub(repl="127.0.0.1", string=str(exc))}\n")
|
|
|
|
if not self.output_path.is_file():
|
|
raise CqlTestException("No output file")
|
|
|
|
try:
|
|
if not self.result_path.is_file():
|
|
raise CqlTestException("No result file")
|
|
|
|
if not filecmp.cmp(self.result_path, self.output_path):
|
|
raise CqlTestException(
|
|
f"Test output does not match expected result:\n"
|
|
f"Expected: {self.result_path}\n"
|
|
f"Actual (copied to): {self.reject_path}\n"
|
|
f"Diff:\n{format_unified_diff(fromfile=self.result_path, tofile=self.output_path, head=61)}"
|
|
)
|
|
except CqlTestException:
|
|
# Move the .reject file close to the .result file so that it's easy to analyze the diff or
|
|
# overwrite .result with .reject.
|
|
shutil.move(self.output_path, self.reject_path)
|
|
raise
|
|
|
|
self.output_path.unlink()
|
|
|
|
def repr_failure(self,
|
|
excinfo: ExceptionInfo[BaseException],
|
|
style: TracebackStyle | None = None) -> str | TerminalRepr:
|
|
if excinfo.errisinstance(CqlTestException):
|
|
return excinfo.value.args[0]
|
|
return super().repr_failure(excinfo, style)
|
|
|
|
|
|
def prettify(rows: Any) -> Any:
|
|
"""Recursively replace all None's with "null" and convert all iterables to lists."""
|
|
|
|
if rows is None:
|
|
return "null"
|
|
|
|
if isinstance(rows, str | float | int | bool) or not isinstance(rows, Iterable):
|
|
return rows
|
|
|
|
return [prettify(row) for row in rows]
|
|
|
|
|
|
def colorize_unified_diff_line(line: str) -> str:
|
|
match line[0]:
|
|
case "+":
|
|
return palette.diff_in(line)
|
|
case "-":
|
|
return palette.diff_out(line)
|
|
case "@":
|
|
return palette.diff_mark(line)
|
|
return line
|
|
|
|
|
|
def format_unified_diff(fromfile: Path,
|
|
tofile: Path,
|
|
n_context_lines: int = 10, # number of diff context lines
|
|
head: int | None = None) -> str: # return first `head` lines of the output
|
|
with fromfile.open(encoding="utf-8") as frm, tofile.open(encoding="utf-8") as to:
|
|
diff = difflib.unified_diff(
|
|
a=frm.readlines(),
|
|
b=to.readlines(),
|
|
fromfile=str(fromfile),
|
|
tofile=str(tofile),
|
|
fromfiledate=time.ctime(fromfile.stat().st_mtime),
|
|
tofiledate=time.ctime(tofile.stat().st_mtime),
|
|
n=n_context_lines,
|
|
)
|
|
return "".join(colorize_unified_diff_line(line) for line in islice(diff, head))
|