Files
scylladb/test/cqlpy/cassandra_tests/porting.py
Avi Kivity f3eade2f62 treewide: relicense to ScyllaDB-Source-Available-1.0
Drop the AGPL license in favor of a source-available license.
See the blog post [1] for details.

[1] https://www.scylladb.com/2024/12/18/why-were-moving-to-a-source-available-license/
2024-12-18 17:45:13 +02:00

389 lines
15 KiB
Python

# Copyright 2020-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#############################################################################
# This file contains various utility functions which are useful for porting
# Cassandra's tests into Python. Most of them are intended to resemble those
# in Cassandra's test/unit/org/apache/cassandra/cql3/CQLTester.java which is
# used by many of those tests.
import pytest
import re
import collections
import struct
import time
from ..util import unique_name
from contextlib import contextmanager
from cassandra.protocol import SyntaxException, InvalidRequest
from cassandra.protocol import ConfigurationException
from cassandra.util import SortedSet, OrderedMapSerializedKey
from cassandra.query import UNSET_VALUE
from .. import nodetool
# A utility function for creating a new temporary table with a given schema.
# Because Scylla becomes slower when a huge number of uniquely-named tables
# are created and deleted (see https://github.com/scylladb/scylla/issues/7620)
# we keep here a list of previously used but now deleted table names, and
# reuse one of these names when possible.
# This function can be used in a "with", as:
# with create_table(cql, test_keyspace, '...') as table:
previously_used_table_names = []
@contextmanager
def create_table(cql, keyspace, schema):
global previously_used_table_names
if not previously_used_table_names:
previously_used_table_names.append(unique_name())
table_name = previously_used_table_names.pop()
table = keyspace + "." + table_name
cql.execute("CREATE TABLE " + table + " " + schema)
try:
yield table
finally:
cql.execute("DROP TABLE " + table)
previously_used_table_names.append(table_name)
@contextmanager
def create_type(cql, keyspace, cmd):
type_name = keyspace + "." + unique_name()
cql.execute("CREATE TYPE " + type_name + " " + cmd)
try:
yield type_name
finally:
cql.execute("DROP TYPE " + type_name)
@contextmanager
def create_function(cql, keyspace, arg):
function_name = keyspace + "." + unique_name()
cql.execute("CREATE FUNCTION " + function_name + " " + arg)
try:
yield function_name
finally:
cql.execute("DROP FUNCTION " + function_name)
@contextmanager
def create_materialized_view(cql, keyspace, arg):
mv_name = keyspace + "." + unique_name()
cql.execute("CREATE MATERIALIZED VIEW " + mv_name + " " + arg)
try:
yield mv_name
finally:
cql.execute("DROP MATERIALIZED VIEW " + mv_name)
@contextmanager
def create_keyspace(cql, arg):
ks_name = unique_name()
cql.execute("CREATE KEYSPACE " + ks_name + " WITH " + arg)
try:
yield ks_name
finally:
cql.execute("DROP KEYSPACE " + ks_name)
# utility function to substitute table name in command.
# For easy conversion of Java code, support %s as used in Java. Also support
# it *multiple* times (always interpolating the same table name). In Java,
# a %1$s is needed for this.
def subs_table(cmd, table):
return cmd.replace('%s', table)
def execute(cql, table, cmd, *args):
if args:
prepared = cql.prepare(subs_table(cmd, table))
return cql.execute(prepared, args)
else:
return cql.execute(subs_table(cmd, table))
def execute_with_paging(cql, table, cmd, page_size, *args):
prepared = cql.prepare(subs_table(cmd, table))
prepared.fetch_size = page_size
return cql.execute(prepared, args)
def execute_without_paging(cql, table, cmd, *args):
prepared = cql.prepare(subs_table(cmd, table))
prepared.fetch_size = 0
return cql.execute(prepared, args)
def assert_invalid_throw(cql, table, type, cmd, *args):
with pytest.raises(type):
execute(cql, table, cmd, *args)
assertInvalidThrow = assert_invalid_throw
def assert_invalid(cql, table, cmd, *args):
# In the original Java code, assert_invalid() accepts any exception
# (it's like passing "Exception" below for the exception type).
# However, this makes it a very fragile test - any silly typo will
# cause a SyntaxException and assert_invalid() will seem to succeed
# without testing anything! It seems most of the tests do fine with
# InvalidRequest tested here. If in the future it turns out that this
# doesn't work, maybe we should check for any exception *except*
# SyntaxException.
assert_invalid_throw(cql, table, InvalidRequest, cmd, *args)
assertInvalid = assert_invalid
def assert_invalid_syntax(cql, table, cmd, *args):
assert_invalid_throw(cql, table, SyntaxException, cmd, *args)
assertInvalidSyntax = assert_invalid_syntax
def assert_invalid_message(cql, table, message, cmd, *args):
with pytest.raises(InvalidRequest, match=re.escape(message)):
execute(cql, table, cmd, *args)
assertInvalidMessage = assert_invalid_message
def assert_invalid_message_ignore_case(cql, table, message, cmd, *args):
with pytest.raises(InvalidRequest, match=re.compile(re.escape(message), re.IGNORECASE)):
execute(cql, table, cmd, *args)
def assert_invalid_syntax_message(cql, table, message, cmd, *args):
with pytest.raises(SyntaxException, match=re.escape(message)):
execute(cql, table, cmd, *args)
assertInvalidSyntaxMessage = assert_invalid_syntax_message
def assert_invalid_message_re(cql, table, message, cmd, *args):
with pytest.raises(InvalidRequest, match=message):
execute(cql, table, cmd, *args)
assertInvalidMessageRE = assert_invalid_message_re
def assert_invalid_throw_message(cql, table, message, typ, cmd, *args):
with pytest.raises(typ, match=re.escape(message)):
execute(cql, table, cmd, *args)
assertInvalidThrowMessage = assert_invalid_throw_message
def assert_invalid_throw_message_re(cql, table, message, typ, cmd, *args):
with pytest.raises(typ, match=message):
execute(cql, table, cmd, *args)
# Cassandra's CQLTester.java has a much more elaborate implementation
# of these functions, which carefully prints out the differences when
# the assertion fails. We can also do this in the future, but pytest's
# assert rewriting is already pretty good that makes such eleborate
# comparison code less critical.
def assert_row_count(result, expected):
assert len(list(result)) == expected
assertRowCount = assert_row_count
def assert_empty(result):
assert len(list(result)) == 0
assertEmpty = assert_empty
def assertArrayEquals(a, b):
assert a == b
def assertEquals(a, b):
assert a == b
def getRows(results):
return list(results)
# Result objects contain some strange types specific to the CQL driver, which
# normally compare well against normal Python types, but in some cases of nested
# types they do not, and require some cleanup:
def result_cleanup(item):
if isinstance(item, OrderedMapSerializedKey):
# Because of issue #7856, we can't use item.items() if the test
# used prepared statement with wrongly ordered nested item, so
# we can use _items instead.
return { freeze(item[0]): item[1] for item in item._items }
if isinstance(item, SortedSet):
return { freeze(x) for x in item }
if item != item:
# item is a bizarre object like NaN which isn't equal to itself
# so it's hopeless it will ever equal to anything else in the
# result set. Return an object that will allow assertRows with NaN.
return "NaN-Like Object"
return item
def assert_rows_list(result, expected):
allresults = list(result)
assert len(allresults) == len(expected)
for r,e in zip(allresults, expected):
if hasattr(r, 'values'):
# case of ordered_dict_factory or dict_factory
r = [result_cleanup(col) for col in r.values()]
else:
# case of namedtuple_factory
r = [result_cleanup(col) for col in r]
assert r == e
def assert_rows(result, *expected):
assert_rows_list(result, expected)
assertRows = assert_rows
assertRowsNet = assert_rows
# Check if results is one of two possible result sets.
# Can be useful in cases where Cassandra and Scylla results are
# expected to be different and we consider that fine.
def assert_rows2(result, expected1, expected2):
allresults = list(result)
assert len(allresults) == len(expected1) or len(allresults) == len(expected2)
same1 = True
for r,e in zip(allresults, expected1):
r = [result_cleanup(col) for col in r]
same1 = same1 and r == e
same2 = True
for r,e in zip(allresults, expected2):
r = [result_cleanup(col) for col in r]
same2 = same2 and r == e
assert same1 or same2
# To compare two lists of items (each is a dict) without regard for order,
# The following function, multiset() converts the list into a multiset
# (set with duplicates) where order doesn't matter, so the multisets can
# be compared.
def freeze(item):
if isinstance(item, dict):
return frozenset((freeze(key), freeze(value)) for key, value in item.items())
if isinstance(item, OrderedMapSerializedKey):
# Because of issue #7856, we can't use item.items() if the test
# used prepared statement with wrongly ordered nested item, so
# we can use _items instead.
return frozenset((freeze(item[0]), freeze(item[1])) for item in item._items)
elif hasattr(item, '_asdict'):
# Probably a namedtuple, e.g., a Row. It is not hashable so we need
# to convert it to something else. For simplicity, let's drop the
# names and convert it to a regular tuple. Note that _asdict()
# returns an *ordered* dict, so the result is a tuple, not frozenset.
return tuple(freeze(value) for key, value in item._asdict().items())
elif isinstance(item, list):
return tuple(freeze(value) for value in item)
elif isinstance(item, SortedSet):
return frozenset(freeze(val) for val in item)
elif isinstance(item, set):
return frozenset(val for val in item)
return item
def multiset(items):
return collections.Counter([freeze(item) for item in items])
def assert_rows_ignoring_order(result, *expected):
allresults = list(result)
assert len(allresults) == len(expected)
assert multiset(allresults) == multiset(expected)
assertRowsIgnoringOrder = assert_rows_ignoring_order
# Unfortunately, collections.Counter objects don't implement comparison
# operators <, <=, >, >=, as sets do. Lets implement the c1 <= c2:
def is_subset(c1, c2):
return all(c1[x] <= c2[x] for x in c1)
def assert_rows_ignoring_order_and_extra(result, *expected):
allresults = list(result)
assert is_subset(multiset(expected), multiset(allresults))
def assert_all_rows(cql, table, *expected):
# The original Cassandra implementation of assert_all_rows() calls
# assert_rows(), not assert_rows_ignoring_order(). That means that if
# the default partitioner ever changes, all these tests will break.
# This makes little sense to me, so I changed to ignoring order.
return assert_rows_ignoring_order(execute(cql, table, "SELECT * FROM %s"), *expected)
# Note that the Python driver has a function _clean_column_name() which
# "cleans" column names in result sets by dropping any non-alphanumeric
# character at the beginning or end of the name and replacing such characters
# in the middle by a "_". So, for example the column name "m['1']" is
# converted to "m__1". This cleanup is needed because of the limitations
# of Python's namedtuple.
# Wrap your query with the original_column_names() context manager below to
# avoid using namedtuple in a query, and avoid this cleanup.
def assert_column_names(result, *expected):
r = result.one()
if hasattr(r, '_fields'):
assert r._fields == expected
else:
# ordered_dict_factory or dict_factory gets here:
assert tuple(r.keys()) == expected
assertColumnNames = assert_column_names
# Use the "original_column_names" context manager to monkey-patch the query
# wrapped by it to NOT use the column name "cleanup" described above
# for assert_column_names().
# For example:
# with original_column_names(cql) as cql:
# assertColumnNames(execute(cql, table, "SELECT a+a, a+b"),
# "a+a", "a+b")
@contextmanager
def original_column_names(cql):
profiles = cql.cluster.profile_manager.profiles
profile = next(iter(profiles.values())) # assume there's just one...
saved_row_factory = profile.row_factory
from cassandra.query import ordered_dict_factory
profile.row_factory = ordered_dict_factory
try:
yield cql
finally:
profile.row_factory = saved_row_factory
def flush(cql, table):
nodetool.flush(cql, table)
def compact(cql, table):
nodetool.compact(cql, table)
# Looping "for _ in before_and_after_flush(cql, table)" runs code twice -
# once immediately, then a flush, and then a second time.
def before_and_after_flush(cql, table):
yield None
flush(cql, table)
yield None
# Utility function for truncating a number to single-precision floating point,
# which is what we can expect when reading a "float" column from Scylla.
# unfortunately, I couldn't find a function from the Cassandra driver or
# Python to do this more easily (maybe I'm missing something?)
def to_float(num):
return struct.unpack('f', struct.pack('f', num))[0]
# Index creation is asynchronous, this method searches in the system table
# IndexInfo for the specified index and returns true if it finds it, which
# indicates the index was built on the pre-existing data.
# If we haven't found it after 30 seconds we give-up.
def wait_for_index(cql, table, index):
keyspace = table.split('.')[0]
start_time = time.time()
while time.time() < start_time + 30:
# Implementation 1: a full partition scan. This is the code which
# Cassandra's unit test had originally.
#for row in cql.execute(f"SELECT index_name FROM system.\"IndexInfo\" WHERE table_name = '{keyspace}'"):
# if row.index_name == index:
# return True
# Implementation 2 (the most efficient): Directly read the specific
# clustering key we want. Needs clustering key slices to work
# correctly in the virtual reader. (issue #8600)
if list(cql.execute(f"SELECT index_name FROM system.\"IndexInfo\" WHERE table_name = '{keyspace}' and index_name = '{index}'")):
return True
time.sleep(0.1)
return False
# user_type("a", 1, "b", 2) creates a named tuple with component names "a", "b"
# and values 1, 2. The return of this function can be used to bind to a UDT.
# The number of arguments is assumed to be even.
def user_type(*args):
return collections.namedtuple('user_type', args[::2])(*args[1::2])
userType = user_type
# a row(...) used in assertRows can simply be a list
def row(*args):
return list(args)
# Java tests use "null" where we need "None"
null = None
def unset():
return UNSET_VALUE
# Java true and false are lowercase
true = True
false = False