Files
scylladb/test/pylib/util.py
Artsiom Mishuta cd1679934c test/pylib: use exponential backoff in wait_for()
Change wait_for() defaults from period=1s/no backoff to period=0.1s
with 1.5x backoff capped at 1.0s. This catches fast conditions in
100ms instead of 1000ms, benefiting ~100 call sites automatically.

Add completion logging with elapsed time and iteration count.

Tested local with test/cluster/test_fencing.py::test_fence_hints (dev mode),
log output:

  wait_for(at_least_one_hint_failed) completed in 0.83s (4 iterations)
  wait_for(exactly_one_hint_sent) completed in 1.34s (5 iterations)

Fixes SCYLLADB-738

Closes scylladb/scylladb#29173
2026-03-24 23:49:49 +02:00

436 lines
16 KiB
Python

#
# Copyright (C) 2022-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
from __future__ import annotations
import re
import subprocess
import threading
import time
import asyncio
import logging
import pathlib
import os
import universalasync
from collections.abc import Awaitable, Callable, Coroutine
from functools import cache
import random
import string
from typing import Optional, TypeVar, Any, cast
from cassandra.cluster import NoHostAvailable, Session, Cluster # type: ignore # pylint: disable=no-name-in-module
from cassandra.protocol import InvalidRequest # type: ignore # pylint: disable=no-name-in-module
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
from cassandra.query import Statement # type: ignore # pylint: disable=no-name-in-module
from cassandra import DriverException, ConsistencyLevel # type: ignore # pylint: disable=no-name-in-module
from test import BUILD_DIR, TOP_SRC_DIR, MODES_TIMEOUT_FACTOR
from test.pylib.internal_types import ServerInfo
logger = logging.getLogger(__name__)
class LogPrefixAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['prefix'], msg), kwargs
T = TypeVar('T')
def unique_name(unique_name_prefix = 'test_'):
if not hasattr(unique_name, "last_ms"):
unique_name.last_ms = 0
current_ms = int(round(time.time() * 1000))
# If unique_name() is called twice in the same millisecond...
if unique_name.last_ms >= current_ms:
current_ms = unique_name.last_ms + 1
unique_name.last_ms = current_ms
return unique_name_prefix + str(current_ms) + '_' + ''.join(random.choice(string.ascii_lowercase) for _ in range(5))
async def wait_for(
pred: Callable[[], Awaitable[Optional[T]]],
deadline: float,
period: float = 0.1,
before_retry: Optional[Callable[[], Any]] = None,
backoff_factor: float = 1.5,
max_period: float = 1.0,
label: Optional[str] = None) -> T:
tag = label or getattr(pred, '__name__', 'unlabeled')
start = time.time()
retries = 0
while True:
elapsed = time.time() - start
assert time.time() < deadline, \
f"wait_for({tag}) timed out after {elapsed:.2f}s ({retries} retries)"
res = await pred()
if res is not None:
if retries > 0:
logger.debug(f"wait_for({tag}) completed "
f"in {elapsed:.2f}s ({retries} retries)")
return res
retries += 1
await asyncio.sleep(period)
period *= backoff_factor
if max_period is not None:
period = min(period, max_period)
if before_retry:
before_retry()
async def wait_for_cql(cql: Session, host: Host, deadline: float) -> None:
async def cql_ready():
try:
await cql.run_async("select * from system.local", host=host)
except NoHostAvailable:
logging.info(f"Driver not connected to {host} yet")
return None
return True
await wait_for(cql_ready, deadline, period=0.1)
async def wait_for_cql_and_get_hosts(cql: Session, servers: list[ServerInfo], deadline: float) \
-> list[Host]:
"""Wait until every server in `servers` is available through `cql`
and translate `servers` to a list of Cassandra `Host`s.
"""
ip_set = set(str(srv.rpc_address) for srv in servers)
async def get_hosts() -> Optional[list[Host]]:
hosts = cql.cluster.metadata.all_hosts()
remaining = ip_set - {h.address for h in hosts}
if not remaining:
return hosts
logging.info(f"Driver hasn't yet learned about hosts: {remaining}")
return None
def try_refresh_nodes():
try:
cql.cluster.refresh_nodes(force_token_rebuild=True)
except DriverException:
# Silence the exception, which might get thrown if we call this in the middle of
# driver reconnect (scylladb/scylladb#17616). `wait_for` will retry anyway and it's enough
# if we succeed only one `get_hosts()` attempt before timing out.
pass
hosts = await wait_for(
pred=get_hosts,
deadline=deadline,
before_retry=try_refresh_nodes,
)
# Take only hosts from `ip_set` (there may be more)
hosts = [h for h in hosts if h.address in ip_set]
# Make sure `hosts` has same order as `servers`, that is: a given index will
# refer to the same underlying Scylla instance in both `servers` and `hosts`.
servers_by_ip = {srv.rpc_address: i for i, srv in enumerate(servers)}
hosts.sort(key=lambda x: servers_by_ip[x.address])
await asyncio.gather(*(wait_for_cql(cql, h, deadline) for h in hosts))
return hosts
def read_last_line(file_path: pathlib.Path, max_line_bytes = 512):
file_size = os.stat(file_path).st_size
with file_path.open('rb') as f:
f.seek(max(0, file_size - max_line_bytes), os.SEEK_SET)
line_bytes = f.read()
line_str = line_bytes.decode('utf-8', errors='ignore')
linesep = os.linesep
if line_str.endswith(linesep):
line_str = line_str[:-len(linesep)]
linesep_index = line_str.rfind(linesep)
if linesep_index != -1:
line_str = line_str[linesep_index + len(linesep):]
elif file_size > max_line_bytes:
line_str = '...' + line_str
return line_str
async def get_available_host(cql: Session, deadline: float) -> Host:
hosts = cql.cluster.metadata.all_hosts()
async def find_host():
for h in hosts:
try:
await cql.run_async(
"select key from system.local where key = 'local'", host=h)
except NoHostAvailable:
logging.debug(f"get_available_host: {h} not available")
continue
return h
return None
return await wait_for(find_host, deadline)
# Wait for the given feature to be enabled.
async def wait_for_feature(feature: str, cql: Session, host: Host, deadline: float) -> None:
async def feature_is_enabled():
enabled_features = await get_enabled_features(cql, host)
return feature in enabled_features or None
await wait_for(feature_is_enabled, deadline)
async def get_supported_features(cql: Session, host: Host) -> set[str]:
"""Returns a set of cluster features that a node advertises support for."""
rs = await cql.run_async(f"SELECT supported_features FROM system.local WHERE key = 'local'", host=host)
return set(rs[0].supported_features.split(","))
async def get_enabled_features(cql: Session, host: Host) -> set[str]:
"""Returns a set of cluster features that a node considers to be enabled."""
rs = await cql.run_async(f"SELECT value FROM system.scylla_local WHERE key = 'enabled_features'", host=host)
return set(rs[0].value.split(","))
class KeyGenerator:
def __init__(self):
self.pk = None
self.pk_lock = threading.Lock()
def next_pk(self):
with self.pk_lock:
if self.pk is not None:
self.pk += 1
else:
self.pk = 0
return self.pk
def last_pk(self):
with self.pk_lock:
return self.pk
async def start_writes(cql: Session, keyspace: str, table: str, concurrency: int = 3, ignore_errors=False):
logger.info(f"Starting to asynchronously write, concurrency = {concurrency}")
stop_event = asyncio.Event()
warmup_writes = 128 // concurrency
warmup_event = asyncio.Event()
stmt = cql.prepare(f"INSERT INTO {keyspace}.{table} (pk, c) VALUES (?, ?)")
stmt.consistency_level = ConsistencyLevel.QUORUM
rd_stmt = cql.prepare(f"SELECT * FROM {keyspace}.{table} WHERE pk = ?")
rd_stmt.consistency_level = ConsistencyLevel.QUORUM
key_gen = KeyGenerator()
async def do_writes(worker_id: int):
async def run_retry_async(cql, *args, **kwargs):
retry_attempts = 1 if ignore_errors else 3
sleep_time = 0.05 # 50ms
error = None
for _ in range(retry_attempts):
try:
return await cql.run_async(*args, **kwargs)
except Exception as e:
error = e
logger.debug(f"Retrying in {sleep_time} second(s)")
await asyncio.sleep(sleep_time)
sleep_time *= 2 # Exponential backoff
raise error
write_count = 0
while not stop_event.is_set():
pk = key_gen.next_pk()
# Once next_pk() is produced, key_gen.last_key() is assumed to be in the database
# hence we can't give up on it.
while True:
start_time = time.time()
try:
await cql.run_async(stmt, [pk, pk])
# Check read-your-writes
rows = await run_retry_async(cql, rd_stmt, [pk])
assert(len(rows) == 1)
assert(rows[0].c == pk)
write_count += 1
break
except Exception as e:
if ignore_errors:
logger.debug(f"Suppressed exception: {e} (expected when node is brought down temporarily)")
else:
logger.error(f"Exception occurred during read/write operation for pk={pk} started {time.time() - start_time}s ago: {e}")
raise e
if pk == warmup_writes:
warmup_event.set()
logger.info(f"Worker #{worker_id} did {write_count} successful writes")
tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)]
await asyncio.wait_for(warmup_event.wait(), timeout=60)
async def finish():
logger.info("Stopping workers")
stop_event.set()
await asyncio.gather(*tasks)
last = key_gen.last_pk()
if last is not None:
return last + 1
return 0
return finish
async def wait_for_view_v1(cql: Session, name: str, node_count: int, timeout: int = 120):
async def view_is_built():
done = await cql.run_async(f"SELECT COUNT(*) FROM system_distributed.view_build_status WHERE status = 'SUCCESS' AND view_name = '{name}' ALLOW FILTERING")
return done[0][0] == node_count or None
deadline = time.time() + timeout
await wait_for(view_is_built, deadline, label=f"view_v1_{name}")
async def wait_for_view(cql: Session, name: str, node_count: int, timeout: int = 120):
async def view_is_built():
done = await cql.run_async(f"SELECT COUNT(*) FROM system.view_build_status_v2 WHERE status = 'SUCCESS' AND view_name = '{name}' ALLOW FILTERING")
return done[0][0] == node_count or None
deadline = time.time() + timeout
await wait_for(view_is_built, deadline, label=f"view_{name}")
async def wait_for_first_completed(coros: list[Coroutine], timeout: int|None = None):
tasks = [asyncio.create_task(c) for c in coros]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
if not done:
# Timeout occurred, cancel all
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise asyncio.TimeoutError("No task completed within timeout")
# Cancel pending tasks
for task in pending:
task.cancel()
# Get first result
list_done = list(done)
first_task = list_done.pop(0)
result = await first_task
# Clean up
cleanup = list(pending) + list_done
if cleanup:
await asyncio.gather(*cleanup, return_exceptions=True)
return result
async def wait_all(coros: list[Coroutine], timeout: int|None = None):
tasks = [asyncio.create_task(c) for c in coros]
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED, timeout=timeout)
if not done:
# Timeout occurred, cancel all
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise asyncio.TimeoutError("No task completed within timeout")
assert not pending
# Get first result
list_done = list(done)
first_task = list_done.pop(0)
result = await first_task
# Clean up
cleanup = list_done
if cleanup:
await asyncio.gather(*cleanup, return_exceptions=True)
return result
def ninja(target: str) -> str:
"""Build specified target using ninja."""
return subprocess.Popen(
# cmake places build.ninja in build/, traditional is in ./.
# We choose to test for traditional, not cmake, because IDEs may
# invoke cmake to learn the configuration and generate false positives
args=["ninja", *(["-C", str(BUILD_DIR)] if not TOP_SRC_DIR.joinpath("build.ninja").exists() else []), target],
stdout=subprocess.PIPE,
cwd=TOP_SRC_DIR,
).communicate()[0].decode()
@cache
def get_configured_modes() -> list[str]:
out = ninja('mode_list')
# [1/1] List configured modes
# debug release dev
return re.sub(r'.* List configured modes\n(.*)\n', r'\1',
out, count=1, flags=re.DOTALL).split('\n')[-1].split(' ')
def get_modes_to_run(config) -> list[str]:
modes = config.getoption('modes')
if not modes:
modes = get_configured_modes()
if not modes:
raise RuntimeError('No modes configured. Please run ./configure.py first')
return modes
def scale_timeout_by_mode(mode: str, timeout: int | float) -> int | float:
"""Scale timeout according to test.py mode semantics.
Each mode has a different scale: debug and sanitize modes are multiplied by 3, dev by 2.
Unknown modes are left unchanged.
"""
return MODES_TIMEOUT_FACTOR.get(mode, 1) * timeout
async def gather_safely(*awaitables: Awaitable):
"""
Developers using asyncio.gather() often assume that it waits for all futures (awaitables) givens.
But this isn't true when the return_exceptions parameter is False, which is the default.
In that case, as soon as one future completes with an exception, the gather() call will return this exception
immediately, and some of the finished tasks may continue to run in the background.
This is bad for applications that use gather() to ensure that a list of background tasks has all completed.
So such applications must use asyncio.gather() with return_exceptions=True, to wait for all given futures to
complete either successfully or unsuccessfully.
"""
results = await asyncio.gather(*awaitables, return_exceptions=True)
for result in results:
if isinstance(result, BaseException):
raise result from None
return results
def get_xdist_worker_id() -> str | None:
return os.environ.get("PYTEST_XDIST_WORKER")
def execute_with_tracing(cql : Session, statement : str | Statement, log : bool = False, *cql_execute_extra_args, **cql_execute_extra_kwargs):
""" Execute statement via cql session and log the tracing output. """
cql_execute_extra_kwargs['trace'] = True
query_result = cql.execute(statement, *cql_execute_extra_args, **cql_execute_extra_kwargs)
tracing = query_result.get_all_query_traces(max_wait_sec_per=900)
ret = []
page_traces = []
for trace in tracing:
ret.append(trace.events)
if not log:
continue
trace_events = []
for event in trace.events:
trace_events.append(f" {event.source} {event.source_elapsed} {event.description}")
page_traces.append("\n".join(trace_events))
if log:
logger.debug("Tracing {}:\n{}\n".format(statement, "\n".join(page_traces)))
return ret
def universalasync_typed_wrap(cls: T) -> T:
return cast(T, universalasync.wrap(cls))