Files
scylladb/test/pylib/util.py
2024-06-11 12:18:17 +02:00

243 lines
8.2 KiB
Python

#
# Copyright (C) 2022-present ScyllaDB
#
# SPDX-License-Identifier: AGPL-3.0-or-later
#
import threading
import time
import asyncio
import logging
import pathlib
import os
import pytest
from typing import Callable, Awaitable, Optional, TypeVar, Any
from cassandra.cluster import NoHostAvailable, Session, Cluster # type: ignore # pylint: disable=no-name-in-module
from cassandra.protocol import InvalidRequest # type: ignore # pylint: disable=no-name-in-module
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
from cassandra import DriverException, ConsistencyLevel # type: ignore # pylint: disable=no-name-in-module
from test.pylib.internal_types import ServerInfo
logger = logging.getLogger(__name__)
class LogPrefixAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['prefix'], msg), kwargs
unique_name_prefix = 'test_'
T = TypeVar('T')
def unique_name():
if not hasattr(unique_name, "last_ms"):
unique_name.last_ms = 0
current_ms = int(round(time.time() * 1000))
# If unique_name() is called twice in the same millisecond...
if unique_name.last_ms >= current_ms:
current_ms = unique_name.last_ms + 1
unique_name.last_ms = current_ms
return unique_name_prefix + str(current_ms)
async def wait_for(
pred: Callable[[], Awaitable[Optional[T]]],
deadline: float,
period: float = 1,
before_retry: Optional[Callable[[], Any]] = None) -> T:
while True:
assert(time.time() < deadline), "Deadline exceeded, failing test."
res = await pred()
if res is not None:
return res
await asyncio.sleep(period)
if before_retry:
before_retry()
async def wait_for_cql(cql: Session, host: Host, deadline: float) -> None:
async def cql_ready():
try:
await cql.run_async("select * from system.local", host=host)
except NoHostAvailable:
logging.info(f"Driver not connected to {host} yet")
return None
return True
await wait_for(cql_ready, deadline)
async def wait_for_cql_and_get_hosts(cql: Session, servers: list[ServerInfo], deadline: float) \
-> list[Host]:
"""Wait until every server in `servers` is available through `cql`
and translate `servers` to a list of Cassandra `Host`s.
"""
ip_set = set(str(srv.rpc_address) for srv in servers)
async def get_hosts() -> Optional[list[Host]]:
hosts = cql.cluster.metadata.all_hosts()
remaining = ip_set - {h.address for h in hosts}
if not remaining:
return hosts
logging.info(f"Driver hasn't yet learned about hosts: {remaining}")
return None
def try_refresh_nodes():
try:
cql.cluster.refresh_nodes(force_token_rebuild=True)
except DriverException:
# Silence the exception, which might get thrown if we call this in the middle of
# driver reconnect (scylladb/scylladb#17616). `wait_for` will retry anyway and it's enough
# if we succeed only one `get_hosts()` attempt before timing out.
pass
hosts = await wait_for(
pred=get_hosts,
deadline=deadline,
before_retry=try_refresh_nodes,
)
# Take only hosts from `ip_set` (there may be more)
hosts = [h for h in hosts if h.address in ip_set]
await asyncio.gather(*(wait_for_cql(cql, h, deadline) for h in hosts))
return hosts
def read_last_line(file_path: pathlib.Path, max_line_bytes = 512):
file_size = os.stat(file_path).st_size
with file_path.open('rb') as f:
f.seek(max(0, file_size - max_line_bytes), os.SEEK_SET)
line_bytes = f.read()
line_str = line_bytes.decode('utf-8', errors='ignore')
linesep = os.linesep
if line_str.endswith(linesep):
line_str = line_str[:-len(linesep)]
linesep_index = line_str.rfind(linesep)
if linesep_index != -1:
line_str = line_str[linesep_index + len(linesep):]
elif file_size > max_line_bytes:
line_str = '...' + line_str
return line_str
async def get_available_host(cql: Session, deadline: float) -> Host:
hosts = cql.cluster.metadata.all_hosts()
async def find_host():
for h in hosts:
try:
await cql.run_async(
"select key from system.local where key = 'local'", host=h)
except NoHostAvailable:
logging.debug(f"get_available_host: {h} not available")
continue
return h
return None
return await wait_for(find_host, deadline)
async def read_barrier(cql: Session, host: Host):
"""To issue a read barrier it is sufficient to attempt dropping a
non-existing table. We need to use `if exists`, otherwise the statement
would fail on prepare/validate step which happens before a read barrier is
performed.
"""
await cql.run_async("drop table if exists nosuchkeyspace.nosuchtable", host = host)
# Wait for the given feature to be enabled.
async def wait_for_feature(feature: str, cql: Session, host: Host, deadline: float) -> None:
async def feature_is_enabled():
enabled_features = await get_enabled_features(cql, host)
return feature in enabled_features or None
await wait_for(feature_is_enabled, deadline)
async def get_supported_features(cql: Session, host: Host) -> set[str]:
"""Returns a set of cluster features that a node advertises support for."""
rs = await cql.run_async(f"SELECT supported_features FROM system.local WHERE key = 'local'", host=host)
return set(rs[0].supported_features.split(","))
async def get_enabled_features(cql: Session, host: Host) -> set[str]:
"""Returns a set of cluster features that a node considers to be enabled."""
rs = await cql.run_async(f"SELECT value FROM system.scylla_local WHERE key = 'enabled_features'", host=host)
return set(rs[0].value.split(","))
class KeyGenerator:
def __init__(self):
self.pk = None
self.pk_lock = threading.Lock()
def next_pk(self):
with self.pk_lock:
if self.pk is not None:
self.pk += 1
else:
self.pk = 0
return self.pk
def last_pk(self):
with self.pk_lock:
return self.pk
async def start_writes(cql: Session, keyspace: str, table: str, concurrency: int = 3, ignore_errors=False):
logger.info(f"Starting to asynchronously write, concurrency = {concurrency}")
stop_event = asyncio.Event()
warmup_writes = 128 // concurrency
warmup_event = asyncio.Event()
stmt = cql.prepare(f"INSERT INTO {keyspace}.{table} (pk, c) VALUES (?, ?)")
stmt.consistency_level = ConsistencyLevel.QUORUM
rd_stmt = cql.prepare(f"SELECT * FROM {keyspace}.{table} WHERE pk = ?")
rd_stmt.consistency_level = ConsistencyLevel.QUORUM
key_gen = KeyGenerator()
async def do_writes(worker_id: int):
write_count = 0
while not stop_event.is_set():
pk = key_gen.next_pk()
# Once next_pk() is produced, key_gen.last_key() is assumed to be in the database
# hence we can't give up on it.
while True:
try:
await cql.run_async(stmt, [pk, pk])
# Check read-your-writes
rows = await cql.run_async(rd_stmt, [pk])
assert(len(rows) == 1)
assert(rows[0].c == pk)
write_count += 1
break
except Exception as e:
if ignore_errors:
pass # Expected when node is brought down temporarily
else:
raise e
if pk == warmup_writes:
warmup_event.set()
logger.info(f"Worker #{worker_id} did {write_count} successful writes")
tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)]
await asyncio.wait_for(warmup_event.wait(), timeout=60)
async def finish():
logger.info("Stopping workers")
stop_event.set()
await asyncio.gather(*tasks)
last = key_gen.last_pk()
if last is not None:
return last + 1
return 0
return finish