Change wait_for() defaults from period=1s/no backoff to period=0.1s with 1.5x backoff capped at 1.0s. This catches fast conditions in 100ms instead of 1000ms, benefiting ~100 call sites automatically. Add completion logging with elapsed time and iteration count. Tested local with test/cluster/test_fencing.py::test_fence_hints (dev mode), log output: wait_for(at_least_one_hint_failed) completed in 0.83s (4 iterations) wait_for(exactly_one_hint_sent) completed in 1.34s (5 iterations) Fixes SCYLLADB-738 Closes scylladb/scylladb#29173
436 lines
16 KiB
Python
436 lines
16 KiB
Python
#
|
|
# Copyright (C) 2022-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
#
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
import pathlib
|
|
import os
|
|
import universalasync
|
|
from collections.abc import Awaitable, Callable, Coroutine
|
|
from functools import cache
|
|
|
|
import random
|
|
import string
|
|
|
|
from typing import Optional, TypeVar, Any, cast
|
|
|
|
from cassandra.cluster import NoHostAvailable, Session, Cluster # type: ignore # pylint: disable=no-name-in-module
|
|
from cassandra.protocol import InvalidRequest # type: ignore # pylint: disable=no-name-in-module
|
|
from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module
|
|
from cassandra.query import Statement # type: ignore # pylint: disable=no-name-in-module
|
|
from cassandra import DriverException, ConsistencyLevel # type: ignore # pylint: disable=no-name-in-module
|
|
|
|
from test import BUILD_DIR, TOP_SRC_DIR, MODES_TIMEOUT_FACTOR
|
|
from test.pylib.internal_types import ServerInfo
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LogPrefixAdapter(logging.LoggerAdapter):
|
|
def process(self, msg, kwargs):
|
|
return '[%s] %s' % (self.extra['prefix'], msg), kwargs
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
def unique_name(unique_name_prefix = 'test_'):
|
|
if not hasattr(unique_name, "last_ms"):
|
|
unique_name.last_ms = 0
|
|
current_ms = int(round(time.time() * 1000))
|
|
# If unique_name() is called twice in the same millisecond...
|
|
if unique_name.last_ms >= current_ms:
|
|
current_ms = unique_name.last_ms + 1
|
|
unique_name.last_ms = current_ms
|
|
return unique_name_prefix + str(current_ms) + '_' + ''.join(random.choice(string.ascii_lowercase) for _ in range(5))
|
|
|
|
|
|
async def wait_for(
|
|
pred: Callable[[], Awaitable[Optional[T]]],
|
|
deadline: float,
|
|
period: float = 0.1,
|
|
before_retry: Optional[Callable[[], Any]] = None,
|
|
backoff_factor: float = 1.5,
|
|
max_period: float = 1.0,
|
|
label: Optional[str] = None) -> T:
|
|
tag = label or getattr(pred, '__name__', 'unlabeled')
|
|
start = time.time()
|
|
retries = 0
|
|
while True:
|
|
elapsed = time.time() - start
|
|
assert time.time() < deadline, \
|
|
f"wait_for({tag}) timed out after {elapsed:.2f}s ({retries} retries)"
|
|
res = await pred()
|
|
if res is not None:
|
|
if retries > 0:
|
|
logger.debug(f"wait_for({tag}) completed "
|
|
f"in {elapsed:.2f}s ({retries} retries)")
|
|
return res
|
|
retries += 1
|
|
await asyncio.sleep(period)
|
|
period *= backoff_factor
|
|
if max_period is not None:
|
|
period = min(period, max_period)
|
|
if before_retry:
|
|
before_retry()
|
|
|
|
|
|
async def wait_for_cql(cql: Session, host: Host, deadline: float) -> None:
|
|
async def cql_ready():
|
|
try:
|
|
await cql.run_async("select * from system.local", host=host)
|
|
except NoHostAvailable:
|
|
logging.info(f"Driver not connected to {host} yet")
|
|
return None
|
|
return True
|
|
await wait_for(cql_ready, deadline, period=0.1)
|
|
|
|
|
|
async def wait_for_cql_and_get_hosts(cql: Session, servers: list[ServerInfo], deadline: float) \
|
|
-> list[Host]:
|
|
"""Wait until every server in `servers` is available through `cql`
|
|
and translate `servers` to a list of Cassandra `Host`s.
|
|
"""
|
|
ip_set = set(str(srv.rpc_address) for srv in servers)
|
|
async def get_hosts() -> Optional[list[Host]]:
|
|
hosts = cql.cluster.metadata.all_hosts()
|
|
remaining = ip_set - {h.address for h in hosts}
|
|
if not remaining:
|
|
return hosts
|
|
|
|
logging.info(f"Driver hasn't yet learned about hosts: {remaining}")
|
|
return None
|
|
def try_refresh_nodes():
|
|
try:
|
|
cql.cluster.refresh_nodes(force_token_rebuild=True)
|
|
except DriverException:
|
|
# Silence the exception, which might get thrown if we call this in the middle of
|
|
# driver reconnect (scylladb/scylladb#17616). `wait_for` will retry anyway and it's enough
|
|
# if we succeed only one `get_hosts()` attempt before timing out.
|
|
pass
|
|
hosts = await wait_for(
|
|
pred=get_hosts,
|
|
deadline=deadline,
|
|
before_retry=try_refresh_nodes,
|
|
)
|
|
|
|
# Take only hosts from `ip_set` (there may be more)
|
|
hosts = [h for h in hosts if h.address in ip_set]
|
|
|
|
# Make sure `hosts` has same order as `servers`, that is: a given index will
|
|
# refer to the same underlying Scylla instance in both `servers` and `hosts`.
|
|
servers_by_ip = {srv.rpc_address: i for i, srv in enumerate(servers)}
|
|
hosts.sort(key=lambda x: servers_by_ip[x.address])
|
|
|
|
await asyncio.gather(*(wait_for_cql(cql, h, deadline) for h in hosts))
|
|
|
|
return hosts
|
|
|
|
def read_last_line(file_path: pathlib.Path, max_line_bytes = 512):
|
|
file_size = os.stat(file_path).st_size
|
|
with file_path.open('rb') as f:
|
|
f.seek(max(0, file_size - max_line_bytes), os.SEEK_SET)
|
|
line_bytes = f.read()
|
|
line_str = line_bytes.decode('utf-8', errors='ignore')
|
|
linesep = os.linesep
|
|
if line_str.endswith(linesep):
|
|
line_str = line_str[:-len(linesep)]
|
|
linesep_index = line_str.rfind(linesep)
|
|
if linesep_index != -1:
|
|
line_str = line_str[linesep_index + len(linesep):]
|
|
elif file_size > max_line_bytes:
|
|
line_str = '...' + line_str
|
|
return line_str
|
|
|
|
|
|
async def get_available_host(cql: Session, deadline: float) -> Host:
|
|
hosts = cql.cluster.metadata.all_hosts()
|
|
async def find_host():
|
|
for h in hosts:
|
|
try:
|
|
await cql.run_async(
|
|
"select key from system.local where key = 'local'", host=h)
|
|
except NoHostAvailable:
|
|
logging.debug(f"get_available_host: {h} not available")
|
|
continue
|
|
return h
|
|
return None
|
|
return await wait_for(find_host, deadline)
|
|
|
|
|
|
# Wait for the given feature to be enabled.
|
|
async def wait_for_feature(feature: str, cql: Session, host: Host, deadline: float) -> None:
|
|
async def feature_is_enabled():
|
|
enabled_features = await get_enabled_features(cql, host)
|
|
return feature in enabled_features or None
|
|
await wait_for(feature_is_enabled, deadline)
|
|
|
|
|
|
async def get_supported_features(cql: Session, host: Host) -> set[str]:
|
|
"""Returns a set of cluster features that a node advertises support for."""
|
|
rs = await cql.run_async(f"SELECT supported_features FROM system.local WHERE key = 'local'", host=host)
|
|
return set(rs[0].supported_features.split(","))
|
|
|
|
|
|
async def get_enabled_features(cql: Session, host: Host) -> set[str]:
|
|
"""Returns a set of cluster features that a node considers to be enabled."""
|
|
rs = await cql.run_async(f"SELECT value FROM system.scylla_local WHERE key = 'enabled_features'", host=host)
|
|
return set(rs[0].value.split(","))
|
|
|
|
|
|
class KeyGenerator:
|
|
def __init__(self):
|
|
self.pk = None
|
|
self.pk_lock = threading.Lock()
|
|
|
|
def next_pk(self):
|
|
with self.pk_lock:
|
|
if self.pk is not None:
|
|
self.pk += 1
|
|
else:
|
|
self.pk = 0
|
|
return self.pk
|
|
|
|
def last_pk(self):
|
|
with self.pk_lock:
|
|
return self.pk
|
|
|
|
async def start_writes(cql: Session, keyspace: str, table: str, concurrency: int = 3, ignore_errors=False):
|
|
logger.info(f"Starting to asynchronously write, concurrency = {concurrency}")
|
|
|
|
stop_event = asyncio.Event()
|
|
|
|
warmup_writes = 128 // concurrency
|
|
warmup_event = asyncio.Event()
|
|
|
|
stmt = cql.prepare(f"INSERT INTO {keyspace}.{table} (pk, c) VALUES (?, ?)")
|
|
stmt.consistency_level = ConsistencyLevel.QUORUM
|
|
rd_stmt = cql.prepare(f"SELECT * FROM {keyspace}.{table} WHERE pk = ?")
|
|
rd_stmt.consistency_level = ConsistencyLevel.QUORUM
|
|
|
|
key_gen = KeyGenerator()
|
|
|
|
async def do_writes(worker_id: int):
|
|
async def run_retry_async(cql, *args, **kwargs):
|
|
retry_attempts = 1 if ignore_errors else 3
|
|
sleep_time = 0.05 # 50ms
|
|
error = None
|
|
for _ in range(retry_attempts):
|
|
try:
|
|
return await cql.run_async(*args, **kwargs)
|
|
except Exception as e:
|
|
error = e
|
|
|
|
logger.debug(f"Retrying in {sleep_time} second(s)")
|
|
await asyncio.sleep(sleep_time)
|
|
sleep_time *= 2 # Exponential backoff
|
|
raise error
|
|
|
|
write_count = 0
|
|
while not stop_event.is_set():
|
|
pk = key_gen.next_pk()
|
|
|
|
# Once next_pk() is produced, key_gen.last_key() is assumed to be in the database
|
|
# hence we can't give up on it.
|
|
while True:
|
|
start_time = time.time()
|
|
try:
|
|
await cql.run_async(stmt, [pk, pk])
|
|
# Check read-your-writes
|
|
rows = await run_retry_async(cql, rd_stmt, [pk])
|
|
assert(len(rows) == 1)
|
|
assert(rows[0].c == pk)
|
|
write_count += 1
|
|
break
|
|
except Exception as e:
|
|
if ignore_errors:
|
|
logger.debug(f"Suppressed exception: {e} (expected when node is brought down temporarily)")
|
|
else:
|
|
logger.error(f"Exception occurred during read/write operation for pk={pk} started {time.time() - start_time}s ago: {e}")
|
|
raise e
|
|
|
|
if pk == warmup_writes:
|
|
warmup_event.set()
|
|
|
|
logger.info(f"Worker #{worker_id} did {write_count} successful writes")
|
|
|
|
tasks = [asyncio.create_task(do_writes(worker_id)) for worker_id in range(concurrency)]
|
|
|
|
await asyncio.wait_for(warmup_event.wait(), timeout=60)
|
|
|
|
async def finish():
|
|
logger.info("Stopping workers")
|
|
stop_event.set()
|
|
await asyncio.gather(*tasks)
|
|
|
|
last = key_gen.last_pk()
|
|
if last is not None:
|
|
return last + 1
|
|
return 0
|
|
|
|
return finish
|
|
|
|
async def wait_for_view_v1(cql: Session, name: str, node_count: int, timeout: int = 120):
|
|
async def view_is_built():
|
|
done = await cql.run_async(f"SELECT COUNT(*) FROM system_distributed.view_build_status WHERE status = 'SUCCESS' AND view_name = '{name}' ALLOW FILTERING")
|
|
return done[0][0] == node_count or None
|
|
deadline = time.time() + timeout
|
|
await wait_for(view_is_built, deadline, label=f"view_v1_{name}")
|
|
|
|
async def wait_for_view(cql: Session, name: str, node_count: int, timeout: int = 120):
|
|
async def view_is_built():
|
|
done = await cql.run_async(f"SELECT COUNT(*) FROM system.view_build_status_v2 WHERE status = 'SUCCESS' AND view_name = '{name}' ALLOW FILTERING")
|
|
return done[0][0] == node_count or None
|
|
deadline = time.time() + timeout
|
|
await wait_for(view_is_built, deadline, label=f"view_{name}")
|
|
|
|
|
|
async def wait_for_first_completed(coros: list[Coroutine], timeout: int|None = None):
|
|
tasks = [asyncio.create_task(c) for c in coros]
|
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED, timeout=timeout)
|
|
if not done:
|
|
# Timeout occurred, cancel all
|
|
for task in tasks:
|
|
task.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise asyncio.TimeoutError("No task completed within timeout")
|
|
|
|
# Cancel pending tasks
|
|
for task in pending:
|
|
task.cancel()
|
|
|
|
# Get first result
|
|
list_done = list(done)
|
|
first_task = list_done.pop(0)
|
|
result = await first_task
|
|
|
|
# Clean up
|
|
cleanup = list(pending) + list_done
|
|
if cleanup:
|
|
await asyncio.gather(*cleanup, return_exceptions=True)
|
|
|
|
return result
|
|
|
|
|
|
async def wait_all(coros: list[Coroutine], timeout: int|None = None):
|
|
tasks = [asyncio.create_task(c) for c in coros]
|
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED, timeout=timeout)
|
|
if not done:
|
|
# Timeout occurred, cancel all
|
|
for task in tasks:
|
|
task.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise asyncio.TimeoutError("No task completed within timeout")
|
|
|
|
assert not pending
|
|
|
|
# Get first result
|
|
list_done = list(done)
|
|
first_task = list_done.pop(0)
|
|
result = await first_task
|
|
|
|
# Clean up
|
|
cleanup = list_done
|
|
if cleanup:
|
|
await asyncio.gather(*cleanup, return_exceptions=True)
|
|
|
|
return result
|
|
|
|
|
|
def ninja(target: str) -> str:
|
|
"""Build specified target using ninja."""
|
|
|
|
return subprocess.Popen(
|
|
# cmake places build.ninja in build/, traditional is in ./.
|
|
# We choose to test for traditional, not cmake, because IDEs may
|
|
# invoke cmake to learn the configuration and generate false positives
|
|
args=["ninja", *(["-C", str(BUILD_DIR)] if not TOP_SRC_DIR.joinpath("build.ninja").exists() else []), target],
|
|
stdout=subprocess.PIPE,
|
|
cwd=TOP_SRC_DIR,
|
|
).communicate()[0].decode()
|
|
|
|
|
|
@cache
|
|
def get_configured_modes() -> list[str]:
|
|
out = ninja('mode_list')
|
|
# [1/1] List configured modes
|
|
# debug release dev
|
|
return re.sub(r'.* List configured modes\n(.*)\n', r'\1',
|
|
out, count=1, flags=re.DOTALL).split('\n')[-1].split(' ')
|
|
|
|
|
|
def get_modes_to_run(config) -> list[str]:
|
|
modes = config.getoption('modes')
|
|
if not modes:
|
|
modes = get_configured_modes()
|
|
if not modes:
|
|
raise RuntimeError('No modes configured. Please run ./configure.py first')
|
|
return modes
|
|
|
|
|
|
def scale_timeout_by_mode(mode: str, timeout: int | float) -> int | float:
|
|
"""Scale timeout according to test.py mode semantics.
|
|
Each mode has a different scale: debug and sanitize modes are multiplied by 3, dev by 2.
|
|
Unknown modes are left unchanged.
|
|
"""
|
|
return MODES_TIMEOUT_FACTOR.get(mode, 1) * timeout
|
|
|
|
|
|
async def gather_safely(*awaitables: Awaitable):
|
|
"""
|
|
Developers using asyncio.gather() often assume that it waits for all futures (awaitables) givens.
|
|
But this isn't true when the return_exceptions parameter is False, which is the default.
|
|
In that case, as soon as one future completes with an exception, the gather() call will return this exception
|
|
immediately, and some of the finished tasks may continue to run in the background.
|
|
This is bad for applications that use gather() to ensure that a list of background tasks has all completed.
|
|
So such applications must use asyncio.gather() with return_exceptions=True, to wait for all given futures to
|
|
complete either successfully or unsuccessfully.
|
|
"""
|
|
results = await asyncio.gather(*awaitables, return_exceptions=True)
|
|
for result in results:
|
|
if isinstance(result, BaseException):
|
|
raise result from None
|
|
return results
|
|
|
|
|
|
def get_xdist_worker_id() -> str | None:
|
|
return os.environ.get("PYTEST_XDIST_WORKER")
|
|
|
|
|
|
def execute_with_tracing(cql : Session, statement : str | Statement, log : bool = False, *cql_execute_extra_args, **cql_execute_extra_kwargs):
|
|
""" Execute statement via cql session and log the tracing output. """
|
|
|
|
cql_execute_extra_kwargs['trace'] = True
|
|
query_result = cql.execute(statement, *cql_execute_extra_args, **cql_execute_extra_kwargs)
|
|
|
|
tracing = query_result.get_all_query_traces(max_wait_sec_per=900)
|
|
|
|
ret = []
|
|
page_traces = []
|
|
for trace in tracing:
|
|
ret.append(trace.events)
|
|
if not log:
|
|
continue
|
|
|
|
trace_events = []
|
|
for event in trace.events:
|
|
trace_events.append(f" {event.source} {event.source_elapsed} {event.description}")
|
|
page_traces.append("\n".join(trace_events))
|
|
|
|
if log:
|
|
logger.debug("Tracing {}:\n{}\n".format(statement, "\n".join(page_traces)))
|
|
|
|
return ret
|
|
|
|
|
|
def universalasync_typed_wrap(cls: T) -> T:
|
|
return cast(T, universalasync.wrap(cls))
|