Files
scylladb/test/cluster/dtest/tools/assertions.py
Dario Mirovic 3c5dd5e5ae test: dtest: migrate setup and tools from dtest
Migrate several functionalities from dtest. These will be used by
the schema_management_test.py tests when they are enabled.

Refs #26932
2025-12-18 12:54:43 +01:00

266 lines
10 KiB
Python

#
# Copyright (C) 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
"""
The assertion methods in this file are used to structure, execute, and test different queries and scenarios.
Use these anytime you are trying to check the content of a table, the row count of a table, if a query should
raise an exception, etc. These methods handle error messaging well, and will help discovering and treating bugs.
An example:
Imagine some table, test:
id | name
1 | John Doe
2 | Jane Doe
We could assert the row count is 2 by using:
assert_row_count(session, 'test', 2)
After inserting [3, 'Alex Smith'], we can ensure the table is correct by:
assert_all(session, "SELECT * FROM test", [[1, 'John Doe'], [2, 'Jane Doe'], [3, 'Alex Smith']])
or we could check the insert was successful:
assert_one(session, "SELECT * FROM test WHERE id = 3", [3, 'Alex Smith'])
We could remove all rows in test, and assert this was sucessful with:
assert_none(session, "SELECT * FROM test")
Perhaps we want to assert invalid queries will throw an exception:
assert_invalid(session, "SELECT FROM test")
or, maybe after shutting down all the nodes, we want to assert an Unavailable exception is raised:
assert_unavailable(session.execute, "SELECT * FROM test")
OR
assert_exception(session, "SELECT * FROM test", expected=Unavailable)
"""
import re
from cassandra import ConsistencyLevel, InvalidRequest
from cassandra.query import SimpleStatement
from test.cluster.dtest.tools.retrying import retrying
def _assert_exception(fun, *args, **kwargs):
matching = kwargs.pop("matching", None)
expected = kwargs["expected"]
try:
if len(args) == 0:
fun(None)
else:
fun(*args)
except expected as e:
if matching is not None:
msg = str(e)
assert re.search(matching, msg), f"Raised exception '{msg}' failed to match with '{matching}'"
except Exception as e:
raise e
else:
assert False, "Expecting query to raise an exception, but nothing was raised."
def assert_exception(session, query, matching=None, expected=None):
if expected is None:
assert False, "Expected exception should not be None. Your test code is wrong, please set `expected`."
_assert_exception(session.execute, query, matching=matching, expected=expected)
def assert_invalid(session, query, matching=None, expected=InvalidRequest):
"""
Attempt to issue a query and assert that the query is invalid.
@param session Session to use
@param query Invalid query to run
@param matching Optional error message string contained within expected exception
@param expected Exception expected to be raised by the invalid query
Examples:
assert_invalid(session, 'DROP USER nonexistent', "nonexistent doesn't exist")
"""
assert_exception(session, query, matching=matching, expected=expected)
def assert_one(session, query, expected, cl=None):
"""
Assert query returns one row.
@param session Session to use
@param query Query to run
@param expected Expected results from query
@param cl Optional Consistency Level setting. Default ONE
Examples:
assert_one(session, "LIST USERS", ['cassandra', True])
assert_one(session, query, [0, 0])
"""
from test.cluster.dtest.tools.data import rows_to_list # to avoid cyclic dependency
simple_query = SimpleStatement(query, consistency_level=cl)
res = session.execute(simple_query)
list_res = rows_to_list(res)
assert list_res == [expected], f"Expected {[expected]} from {query}, but got {list_res}"
@retrying(num_attempts=1, sleep_time=10)
def assert_all(
session,
query,
expected,
cl=ConsistencyLevel.ONE,
ignore_order=False,
num_attempts=1,
sleep_time=10,
result_as_string=False,
print_result_on_failure=True,
timeout=None,
):
"""
Assert query returns all expected items optionally in the correct order
@param session Session in use
@param query Query to run
@param expected Expected results from query
@param cl Optional Consistency Level setting. Default ONE
@param ignore_order Optional boolean flag determining whether response is ordered
@param timeout Optional query timeout, in seconds
@param num_attempts: defines how many times to try to assert data in case failure. Used in retrying decorator
@param sleep_time: defines how many seconds to sleep between attempts. Used in retrying decorator
@param result_as_string: return result as string
@param print_result_on_failure print actual result in the error in case failure
Examples:
assert_all(session, "LIST USERS", [['aleksey', False], ['cassandra', True]])
assert_all(self.session1, "SELECT * FROM ttl_table;", [[1, 42, 1, 1]])
"""
from test.cluster.dtest.tools.data import get_list_res, rows_to_list # to avoid cyclic dependency
from test.cluster.dtest.tools.misc import list_to_hashed_dict # to avoid cyclic dependency
if result_as_string:
list_res = get_list_res(session, query, cl, ignore_order, result_as_string, timeout=timeout)
else:
simple_query = SimpleStatement(query, consistency_level=cl)
res = session.execute(simple_query) if timeout is None else session.execute(simple_query, timeout=timeout)
list_res = rows_to_list(res)
if ignore_order:
expected = list_to_hashed_dict(expected)
list_res = list_to_hashed_dict(list_res)
error = f"Expected {expected} from {query}, but got {list_res}" if print_result_on_failure else f"Actual result ({len(list_res)} rows) is not as expected ({len(expected)} rows). Query: {query}"
assert list_res == expected, error
def assert_almost_equal(*args, **kwargs):
"""
Assert variable number of arguments all fall within a margin of error.
@params *args variable number of numerical arguments to check
@params error Optional margin of error. Default 0.16
@params error_message Optional error message to print. Default ''
Examples:
assert_almost_equal(sizes[2], init_size)
assert_almost_equal(ttl_session1, ttl_session2[0][0], error=0.005)
"""
error = kwargs["error"] if "error" in kwargs else 0.16
vmax = max(args)
vmin = min(args)
error_message = "" if "error_message" not in kwargs else kwargs["error_message"]
assert vmin > vmax * (1.0 - error) or vmin == vmax, f"values not within {error * 100:.2f}% of the max: {args} ({error_message})"
@retrying(num_attempts=1, sleep_time=10)
def assert_row_count(
session,
table_name,
expected,
consistency_level=ConsistencyLevel.ONE,
num_attempts=1,
sleep_time=10,
timeout=None,
):
"""
Function to validate the row count expected in table_name
@param session Session to use
@param table_name Name of the table to query
@param expected Number of rows expected to be in table
@param num_attempts defines how many times to try to assert data in case failure. Used in retrying decorator
@param sleep_time defines how many seconds to sleep between attempts. Used in retrying decorator
@param timeout
Examples:
assert_row_count(self.session1, 'ttl_table', 1)
"""
from test.cluster.dtest.tools.data import run_query_with_data_processing # to avoid cyclic dependency
query = f"SELECT count(*) FROM {table_name}"
count = run_query_with_data_processing(session, query, consistency_level=consistency_level, session_timeout=timeout)
if isinstance(count, list):
count = count[0][0]
assert count == expected, f"Expected a row count of {expected} in table '{table_name}', but got {count}"
@retrying(num_attempts=1, sleep_time=10)
def assert_row_count_in_select_less(
session,
query,
max_rows_expected,
consistency_level=ConsistencyLevel.ONE,
num_attempts=1,
timeout=None,
):
"""
Function to validate the row count are returned by select
:param num_attempts: defines how many times to try to assert data in case failure. Used in retry_with_func_attempts decorator
"""
from test.cluster.dtest.tools.data import get_list_res
count = len(get_list_res(session, query, consistency_level, timeout=timeout))
assert count < max_rows_expected, f'Expected a row count < of {max_rows_expected} in query "{query}", but got {count}'
def assert_length_equal(object_with_length, expected_length):
"""
Assert an object has a specific length.
@param object_with_length The object whose length will be checked
@param expected_length The expected length of the object
Examples:
assert_length_equal(res, nb_counter)
"""
assert len(object_with_length) == expected_length, f"Expected {object_with_length} to have length {expected_length}, but instead is of length {len(object_with_length)}"
def assert_lists_equal_ignoring_order(list1, list2, sort_key=None):
"""
asserts that the contents of the two provided lists are equal
but ignoring the order that the items of the lists are actually in
:param list1: list to check if it's contents are equal to list2
:param list2: list to check if it's contents are equal to list1
:param sort_key: if the contents of the list are of type dict, the
key to use of each object to sort the overall object with
"""
normalized_list1 = []
for obj in list1:
normalized_list1.append(obj)
normalized_list2 = []
for obj in list2:
normalized_list2.append(obj)
if not sort_key:
sorted_list1 = sorted(normalized_list1, key=lambda elm: elm[0])
sorted_list2 = sorted(normalized_list2, key=lambda elm: elm[0])
elif not sort_key == "id" and "id" in list1[0].keys():
# first always sort by "id"
# that way we get a two factor sort which will increase the chance of ordering lists exactly the same
sorted_list1 = sorted(sorted(normalized_list1, key=lambda elm: elm["id"]), key=lambda elm: elm[sort_key])
sorted_list2 = sorted(sorted(normalized_list2, key=lambda elm: elm["id"]), key=lambda elm: elm[sort_key])
elif isinstance(list1[0]["id"], int | float):
sorted_list1 = sorted(normalized_list1, key=lambda elm: elm[sort_key])
sorted_list2 = sorted(normalized_list2, key=lambda elm: elm[sort_key])
else:
sorted_list1 = sorted(normalized_list1, key=lambda elm: str(elm[sort_key]))
sorted_list2 = sorted(normalized_list2, key=lambda elm: str(elm[sort_key]))
assert sorted_list1 == sorted_list2