Change wait_for() defaults from period=1s/no backoff to period=0.1s with 1.5x backoff capped at 1.0s. This catches fast conditions in 100ms instead of 1000ms, benefiting ~100 call sites automatically. Add completion logging with elapsed time and iteration count. Tested local with test/cluster/test_fencing.py::test_fence_hints (dev mode), log output: wait_for(at_least_one_hint_failed) completed in 0.83s (4 iterations) wait_for(exactly_one_hint_sent) completed in 1.34s (5 iterations) Fixes SCYLLADB-738 Closes scylladb/scylladb#29173
555 lines
24 KiB
Python
555 lines
24 KiB
Python
#
|
|
# Copyright (C) 2022-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
#
|
|
"""
|
|
Test consistency of schema changes with topology changes.
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
import functools
|
|
import operator
|
|
import time
|
|
import re
|
|
from contextlib import asynccontextmanager, contextmanager, suppress
|
|
|
|
from cassandra.cluster import ConnectionException, ConsistencyLevel, NoHostAvailable, Session, SimpleStatement # type: ignore # pylint: disable=no-name-in-module
|
|
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
|
|
from test.pylib.internal_types import ServerInfo, HostID
|
|
from test.pylib.manager_client import ManagerClient
|
|
from test.pylib.rest_client import get_host_api_address, read_barrier
|
|
from test.pylib.util import wait_for, wait_for_cql_and_get_hosts, get_available_host, unique_name
|
|
from typing import Optional, List, Union
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
UUID_REGEX = re.compile(r"([0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12})")
|
|
|
|
|
|
async def reconnect_driver(manager: ManagerClient) -> Session:
|
|
"""Can be used as a workaround for scylladb/python-driver#295.
|
|
|
|
When restarting a node, a pre-existing session connected to the cluster
|
|
may reconnect to the restarted node multiple times. Even if we verify
|
|
that the session can perform a query on that node (e.g. like `wait_for_cql`,
|
|
which tries to select from system.local), the driver may again reconnect
|
|
after that, and following queries may fail.
|
|
|
|
The new session created by this function *should* not have this problem,
|
|
although (if I remember correctly) there is no 100% guarantee; still,
|
|
the chance of this problem appearing should be significantly decreased
|
|
with the new session.
|
|
"""
|
|
logging.info(f"Reconnecting driver")
|
|
manager.driver_close()
|
|
await manager.driver_connect()
|
|
logging.info(f"Driver reconnected")
|
|
cql = manager.cql
|
|
assert(cql)
|
|
return cql
|
|
|
|
|
|
async def get_token_ring_host_ids(manager: ManagerClient, srv: ServerInfo) -> set[str]:
|
|
"""Get the host IDs of normal token owners known by `srv`."""
|
|
token_endpoint_map = await manager.api.client.get_json("/storage_service/tokens_endpoint", srv.ip_addr)
|
|
normal_endpoints = {e["value"] for e in token_endpoint_map}
|
|
logger.info(f"Normal endpoints' IPs by {srv}: {normal_endpoints}")
|
|
host_id_map = await manager.api.client.get_json('/storage_service/host_id', srv.ip_addr)
|
|
all_host_ids = {e["value"] for e in host_id_map}
|
|
logger.info(f"All host IDs by {srv}: {all_host_ids}")
|
|
normal_host_ids = {e["value"] for e in host_id_map if e["key"] in normal_endpoints}
|
|
logger.info(f"Normal endpoints' host IDs by {srv}: {normal_host_ids}")
|
|
return normal_host_ids
|
|
|
|
|
|
async def get_current_group0_config(manager: ManagerClient, srv: ServerInfo) -> set[tuple[str, bool]]:
|
|
"""Get the current Raft group 0 configuration known by `srv`.
|
|
The first element of each tuple is the Raft ID of the node (which is equal to the Host ID),
|
|
the second element indicates whether the node is a voter.
|
|
"""
|
|
assert manager.cql
|
|
host = (await wait_for_cql_and_get_hosts(manager.cql, [srv], time.time() + 60))[0]
|
|
await read_barrier(manager.api, srv.ip_addr)
|
|
group0_id = (await manager.cql.run_async(
|
|
"select value from system.scylla_local where key = 'raft_group0_id'",
|
|
host=host))[0].value
|
|
config = await manager.cql.run_async(
|
|
f"select server_id, can_vote from system.raft_state where group_id = {group0_id} and disposition = 'CURRENT'",
|
|
host=host)
|
|
result = {(str(m.server_id), bool(m.can_vote)) for m in config}
|
|
logger.info(f"Group 0 members by {srv}: {result}")
|
|
return result
|
|
|
|
|
|
async def get_topology_coordinator(manager: ManagerClient) -> HostID:
|
|
"""Get the host ID of the topology coordinator."""
|
|
host = await get_available_host(manager.cql, time.time() + 60)
|
|
host_address = get_host_api_address(host)
|
|
await read_barrier(manager.api, host_address)
|
|
return await manager.api.get_raft_leader(host_address)
|
|
|
|
|
|
async def get_topology_version(cql: Session, host: Host) -> int:
|
|
rows = await cql.run_async(
|
|
"select version from system.topology where key = 'topology'",
|
|
host=host)
|
|
return rows[0].version
|
|
|
|
|
|
async def find_server_by_host_id(manager: ManagerClient, servers: List[ServerInfo], host_id: HostID) -> ServerInfo:
|
|
for s in servers:
|
|
if await manager.get_host_id(s.server_id) == host_id:
|
|
return s
|
|
raise Exception(f"Host ID {host_id} not found in {servers}")
|
|
|
|
|
|
async def check_token_ring_and_group0_consistency(manager: ManagerClient) -> None:
|
|
"""Ensure that the normal token owners and group 0 members match
|
|
according to each currently running server.
|
|
|
|
Note that the normal token owners and group 0 members never match
|
|
in the presence of zero-token nodes.
|
|
"""
|
|
servers = await manager.running_servers()
|
|
for srv in servers:
|
|
group0_members = await get_current_group0_config(manager, srv)
|
|
group0_ids = {m[0] for m in group0_members}
|
|
token_ring_ids = await get_token_ring_host_ids(manager, srv)
|
|
assert token_ring_ids == group0_ids
|
|
|
|
|
|
async def wait_for_token_ring_and_group0_consistency(manager: ManagerClient, deadline: float) -> None:
|
|
"""
|
|
Weaker version of the above check.
|
|
|
|
In the Raft-based topology, a decommissioning node is removed from group 0 after the decommission request is
|
|
considered finished (and the token ring is updated).
|
|
|
|
Moreover, in the gossip-based topology, the token ring is not immediately updated after
|
|
bootstrap/replace/decommission - the normal tokens propagate through gossip.
|
|
|
|
Take this into account and wait for the equality condition to hold, with a timeout.
|
|
"""
|
|
servers = await manager.running_servers()
|
|
for srv in servers:
|
|
async def token_ring_and_group0_match():
|
|
group0_members = await get_current_group0_config(manager, srv)
|
|
group0_ids = {m[0] for m in group0_members}
|
|
token_ring_ids = await get_token_ring_host_ids(manager, srv)
|
|
diff = token_ring_ids ^ group0_ids
|
|
if diff:
|
|
logger.warning(f"Group 0 members and token ring members don't yet match" \
|
|
f" according to {srv}, symmetric difference: {diff}")
|
|
return None
|
|
return True
|
|
await wait_for(token_ring_and_group0_match, deadline, period=.5)
|
|
|
|
|
|
async def delete_discovery_state_and_group0_id(cql: Session, host: Host) -> None:
|
|
await cql.run_async("truncate table system.discovery", host=host)
|
|
await cql.run_async("delete value from system.scylla_local where key = 'raft_group0_id'", host=host)
|
|
|
|
|
|
async def delete_raft_group_data(group_id: str, cql: Session, host: Host) -> None:
|
|
await cql.run_async(f'delete from system.raft where group_id = {group_id}', host=host)
|
|
await cql.run_async(f'delete from system.raft_snapshots where group_id = {group_id}', host=host)
|
|
await cql.run_async(f'delete from system.raft_snapshot_config where group_id = {group_id}', host=host)
|
|
|
|
|
|
async def wait_for_cdc_generations_publishing(cql: Session, hosts: list[Host], deadline: float):
|
|
for host in hosts:
|
|
async def all_generations_published():
|
|
topo_res = await cql.run_async("SELECT unpublished_cdc_generations FROM system.topology", host=host)
|
|
assert len(topo_res) != 0
|
|
unpublished_generations = topo_res[0].unpublished_cdc_generations
|
|
return unpublished_generations is None or len(unpublished_generations) == 0 or None
|
|
|
|
await wait_for(all_generations_published, deadline=deadline)
|
|
|
|
|
|
async def check_system_topology_and_cdc_generations_v3_consistency(manager: ManagerClient, live_hosts: list[Host], cqls: Optional[list[Session]] = None, ignored_hosts: list[Host] = []):
|
|
# The cqls parameter is a temporary workaround for testing the recovery mode in the presence of live zero-token
|
|
# nodes. A zero-token node requires a different cql session not to be ignored by the driver because of empty tokens
|
|
# in the system.peers table.
|
|
assert len(live_hosts) != 0
|
|
|
|
logging.info(f"Nodes that will be ignored by check_system_topology_and_cdc_generations_v3_consistency: {ignored_hosts}")
|
|
|
|
if cqls is None:
|
|
cqls = [manager.cql] * len(live_hosts)
|
|
|
|
live_host_ids = frozenset(host.host_id for host in live_hosts)
|
|
ignored_host_ids = frozenset(host.host_id for host in ignored_hosts)
|
|
|
|
topo_results = await asyncio.gather(*(cql.run_async("SELECT * FROM system.topology", host=host) for cql, host in zip(cqls, live_hosts)))
|
|
|
|
for host, topo_res in zip(live_hosts, topo_results):
|
|
logging.info(f"Dumping the state of system.topology as seen by {host}:")
|
|
for row in topo_res:
|
|
logging.info(f" {row}")
|
|
|
|
for cql, host, topo_res in zip(cqls, live_hosts, topo_results):
|
|
assert len(topo_res) != 0
|
|
|
|
for row in topo_res:
|
|
num_tokens = 0 if row.tokens is None else len(row.tokens)
|
|
if row.host_id in live_host_ids:
|
|
assert row.datacenter is not None
|
|
assert row.ignore_msb is not None
|
|
assert row.node_state == "normal"
|
|
assert row.num_tokens is not None
|
|
assert row.rack is not None
|
|
assert row.release_version is not None
|
|
assert row.supported_features is not None
|
|
assert row.shard_count is not None
|
|
assert num_tokens == row.num_tokens
|
|
else:
|
|
assert row.host_id is not None
|
|
assert row.host_id in ignored_host_ids
|
|
assert row.node_state == "left"
|
|
assert num_tokens == 0
|
|
|
|
live_topo_res = [row for row in topo_res if row.host_id in live_host_ids]
|
|
|
|
assert live_topo_res[0].committed_cdc_generations is not None
|
|
committed_generations = frozenset(gen[1] for gen in live_topo_res[0].committed_cdc_generations)
|
|
|
|
assert live_topo_res[0].fence_version is not None
|
|
assert live_topo_res[0].upgrade_state == "done"
|
|
|
|
assert live_host_ids == frozenset(row.host_id for row in live_topo_res)
|
|
|
|
computed_enabled_features = functools.reduce(operator.and_, (frozenset(row.supported_features) for row in live_topo_res))
|
|
assert live_topo_res[0].enabled_features is not None
|
|
enabled_features = frozenset(live_topo_res[0].enabled_features)
|
|
assert enabled_features == computed_enabled_features
|
|
assert "SUPPORTS_CONSISTENT_TOPOLOGY_CHANGES" in enabled_features
|
|
|
|
cdc_res = await cql.run_async("SELECT * FROM system.cdc_generations_v3", host=host)
|
|
assert len(cdc_res) != 0
|
|
|
|
all_generations = frozenset(row.id for row in cdc_res)
|
|
assert committed_generations.issubset(all_generations)
|
|
|
|
# Check that the contents fetched from the current host are the same as for other nodes
|
|
# (ignoring the dead rows as the orphan node remover might have removed them
|
|
# at a different time on different nodes).
|
|
assert [row for row in topo_results[0] if row.host_id in live_host_ids] == live_topo_res
|
|
|
|
async def check_node_log_for_failed_mutations(manager: ManagerClient, server: ServerInfo):
|
|
logging.info(f"Checking that node {server} had no failed mutations")
|
|
log = await manager.server_open_log(server.server_id)
|
|
occurrences = await log.grep(expr="Failed to apply mutation from", filter_expr="(TRACE|DEBUG|INFO)")
|
|
assert len(occurrences) == 0
|
|
|
|
|
|
async def start_writes(cql: Session, rf: int, cl: ConsistencyLevel, concurrency: int = 3,
|
|
ks_name: Optional[str] = None, node_shutdowns: bool = False):
|
|
logging.info(f"Starting to asynchronously write, concurrency = {concurrency}")
|
|
|
|
stop_event = asyncio.Event()
|
|
|
|
if ks_name is None:
|
|
ks_name = unique_name()
|
|
await cql.run_async(f"CREATE KEYSPACE IF NOT EXISTS {ks_name} WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {rf}}}")
|
|
await cql.run_async(f"USE {ks_name}")
|
|
await cql.run_async(f"CREATE TABLE IF NOT EXISTS tbl (pk int PRIMARY KEY, v int)")
|
|
|
|
# In the test we only care about whether operations report success or not
|
|
# and whether they trigger errors in the nodes' logs. Inserting the same
|
|
# value repeatedly is enough for our purposes.
|
|
stmt = SimpleStatement("INSERT INTO tbl (pk, v) VALUES (0, 0)", consistency_level=cl)
|
|
|
|
async def do_writes(worker_id: int):
|
|
write_count = 0
|
|
while not stop_event.is_set():
|
|
start_time = time.time()
|
|
try:
|
|
await cql.run_async(stmt)
|
|
write_count += 1
|
|
except NoHostAvailable as e:
|
|
for _, err in e.errors.items():
|
|
# ConnectionException can be raised when the node is shutting down.
|
|
if not node_shutdowns or not isinstance(err, ConnectionException):
|
|
logger.error(f"Write started {time.time() - start_time}s ago failed: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logging.error(f"Write started {time.time() - start_time}s ago failed: {e}")
|
|
raise
|
|
await asyncio.sleep(0.01)
|
|
logging.info(f"Worker #{worker_id} did {write_count} successful writes")
|
|
|
|
tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)]
|
|
|
|
async def finish():
|
|
logging.info("Stopping write workers")
|
|
stop_event.set()
|
|
await asyncio.gather(*tasks)
|
|
|
|
return finish
|
|
|
|
|
|
|
|
async def trigger_snapshot(manager, server: ServerInfo) -> None:
|
|
cql = manager.get_cql()
|
|
group0_id = (await cql.run_async(
|
|
"select value from system.scylla_local where key = 'raft_group0_id'"))[0].value
|
|
|
|
host = cql.cluster.metadata.get_host(server.ip_addr)
|
|
await manager.api.client.post(f"/raft/trigger_snapshot/{group0_id}", host=server.ip_addr)
|
|
|
|
async def trigger_stepdown(manager, server: ServerInfo) -> None:
|
|
cql = manager.get_cql()
|
|
host = cql.cluster.metadata.get_host(server.ip_addr)
|
|
await manager.api.client.post("/raft/trigger_stepdown", host=server.ip_addr)
|
|
|
|
|
|
|
|
async def get_coordinator_host_ids(manager: ManagerClient) -> list[str]:
|
|
""" Get coordinator host id from history
|
|
|
|
Select all records with elected coordinator
|
|
from description column in system.group0_history table and
|
|
return list of coordinator host ids, where
|
|
first element in list is active coordinator
|
|
"""
|
|
stm = SimpleStatement("select description from system.group0_history "
|
|
"where key = 'history' and description LIKE 'Starting new topology coordinator%' ALLOW FILTERING;")
|
|
|
|
cql = manager.get_cql()
|
|
result = await cql.run_async(stm)
|
|
coordinators_ids = []
|
|
for row in result:
|
|
coordinator_host_id = get_uuid_from_str(row.description)
|
|
if coordinator_host_id:
|
|
coordinators_ids.append(coordinator_host_id)
|
|
assert len(coordinators_ids) > 0, f"No coordinator ids {coordinators_ids} were found"
|
|
return coordinators_ids
|
|
|
|
|
|
async def get_coordinator_host(manager: ManagerClient) -> ServerInfo:
|
|
"""Get coordinator ServerInfo"""
|
|
|
|
coordinator_host_id = (await get_coordinator_host_ids(manager))[0]
|
|
host_ids = []
|
|
for s_info in await manager.running_servers():
|
|
with suppress(Exception):
|
|
host_ids.append(await manager.get_host_id(s_info.server_id))
|
|
if host_ids[-1] == coordinator_host_id:
|
|
return s_info
|
|
raise AssertionError(f"Node with host id {coordinator_host_id} was not found in cluster host ids {host_ids}")
|
|
|
|
|
|
async def ensure_group0_leader_on(manager: ManagerClient, server: ServerInfo, timeout_seconds = 60):
|
|
"""
|
|
Ensure that raft group0 leader runs on a given server, triggering stepdowns if necessary.
|
|
Assumes that servers are not added concurrently.
|
|
"""
|
|
|
|
deadline = time.time() + timeout_seconds
|
|
servers_by_host = await manager.all_servers_by_host_id()
|
|
desired_host_id = await manager.get_host_id(server.server_id)
|
|
|
|
while True:
|
|
if time.time() > deadline:
|
|
raise RuntimeError(f"timed out")
|
|
|
|
await read_barrier(manager.api, server.ip_addr)
|
|
coord = await manager.api.get_raft_leader(server.ip_addr)
|
|
if coord == desired_host_id:
|
|
break
|
|
|
|
if not coord:
|
|
logger.info("no leader")
|
|
continue
|
|
|
|
logger.info(f"group0 leader is {coord}, want {desired_host_id}")
|
|
coord_host = servers_by_host[coord]
|
|
logger.info(f"triggering stepdown of {coord}/{coord_host.ip_addr}")
|
|
await manager.api.client.post("/raft/trigger_stepdown", host=coord_host.ip_addr)
|
|
|
|
async def get_non_coordinator_host(manager: ManagerClient) -> ServerInfo | None:
|
|
"""Get first non-coordinator ServerInfo."""
|
|
|
|
coordinator_id = (await get_coordinator_host(manager=manager)).server_id
|
|
return next((s_info for s_info in await manager.running_servers() if s_info.server_id != coordinator_id), None)
|
|
|
|
|
|
def get_uuid_from_str(string: str) -> str:
|
|
"""Search uuid in string"""
|
|
uuid = ""
|
|
if match := UUID_REGEX.search(string):
|
|
uuid = match.group(1)
|
|
return uuid
|
|
|
|
|
|
async def wait_new_coordinator_elected(manager: ManagerClient, expected_num_of_elections: int, deadline: float) -> None:
|
|
"""Wait new coordinator to be elected
|
|
|
|
Wait while the table 'system.group0_history' will have a number of lines
|
|
with the 'new topology coordinator' equal to the expected_num_of_elections number,
|
|
and the latest host_id coordinator differs from the previous one.
|
|
"""
|
|
async def new_coordinator_elected():
|
|
coordinators_ids = await get_coordinator_host_ids(manager)
|
|
if len(coordinators_ids) == expected_num_of_elections \
|
|
and coordinators_ids[0] != coordinators_ids[1]:
|
|
return True
|
|
logger.warning("New coordinator was not elected %s", coordinators_ids)
|
|
|
|
await wait_for(new_coordinator_elected, deadline=deadline)
|
|
|
|
async def create_new_test_keyspace(cql: Session, opts, host=None):
|
|
"""
|
|
A utility function for creating a new temporary keyspace with given
|
|
options.
|
|
"""
|
|
keyspace = unique_name()
|
|
# Use CREATE KEYSPACE IF NOT EXISTS as a workaround for
|
|
# https://github.com/scylladb/python-driver/issues/317
|
|
await cql.run_async(f"CREATE KEYSPACE IF NOT EXISTS {keyspace} {opts}", host=host)
|
|
return keyspace
|
|
|
|
@asynccontextmanager
|
|
async def new_test_keyspace(manager: ManagerClient, opts, host=None):
|
|
"""
|
|
A utility function for creating a new temporary keyspace with given
|
|
options. It can be used in a "async with", as:
|
|
async with new_test_keyspace(ManagerClient, '...') as keyspace:
|
|
"""
|
|
keyspace = await create_new_test_keyspace(manager.get_cql(), opts, host)
|
|
try:
|
|
yield keyspace
|
|
except:
|
|
logger.info(f"Error happened while using keyspace '{keyspace}', the keyspace is left in place for investigation")
|
|
raise
|
|
else:
|
|
await manager.get_cql().run_async("DROP KEYSPACE " + keyspace, host=host)
|
|
|
|
previously_used_table_names = []
|
|
@asynccontextmanager
|
|
async def new_test_table(manager: ManagerClient, keyspace, schema, extra="", host=None, reuse_tables=True):
|
|
"""
|
|
A utility function for creating a new temporary table with a given schema.
|
|
Because Scylla becomes slower when a huge number of uniquely-named tables
|
|
are created and deleted (see https://github.com/scylladb/scylla/issues/7620)
|
|
we keep here a list of previously used but now deleted table names, and
|
|
reuse one of these names when possible.
|
|
This function can be used in a "async with", as:
|
|
async with create_table(cql, test_keyspace, '...') as table:
|
|
"""
|
|
global previously_used_table_names
|
|
if reuse_tables:
|
|
if not previously_used_table_names:
|
|
previously_used_table_names.append(unique_name())
|
|
table_name = previously_used_table_names.pop()
|
|
else:
|
|
table_name = unique_name()
|
|
table = keyspace + "." + table_name
|
|
await manager.get_cql().run_async("CREATE TABLE " + table + "(" + schema + ")" + extra, host=host)
|
|
try:
|
|
yield table
|
|
finally:
|
|
await manager.get_cql().run_async("DROP TABLE " + table, host=host)
|
|
if reuse_tables:
|
|
previously_used_table_names.append(table_name)
|
|
|
|
@asynccontextmanager
|
|
async def new_materialized_view(manager: ManagerClient, table, select, pk, where, extra=""):
|
|
"""
|
|
A utility function for creating a new temporary materialized view in
|
|
an existing table.
|
|
"""
|
|
keyspace = table.split('.')[0]
|
|
mv = keyspace + "." + unique_name()
|
|
await manager.get_cql().run_async(f"CREATE MATERIALIZED VIEW {mv} AS SELECT {select} FROM {table} WHERE {where} PRIMARY KEY ({pk}) {extra}")
|
|
try:
|
|
yield mv
|
|
finally:
|
|
await manager.get_cql().run_async(f"DROP MATERIALIZED VIEW {mv}")
|
|
|
|
|
|
async def keyspace_has_tablets(manager: ManagerClient, keyspace: str) -> bool:
|
|
"""
|
|
Checks whether the given keyspace uses tablets.
|
|
Adapted from its counterpart in the cqlpy test: cqlpy/util.py::keyspace_has_tablets.
|
|
"""
|
|
cql = manager.get_cql()
|
|
rows_iter = await cql.run_async(f"SELECT * FROM system_schema.scylla_keyspaces WHERE keyspace_name='{keyspace}'")
|
|
rows = list(rows_iter)
|
|
return len(rows) > 0 and getattr(rows[0], "initial_tablets", None) is not None
|
|
|
|
|
|
async def get_raft_log_size(cql, host) -> int:
|
|
query = "select count(\"index\") from system.raft"
|
|
return (await cql.run_async(query, host=host))[0][0]
|
|
|
|
|
|
async def get_raft_snap_id(cql, host) -> str:
|
|
query = "select snapshot_id from system.raft limit 1"
|
|
return (await cql.run_async(query, host=host))[0].snapshot_id
|
|
|
|
|
|
@contextmanager
|
|
def disable_schema_agreement_wait(cql: Session):
|
|
"""
|
|
A context manager that temporarily disables the schema agreement wait
|
|
for the given cql session.
|
|
"""
|
|
assert hasattr(cql.cluster, "max_schema_agreement_wait")
|
|
old_value = cql.cluster.max_schema_agreement_wait
|
|
cql.cluster.max_schema_agreement_wait = 0
|
|
try:
|
|
yield
|
|
finally:
|
|
cql.cluster.max_schema_agreement_wait = old_value
|
|
|
|
|
|
ReplicationOption = Union[str, List[str]]
|
|
ReplicationOptions = dict[str, ReplicationOption]
|
|
|
|
|
|
def parse_replication_options(replication_column) -> ReplicationOptions:
|
|
"""
|
|
Parses the value of "replication_v2" or "replication" column from system_schema.keyspaces,
|
|
which is a flattened map of options, into an expanded map.
|
|
Expands a flattened map like {"dc0:0": "r1", "dc0:1": "r2"} into {"dc0": ["r1", "r2"]}.
|
|
See docs/dev/system_schema_keyspace.md
|
|
"""
|
|
result = {}
|
|
for key, value in replication_column.items():
|
|
if ':' in key:
|
|
sub_key, index_str = key.split(':', 1)
|
|
if sub_key not in result:
|
|
result[sub_key] = []
|
|
index = int(index_str)
|
|
while len(result[sub_key]) <= index:
|
|
result[sub_key].append(None)
|
|
if index >= 0:
|
|
result[sub_key][index] = value
|
|
else:
|
|
result[key] = value
|
|
return result
|
|
|
|
|
|
def get_replication(cql, keyspace) -> ReplicationOptions:
|
|
"""
|
|
Returns replication options for a given keyspace.
|
|
|
|
Example result: {"dc1": "2", "dc2": ["rack1", "rack2"]}
|
|
"""
|
|
row = cql.execute(f"SELECT replication, replication_v2 FROM system_schema.keyspaces WHERE keyspace_name='{keyspace}'").one()
|
|
return parse_replication_options(row.replication_v2 or row.replication)
|
|
|
|
|
|
def get_replica_count(rf: ReplicationOption) -> int:
|
|
"""
|
|
Returns replica count corresponding to the given replication option.
|
|
|
|
Examples:
|
|
get_replica_count(["rack1", "rack2"]) == 2
|
|
get_replica_count(["2"]) == 2
|
|
"""
|
|
return len(rf) if type(rf) is list else int(rf)
|