# # Copyright (C) 2022-present ScyllaDB # # SPDX-License-Identifier: AGPL-3.0-or-later # import threading import time import asyncio import logging import pathlib import os import pytest from typing import Callable, Awaitable, Optional, TypeVar, Any from cassandra.cluster import NoHostAvailable, Session, Cluster # type: ignore # pylint: disable=no-name-in-module from cassandra.protocol import InvalidRequest # type: ignore # pylint: disable=no-name-in-module from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module from cassandra import DriverException, ConsistencyLevel # type: ignore # pylint: disable=no-name-in-module from test.pylib.internal_types import ServerInfo logger = logging.getLogger(__name__) class LogPrefixAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return '[%s] %s' % (self.extra['prefix'], msg), kwargs unique_name_prefix = 'test_' T = TypeVar('T') def unique_name(): if not hasattr(unique_name, "last_ms"): unique_name.last_ms = 0 current_ms = int(round(time.time() * 1000)) # If unique_name() is called twice in the same millisecond... if unique_name.last_ms >= current_ms: current_ms = unique_name.last_ms + 1 unique_name.last_ms = current_ms return unique_name_prefix + str(current_ms) async def wait_for( pred: Callable[[], Awaitable[Optional[T]]], deadline: float, period: float = 1, before_retry: Optional[Callable[[], Any]] = None) -> T: while True: assert(time.time() < deadline), "Deadline exceeded, failing test." res = await pred() if res is not None: return res await asyncio.sleep(period) if before_retry: before_retry() async def wait_for_cql(cql: Session, host: Host, deadline: float) -> None: async def cql_ready(): try: await cql.run_async("select * from system.local", host=host) except NoHostAvailable: logging.info(f"Driver not connected to {host} yet") return None return True await wait_for(cql_ready, deadline) async def wait_for_cql_and_get_hosts(cql: Session, servers: list[ServerInfo], deadline: float) \ -> list[Host]: """Wait until every server in `servers` is available through `cql` and translate `servers` to a list of Cassandra `Host`s. """ ip_set = set(str(srv.rpc_address) for srv in servers) async def get_hosts() -> Optional[list[Host]]: hosts = cql.cluster.metadata.all_hosts() remaining = ip_set - {h.address for h in hosts} if not remaining: return hosts logging.info(f"Driver hasn't yet learned about hosts: {remaining}") return None def try_refresh_nodes(): try: cql.cluster.refresh_nodes(force_token_rebuild=True) except DriverException: # Silence the exception, which might get thrown if we call this in the middle of # driver reconnect (scylladb/scylladb#17616). `wait_for` will retry anyway and it's enough # if we succeed only one `get_hosts()` attempt before timing out. pass hosts = await wait_for( pred=get_hosts, deadline=deadline, before_retry=try_refresh_nodes, ) # Take only hosts from `ip_set` (there may be more) hosts = [h for h in hosts if h.address in ip_set] await asyncio.gather(*(wait_for_cql(cql, h, deadline) for h in hosts)) return hosts def read_last_line(file_path: pathlib.Path, max_line_bytes = 512): file_size = os.stat(file_path).st_size with file_path.open('rb') as f: f.seek(max(0, file_size - max_line_bytes), os.SEEK_SET) line_bytes = f.read() line_str = line_bytes.decode('utf-8', errors='ignore') linesep = os.linesep if line_str.endswith(linesep): line_str = line_str[:-len(linesep)] linesep_index = line_str.rfind(linesep) if linesep_index != -1: line_str = line_str[linesep_index + len(linesep):] elif file_size > max_line_bytes: line_str = '...' + line_str return line_str async def get_available_host(cql: Session, deadline: float) -> Host: hosts = cql.cluster.metadata.all_hosts() async def find_host(): for h in hosts: try: await cql.run_async( "select key from system.local where key = 'local'", host=h) except NoHostAvailable: logging.debug(f"get_available_host: {h} not available") continue return h return None return await wait_for(find_host, deadline) async def read_barrier(cql: Session, host: Host): """To issue a read barrier it is sufficient to attempt dropping a non-existing table. We need to use `if exists`, otherwise the statement would fail on prepare/validate step which happens before a read barrier is performed. """ await cql.run_async("drop table if exists nosuchkeyspace.nosuchtable", host = host) # Wait for the given feature to be enabled. async def wait_for_feature(feature: str, cql: Session, host: Host, deadline: float) -> None: async def feature_is_enabled(): enabled_features = await get_enabled_features(cql, host) return feature in enabled_features or None await wait_for(feature_is_enabled, deadline) async def get_supported_features(cql: Session, host: Host) -> set[str]: """Returns a set of cluster features that a node advertises support for.""" rs = await cql.run_async(f"SELECT supported_features FROM system.local WHERE key = 'local'", host=host) return set(rs[0].supported_features.split(",")) async def get_enabled_features(cql: Session, host: Host) -> set[str]: """Returns a set of cluster features that a node considers to be enabled.""" rs = await cql.run_async(f"SELECT value FROM system.scylla_local WHERE key = 'enabled_features'", host=host) return set(rs[0].value.split(",")) class KeyGenerator: def __init__(self): self.pk = None self.pk_lock = threading.Lock() def next_pk(self): with self.pk_lock: if self.pk is not None: self.pk += 1 else: self.pk = 0 return self.pk def last_pk(self): with self.pk_lock: return self.pk async def start_writes(cql: Session, keyspace: str, table: str, concurrency: int = 3, ignore_errors=False): logger.info(f"Starting to asynchronously write, concurrency = {concurrency}") stop_event = asyncio.Event() warmup_writes = 128 // concurrency warmup_event = asyncio.Event() stmt = cql.prepare(f"INSERT INTO {keyspace}.{table} (pk, c) VALUES (?, ?)") stmt.consistency_level = ConsistencyLevel.QUORUM rd_stmt = cql.prepare(f"SELECT * FROM {keyspace}.{table} WHERE pk = ?") rd_stmt.consistency_level = ConsistencyLevel.QUORUM key_gen = KeyGenerator() async def do_writes(worker_id: int): write_count = 0 while not stop_event.is_set(): pk = key_gen.next_pk() # Once next_pk() is produced, key_gen.last_key() is assumed to be in the database # hence we can't give up on it. while True: try: await cql.run_async(stmt, [pk, pk]) # Check read-your-writes rows = await cql.run_async(rd_stmt, [pk]) assert(len(rows) == 1) assert(rows[0].c == pk) write_count += 1 break except Exception as e: if ignore_errors: pass # Expected when node is brought down temporarily else: raise e if pk == warmup_writes: warmup_event.set() logger.info(f"Worker #{worker_id} did {write_count} successful writes") tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)] await asyncio.wait_for(warmup_event.wait(), timeout=60) async def finish(): logger.info("Stopping workers") stop_event.set() await asyncio.gather(*tasks) last = key_gen.last_pk() if last is not None: return last + 1 return 0 return finish