Files
scylladb/test/cluster/util.py
Artsiom Mishuta cd1679934c test/pylib: use exponential backoff in wait_for()
Change wait_for() defaults from period=1s/no backoff to period=0.1s
with 1.5x backoff capped at 1.0s. This catches fast conditions in
100ms instead of 1000ms, benefiting ~100 call sites automatically.

Add completion logging with elapsed time and iteration count.

Tested local with test/cluster/test_fencing.py::test_fence_hints (dev mode),
log output:

  wait_for(at_least_one_hint_failed) completed in 0.83s (4 iterations)
  wait_for(exactly_one_hint_sent) completed in 1.34s (5 iterations)

Fixes SCYLLADB-738

Closes scylladb/scylladb#29173
2026-03-24 23:49:49 +02:00

555 lines
24 KiB
Python

#
# Copyright (C) 2022-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
"""
Test consistency of schema changes with topology changes.
"""
import asyncio
import logging
import functools
import operator
import time
import re
from contextlib import asynccontextmanager, contextmanager, suppress
from cassandra.cluster import ConnectionException, ConsistencyLevel, NoHostAvailable, Session, SimpleStatement # type: ignore # pylint: disable=no-name-in-module
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
from test.pylib.internal_types import ServerInfo, HostID
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import get_host_api_address, read_barrier
from test.pylib.util import wait_for, wait_for_cql_and_get_hosts, get_available_host, unique_name
from typing import Optional, List, Union
logger = logging.getLogger(__name__)
UUID_REGEX = re.compile(r"([0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12})")
async def reconnect_driver(manager: ManagerClient) -> Session:
"""Can be used as a workaround for scylladb/python-driver#295.
When restarting a node, a pre-existing session connected to the cluster
may reconnect to the restarted node multiple times. Even if we verify
that the session can perform a query on that node (e.g. like `wait_for_cql`,
which tries to select from system.local), the driver may again reconnect
after that, and following queries may fail.
The new session created by this function *should* not have this problem,
although (if I remember correctly) there is no 100% guarantee; still,
the chance of this problem appearing should be significantly decreased
with the new session.
"""
logging.info(f"Reconnecting driver")
manager.driver_close()
await manager.driver_connect()
logging.info(f"Driver reconnected")
cql = manager.cql
assert(cql)
return cql
async def get_token_ring_host_ids(manager: ManagerClient, srv: ServerInfo) -> set[str]:
"""Get the host IDs of normal token owners known by `srv`."""
token_endpoint_map = await manager.api.client.get_json("/storage_service/tokens_endpoint", srv.ip_addr)
normal_endpoints = {e["value"] for e in token_endpoint_map}
logger.info(f"Normal endpoints' IPs by {srv}: {normal_endpoints}")
host_id_map = await manager.api.client.get_json('/storage_service/host_id', srv.ip_addr)
all_host_ids = {e["value"] for e in host_id_map}
logger.info(f"All host IDs by {srv}: {all_host_ids}")
normal_host_ids = {e["value"] for e in host_id_map if e["key"] in normal_endpoints}
logger.info(f"Normal endpoints' host IDs by {srv}: {normal_host_ids}")
return normal_host_ids
async def get_current_group0_config(manager: ManagerClient, srv: ServerInfo) -> set[tuple[str, bool]]:
"""Get the current Raft group 0 configuration known by `srv`.
The first element of each tuple is the Raft ID of the node (which is equal to the Host ID),
the second element indicates whether the node is a voter.
"""
assert manager.cql
host = (await wait_for_cql_and_get_hosts(manager.cql, [srv], time.time() + 60))[0]
await read_barrier(manager.api, srv.ip_addr)
group0_id = (await manager.cql.run_async(
"select value from system.scylla_local where key = 'raft_group0_id'",
host=host))[0].value
config = await manager.cql.run_async(
f"select server_id, can_vote from system.raft_state where group_id = {group0_id} and disposition = 'CURRENT'",
host=host)
result = {(str(m.server_id), bool(m.can_vote)) for m in config}
logger.info(f"Group 0 members by {srv}: {result}")
return result
async def get_topology_coordinator(manager: ManagerClient) -> HostID:
"""Get the host ID of the topology coordinator."""
host = await get_available_host(manager.cql, time.time() + 60)
host_address = get_host_api_address(host)
await read_barrier(manager.api, host_address)
return await manager.api.get_raft_leader(host_address)
async def get_topology_version(cql: Session, host: Host) -> int:
rows = await cql.run_async(
"select version from system.topology where key = 'topology'",
host=host)
return rows[0].version
async def find_server_by_host_id(manager: ManagerClient, servers: List[ServerInfo], host_id: HostID) -> ServerInfo:
for s in servers:
if await manager.get_host_id(s.server_id) == host_id:
return s
raise Exception(f"Host ID {host_id} not found in {servers}")
async def check_token_ring_and_group0_consistency(manager: ManagerClient) -> None:
"""Ensure that the normal token owners and group 0 members match
according to each currently running server.
Note that the normal token owners and group 0 members never match
in the presence of zero-token nodes.
"""
servers = await manager.running_servers()
for srv in servers:
group0_members = await get_current_group0_config(manager, srv)
group0_ids = {m[0] for m in group0_members}
token_ring_ids = await get_token_ring_host_ids(manager, srv)
assert token_ring_ids == group0_ids
async def wait_for_token_ring_and_group0_consistency(manager: ManagerClient, deadline: float) -> None:
"""
Weaker version of the above check.
In the Raft-based topology, a decommissioning node is removed from group 0 after the decommission request is
considered finished (and the token ring is updated).
Moreover, in the gossip-based topology, the token ring is not immediately updated after
bootstrap/replace/decommission - the normal tokens propagate through gossip.
Take this into account and wait for the equality condition to hold, with a timeout.
"""
servers = await manager.running_servers()
for srv in servers:
async def token_ring_and_group0_match():
group0_members = await get_current_group0_config(manager, srv)
group0_ids = {m[0] for m in group0_members}
token_ring_ids = await get_token_ring_host_ids(manager, srv)
diff = token_ring_ids ^ group0_ids
if diff:
logger.warning(f"Group 0 members and token ring members don't yet match" \
f" according to {srv}, symmetric difference: {diff}")
return None
return True
await wait_for(token_ring_and_group0_match, deadline, period=.5)
async def delete_discovery_state_and_group0_id(cql: Session, host: Host) -> None:
await cql.run_async("truncate table system.discovery", host=host)
await cql.run_async("delete value from system.scylla_local where key = 'raft_group0_id'", host=host)
async def delete_raft_group_data(group_id: str, cql: Session, host: Host) -> None:
await cql.run_async(f'delete from system.raft where group_id = {group_id}', host=host)
await cql.run_async(f'delete from system.raft_snapshots where group_id = {group_id}', host=host)
await cql.run_async(f'delete from system.raft_snapshot_config where group_id = {group_id}', host=host)
async def wait_for_cdc_generations_publishing(cql: Session, hosts: list[Host], deadline: float):
for host in hosts:
async def all_generations_published():
topo_res = await cql.run_async("SELECT unpublished_cdc_generations FROM system.topology", host=host)
assert len(topo_res) != 0
unpublished_generations = topo_res[0].unpublished_cdc_generations
return unpublished_generations is None or len(unpublished_generations) == 0 or None
await wait_for(all_generations_published, deadline=deadline)
async def check_system_topology_and_cdc_generations_v3_consistency(manager: ManagerClient, live_hosts: list[Host], cqls: Optional[list[Session]] = None, ignored_hosts: list[Host] = []):
# The cqls parameter is a temporary workaround for testing the recovery mode in the presence of live zero-token
# nodes. A zero-token node requires a different cql session not to be ignored by the driver because of empty tokens
# in the system.peers table.
assert len(live_hosts) != 0
logging.info(f"Nodes that will be ignored by check_system_topology_and_cdc_generations_v3_consistency: {ignored_hosts}")
if cqls is None:
cqls = [manager.cql] * len(live_hosts)
live_host_ids = frozenset(host.host_id for host in live_hosts)
ignored_host_ids = frozenset(host.host_id for host in ignored_hosts)
topo_results = await asyncio.gather(*(cql.run_async("SELECT * FROM system.topology", host=host) for cql, host in zip(cqls, live_hosts)))
for host, topo_res in zip(live_hosts, topo_results):
logging.info(f"Dumping the state of system.topology as seen by {host}:")
for row in topo_res:
logging.info(f" {row}")
for cql, host, topo_res in zip(cqls, live_hosts, topo_results):
assert len(topo_res) != 0
for row in topo_res:
num_tokens = 0 if row.tokens is None else len(row.tokens)
if row.host_id in live_host_ids:
assert row.datacenter is not None
assert row.ignore_msb is not None
assert row.node_state == "normal"
assert row.num_tokens is not None
assert row.rack is not None
assert row.release_version is not None
assert row.supported_features is not None
assert row.shard_count is not None
assert num_tokens == row.num_tokens
else:
assert row.host_id is not None
assert row.host_id in ignored_host_ids
assert row.node_state == "left"
assert num_tokens == 0
live_topo_res = [row for row in topo_res if row.host_id in live_host_ids]
assert live_topo_res[0].committed_cdc_generations is not None
committed_generations = frozenset(gen[1] for gen in live_topo_res[0].committed_cdc_generations)
assert live_topo_res[0].fence_version is not None
assert live_topo_res[0].upgrade_state == "done"
assert live_host_ids == frozenset(row.host_id for row in live_topo_res)
computed_enabled_features = functools.reduce(operator.and_, (frozenset(row.supported_features) for row in live_topo_res))
assert live_topo_res[0].enabled_features is not None
enabled_features = frozenset(live_topo_res[0].enabled_features)
assert enabled_features == computed_enabled_features
assert "SUPPORTS_CONSISTENT_TOPOLOGY_CHANGES" in enabled_features
cdc_res = await cql.run_async("SELECT * FROM system.cdc_generations_v3", host=host)
assert len(cdc_res) != 0
all_generations = frozenset(row.id for row in cdc_res)
assert committed_generations.issubset(all_generations)
# Check that the contents fetched from the current host are the same as for other nodes
# (ignoring the dead rows as the orphan node remover might have removed them
# at a different time on different nodes).
assert [row for row in topo_results[0] if row.host_id in live_host_ids] == live_topo_res
async def check_node_log_for_failed_mutations(manager: ManagerClient, server: ServerInfo):
logging.info(f"Checking that node {server} had no failed mutations")
log = await manager.server_open_log(server.server_id)
occurrences = await log.grep(expr="Failed to apply mutation from", filter_expr="(TRACE|DEBUG|INFO)")
assert len(occurrences) == 0
async def start_writes(cql: Session, rf: int, cl: ConsistencyLevel, concurrency: int = 3,
ks_name: Optional[str] = None, node_shutdowns: bool = False):
logging.info(f"Starting to asynchronously write, concurrency = {concurrency}")
stop_event = asyncio.Event()
if ks_name is None:
ks_name = unique_name()
await cql.run_async(f"CREATE KEYSPACE IF NOT EXISTS {ks_name} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {rf}}}")
await cql.run_async(f"USE {ks_name}")
await cql.run_async(f"CREATE TABLE IF NOT EXISTS tbl (pk int PRIMARY KEY, v int)")
# In the test we only care about whether operations report success or not
# and whether they trigger errors in the nodes' logs. Inserting the same
# value repeatedly is enough for our purposes.
stmt = SimpleStatement("INSERT INTO tbl (pk, v) VALUES (0, 0)", consistency_level=cl)
async def do_writes(worker_id: int):
write_count = 0
while not stop_event.is_set():
start_time = time.time()
try:
await cql.run_async(stmt)
write_count += 1
except NoHostAvailable as e:
for _, err in e.errors.items():
# ConnectionException can be raised when the node is shutting down.
if not node_shutdowns or not isinstance(err, ConnectionException):
logger.error(f"Write started {time.time() - start_time}s ago failed: {e}")
raise
except Exception as e:
logging.error(f"Write started {time.time() - start_time}s ago failed: {e}")
raise
await asyncio.sleep(0.01)
logging.info(f"Worker #{worker_id} did {write_count} successful writes")
tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)]
async def finish():
logging.info("Stopping write workers")
stop_event.set()
await asyncio.gather(*tasks)
return finish
async def trigger_snapshot(manager, server: ServerInfo) -> None:
cql = manager.get_cql()
group0_id = (await cql.run_async(
"select value from system.scylla_local where key = 'raft_group0_id'"))[0].value
host = cql.cluster.metadata.get_host(server.ip_addr)
await manager.api.client.post(f"/raft/trigger_snapshot/{group0_id}", host=server.ip_addr)
async def trigger_stepdown(manager, server: ServerInfo) -> None:
cql = manager.get_cql()
host = cql.cluster.metadata.get_host(server.ip_addr)
await manager.api.client.post("/raft/trigger_stepdown", host=server.ip_addr)
async def get_coordinator_host_ids(manager: ManagerClient) -> list[str]:
""" Get coordinator host id from history
Select all records with elected coordinator
from description column in system.group0_history table and
return list of coordinator host ids, where
first element in list is active coordinator
"""
stm = SimpleStatement("select description from system.group0_history "
"where key = 'history' and description LIKE 'Starting new topology coordinator%' ALLOW FILTERING;")
cql = manager.get_cql()
result = await cql.run_async(stm)
coordinators_ids = []
for row in result:
coordinator_host_id = get_uuid_from_str(row.description)
if coordinator_host_id:
coordinators_ids.append(coordinator_host_id)
assert len(coordinators_ids) > 0, f"No coordinator ids {coordinators_ids} were found"
return coordinators_ids
async def get_coordinator_host(manager: ManagerClient) -> ServerInfo:
"""Get coordinator ServerInfo"""
coordinator_host_id = (await get_coordinator_host_ids(manager))[0]
host_ids = []
for s_info in await manager.running_servers():
with suppress(Exception):
host_ids.append(await manager.get_host_id(s_info.server_id))
if host_ids[-1] == coordinator_host_id:
return s_info
raise AssertionError(f"Node with host id {coordinator_host_id} was not found in cluster host ids {host_ids}")
async def ensure_group0_leader_on(manager: ManagerClient, server: ServerInfo, timeout_seconds = 60):
"""
Ensure that raft group0 leader runs on a given server, triggering stepdowns if necessary.
Assumes that servers are not added concurrently.
"""
deadline = time.time() + timeout_seconds
servers_by_host = await manager.all_servers_by_host_id()
desired_host_id = await manager.get_host_id(server.server_id)
while True:
if time.time() > deadline:
raise RuntimeError(f"timed out")
await read_barrier(manager.api, server.ip_addr)
coord = await manager.api.get_raft_leader(server.ip_addr)
if coord == desired_host_id:
break
if not coord:
logger.info("no leader")
continue
logger.info(f"group0 leader is {coord}, want {desired_host_id}")
coord_host = servers_by_host[coord]
logger.info(f"triggering stepdown of {coord}/{coord_host.ip_addr}")
await manager.api.client.post("/raft/trigger_stepdown", host=coord_host.ip_addr)
async def get_non_coordinator_host(manager: ManagerClient) -> ServerInfo | None:
"""Get first non-coordinator ServerInfo."""
coordinator_id = (await get_coordinator_host(manager=manager)).server_id
return next((s_info for s_info in await manager.running_servers() if s_info.server_id != coordinator_id), None)
def get_uuid_from_str(string: str) -> str:
"""Search uuid in string"""
uuid = ""
if match := UUID_REGEX.search(string):
uuid = match.group(1)
return uuid
async def wait_new_coordinator_elected(manager: ManagerClient, expected_num_of_elections: int, deadline: float) -> None:
"""Wait new coordinator to be elected
Wait while the table 'system.group0_history' will have a number of lines
with the 'new topology coordinator' equal to the expected_num_of_elections number,
and the latest host_id coordinator differs from the previous one.
"""
async def new_coordinator_elected():
coordinators_ids = await get_coordinator_host_ids(manager)
if len(coordinators_ids) == expected_num_of_elections \
and coordinators_ids[0] != coordinators_ids[1]:
return True
logger.warning("New coordinator was not elected %s", coordinators_ids)
await wait_for(new_coordinator_elected, deadline=deadline)
async def create_new_test_keyspace(cql: Session, opts, host=None):
"""
A utility function for creating a new temporary keyspace with given
options.
"""
keyspace = unique_name()
# Use CREATE KEYSPACE IF NOT EXISTS as a workaround for
# https://github.com/scylladb/python-driver/issues/317
await cql.run_async(f"CREATE KEYSPACE IF NOT EXISTS {keyspace} {opts}", host=host)
return keyspace
@asynccontextmanager
async def new_test_keyspace(manager: ManagerClient, opts, host=None):
"""
A utility function for creating a new temporary keyspace with given
options. It can be used in a "async with", as:
async with new_test_keyspace(ManagerClient, '...') as keyspace:
"""
keyspace = await create_new_test_keyspace(manager.get_cql(), opts, host)
try:
yield keyspace
except:
logger.info(f"Error happened while using keyspace '{keyspace}', the keyspace is left in place for investigation")
raise
else:
await manager.get_cql().run_async("DROP KEYSPACE " + keyspace, host=host)
previously_used_table_names = []
@asynccontextmanager
async def new_test_table(manager: ManagerClient, keyspace, schema, extra="", host=None, reuse_tables=True):
"""
A utility function for creating a new temporary table with a given schema.
Because Scylla becomes slower when a huge number of uniquely-named tables
are created and deleted (see https://github.com/scylladb/scylla/issues/7620)
we keep here a list of previously used but now deleted table names, and
reuse one of these names when possible.
This function can be used in a "async with", as:
async with create_table(cql, test_keyspace, '...') as table:
"""
global previously_used_table_names
if reuse_tables:
if not previously_used_table_names:
previously_used_table_names.append(unique_name())
table_name = previously_used_table_names.pop()
else:
table_name = unique_name()
table = keyspace + "." + table_name
await manager.get_cql().run_async("CREATE TABLE " + table + "(" + schema + ")" + extra, host=host)
try:
yield table
finally:
await manager.get_cql().run_async("DROP TABLE " + table, host=host)
if reuse_tables:
previously_used_table_names.append(table_name)
@asynccontextmanager
async def new_materialized_view(manager: ManagerClient, table, select, pk, where, extra=""):
"""
A utility function for creating a new temporary materialized view in
an existing table.
"""
keyspace = table.split('.')[0]
mv = keyspace + "." + unique_name()
await manager.get_cql().run_async(f"CREATE MATERIALIZED VIEW {mv} AS SELECT {select} FROM {table} WHERE {where} PRIMARY KEY ({pk}) {extra}")
try:
yield mv
finally:
await manager.get_cql().run_async(f"DROP MATERIALIZED VIEW {mv}")
async def keyspace_has_tablets(manager: ManagerClient, keyspace: str) -> bool:
"""
Checks whether the given keyspace uses tablets.
Adapted from its counterpart in the cqlpy test: cqlpy/util.py::keyspace_has_tablets.
"""
cql = manager.get_cql()
rows_iter = await cql.run_async(f"SELECT * FROM system_schema.scylla_keyspaces WHERE keyspace_name='{keyspace}'")
rows = list(rows_iter)
return len(rows) > 0 and getattr(rows[0], "initial_tablets", None) is not None
async def get_raft_log_size(cql, host) -> int:
query = "select count(\"index\") from system.raft"
return (await cql.run_async(query, host=host))[0][0]
async def get_raft_snap_id(cql, host) -> str:
query = "select snapshot_id from system.raft limit 1"
return (await cql.run_async(query, host=host))[0].snapshot_id
@contextmanager
def disable_schema_agreement_wait(cql: Session):
"""
A context manager that temporarily disables the schema agreement wait
for the given cql session.
"""
assert hasattr(cql.cluster, "max_schema_agreement_wait")
old_value = cql.cluster.max_schema_agreement_wait
cql.cluster.max_schema_agreement_wait = 0
try:
yield
finally:
cql.cluster.max_schema_agreement_wait = old_value
ReplicationOption = Union[str, List[str]]
ReplicationOptions = dict[str, ReplicationOption]
def parse_replication_options(replication_column) -> ReplicationOptions:
"""
Parses the value of "replication_v2" or "replication" column from system_schema.keyspaces,
which is a flattened map of options, into an expanded map.
Expands a flattened map like {"dc0:0": "r1", "dc0:1": "r2"} into {"dc0": ["r1", "r2"]}.
See docs/dev/system_schema_keyspace.md
"""
result = {}
for key, value in replication_column.items():
if ':' in key:
sub_key, index_str = key.split(':', 1)
if sub_key not in result:
result[sub_key] = []
index = int(index_str)
while len(result[sub_key]) <= index:
result[sub_key].append(None)
if index >= 0:
result[sub_key][index] = value
else:
result[key] = value
return result
def get_replication(cql, keyspace) -> ReplicationOptions:
"""
Returns replication options for a given keyspace.
Example result: {"dc1": "2", "dc2": ["rack1", "rack2"]}
"""
row = cql.execute(f"SELECT replication, replication_v2 FROM system_schema.keyspaces WHERE keyspace_name='{keyspace}'").one()
return parse_replication_options(row.replication_v2 or row.replication)
def get_replica_count(rf: ReplicationOption) -> int:
"""
Returns replica count corresponding to the given replication option.
Examples:
get_replica_count(["rack1", "rack2"]) == 2
get_replica_count(["2"]) == 2
"""
return len(rf) if type(rf) is list else int(rf)