Files
scylladb/test/cluster/lwt/lwt_common.py

368 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (C) 2024-present ScyllaDB
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
"""
Common functionality for LWT (Lightweight Transaction)
This module provides shared components for testing LWT behavior during various tablet operations
like migrations, splits, merges, etc.
"""
import asyncio
import logging
import random
import re
import time
from collections import defaultdict
from functools import cached_property
from functools import wraps
from typing import List, Dict, Callable
from cassandra import ConsistencyLevel
from cassandra import WriteTimeout, ReadTimeout, OperationTimedOut
from cassandra.query import SimpleStatement, PreparedStatement
from test.pylib.manager_client import ManagerClient
from test.pylib.tablets import get_tablet_count
logger = logging.getLogger(__name__)
# Test constants (arbitrary values)
DEFAULT_WORKERS = 20
DEFAULT_BACKOFF_BASE = 0.02
DEFAULT_NUM_KEYS = 50
UNCERTAINTY_RE = re.compile(r"write timeout due to uncertainty", re.IGNORECASE)
def backoff_on_exception(timeout, between_sleep, retry_exceptions, should_retry):
"""
Decorator for retrying async functions with backoff on specified exceptions.
"""
def decorator(target):
@wraps(target)
async def wrapper(*args, **kwargs):
event_loop = asyncio.get_event_loop()
start_time = event_loop.time()
attempt = 1
while True:
try:
return await target(*args, **kwargs)
except Exception as e:
if retry_exceptions and not isinstance(e, retry_exceptions):
raise
if should_retry and not should_retry(e):
raise
elapsed = event_loop.time() - start_time
if elapsed > timeout:
raise TimeoutError(f"Operation timed out after {timeout} seconds and {attempt} attempts") from e
await asyncio.sleep(between_sleep)
attempt += 1
return wrapper
return decorator
class Worker:
"""
A single worker increments its dedicated column `s{i}` via LWT:
UPDATE .. SET s{i}=? WHERE pk=? IF <guards on other cols> AND s{i}=?
It checks for applied state and retries on "uncertainty" timeouts.
"""
def __init__(
self,
worker_id: int,
cql,
pks: List[int],
select_statement: PreparedStatement,
update_statement: PreparedStatement,
other_columns: List[int],
get_lower_bound: Callable[[int, int], int],
on_applied: Callable[[int, int, int], None],
stop_event: asyncio.Event
):
self.stop_event = stop_event
self.success_counts: Dict[int, int] = {pk: 0 for pk in pks}
self.worker_id = worker_id
# seed for reproducibility per worker
self.rng = random.Random(worker_id)
self.select_statement = select_statement
self.update_statement = update_statement
self.other_columns = other_columns
self.pks = pks
self.cql = cql
self.get_lower_bound = get_lower_bound
self.on_applied = on_applied
async def verify_update_through_select(self, pk, new_val, prev_val):
"""
When CAS returns "uncertainty timeout", it's unclear if the update was applied.
Re-read the row at LOCAL_SERIAL consistency and check the current value.
"""
verify_stmt = self.select_statement.bind([pk])
verify_stmt.consistency_level = ConsistencyLevel.LOCAL_SERIAL
rows = await self.cql.run_async(verify_stmt)
vrow = rows[0]
current_val = getattr(vrow, f"s{self.worker_id}")
assert current_val == new_val or current_val == prev_val
return current_val == new_val
def stop(self) -> None:
self.stop_event.set()
async def __call__(self):
"""
Worker loop:
- pick random pk
- read current row
- compute guards
- execute CAS update
- handle applied/uncertainty/timeout cases
"""
while not self.stop_event.is_set():
try:
pk = self.rng.choice(self.pks)
# Read current values
verify_query = self.select_statement.bind([pk])
verify_query.consistency_level = ConsistencyLevel.LOCAL_QUORUM
rows = await self.cql.run_async(verify_query)
row = rows[0]
prev_val = getattr(row, f"s{self.worker_id}")
expected = self.success_counts[pk]
new_val = expected + 1
# Verify consistency before update
assert prev_val == expected, (
f"Consistency mismatch: pk={pk} s{self.worker_id} row={prev_val} tracker={expected}"
)
guard_vals = [
max(getattr(row, f"s{col_idx}"), self.get_lower_bound(pk, col_idx))
for col_idx in self.other_columns
]
# Prepare conditional update
update = self.update_statement.bind([new_val, pk, *guard_vals, prev_val])
update.consistency_level = ConsistencyLevel.LOCAL_QUORUM
update.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL
try:
res = await self.cql.run_async(update)
except (WriteTimeout, OperationTimedOut, ReadTimeout) as e:
if not is_uncertainty_timeout(e):
raise
applied = await self.verify_update_through_select(pk, new_val, prev_val)
else:
applied = bool(res and res[0].applied)
if not applied:
logger.error(
"LWT_NOT_APPLIED pk=%r worker=s%d new=%r guard=%r prev=%r",
pk,
self.worker_id,
new_val,
guard_vals,
prev_val,
)
raise AssertionError(
f"LWT not applied: pk={pk} s{self.worker_id} new={new_val} guard={guard_vals} prev={prev_val}"
)
if applied:
self.on_applied(pk, self.worker_id, new_val)
self.success_counts[pk] += 1
await asyncio.sleep(0.1)
except Exception:
self.stop()
raise
logger.info("Worker finished")
class BaseLWTTester:
"""
Orchestrates a multi-column LWT workload.
Creates schema, spawns workers, collects results, verifies consistency.
"""
def __init__(
self, manager: ManagerClient, ks: str, tbl: str,
num_workers: int = DEFAULT_WORKERS, num_keys: int = DEFAULT_NUM_KEYS
):
self.ks = ks
self.tbl = tbl
self.cql = manager.get_cql()
self.num_workers = num_workers
self.pks = list(range(1, num_keys + 1))
self.select_cols = ", ".join(f"s{i}" for i in range(self.num_workers))
self.workers: List[Worker] = []
self._tasks: List[asyncio.Task] = []
self.lb_counts: Dict[int, List[int]] = {pk: [0] * self.num_workers for pk in self.pks}
self.pk_to_token: Dict[int, int] = {}
self.migrations = 0
self.phase = "warmup" # "warmup" -> "migrating" -> "post"
self.phase_ops = defaultdict(int)
def _get_lower_bound(self, pk: int, col_idx: int) -> int:
return self.lb_counts[pk][col_idx]
def _on_applied(self, pk: int, col_idx: int, new_val: int) -> None:
self.lb_counts[pk][col_idx] = max(self.lb_counts[pk][col_idx], new_val)
self.phase_ops[self.phase] += 1
def set_phase(self, phase: str) -> None:
self.phase = phase
def get_phase_ops(self, phase: str) -> int:
return self.phase_ops.get(phase, 0)
async def wait_for_phase_ops(self, stop_event, phase: str, target: int, timeout: float = 120.0, poll: float = 0.2):
deadline = time.time() + timeout
while time.time() < deadline and not stop_event.is_set():
if self.get_phase_ops(phase) >= target:
return
await asyncio.sleep(poll)
raise asyncio.TimeoutError(
f"phase '{phase}' did not reach {target} ops in time (have {self.get_phase_ops(phase)})")
@cached_property
def select_statement(self) -> PreparedStatement:
return self.cql.prepare(
f"SELECT {self.select_cols} FROM {self.ks}.{self.tbl} WHERE pk = ?"
)
def create_workers(self, stop_event) -> List[Worker]:
workers: List[Worker] = []
for i in range(self.num_workers):
other_columns = [j for j in range(self.num_workers) if j != i]
cond = " AND ".join([*(f"s{j} >= ?" for j in other_columns), f"s{i} = ?"])
query = f"UPDATE {self.ks}.{self.tbl} SET s{i} = ? WHERE pk = ? IF {cond}"
worker = Worker(
stop_event=stop_event,
worker_id=i,
cql=self.cql,
pks=self.pks,
select_statement=self.select_statement,
update_statement=self.cql.prepare(query),
other_columns=other_columns,
get_lower_bound=self._get_lower_bound,
on_applied=self._on_applied,
)
workers.append(worker)
return workers
async def create_schema(self):
"""Create table with multiple counter columns"""
cols_def = ", ".join(f"s{i} int" for i in range(self.num_workers))
await self.cql.run_async(
f"CREATE TABLE {self.ks}.{self.tbl} (pk int PRIMARY KEY, {cols_def})"
)
logger.info("Created table %s.%s with %d columns", self.ks, self.tbl, self.num_workers)
async def initialize_rows(self):
"""
Insert initial rows with all columns set to 0.
Then precompute and cache pk->token mapping.
Both inserts and selects are run in parallel.
"""
zeros = ", ".join("0" for _ in range(self.num_workers))
ps = self.cql.prepare(
f"INSERT INTO {self.ks}.{self.tbl} (pk, {self.select_cols}) VALUES (?, {zeros})"
)
insert_tasks = [self.cql.run_async(ps.bind([pk])) for pk in self.pks]
await asyncio.gather(*insert_tasks)
token_tasks = [get_token_for_pk(self.cql, self.ks, self.tbl, pk) for pk in self.pks]
tokens = await asyncio.gather(*token_tasks)
for pk, token in zip(self.pks, tokens):
self.pk_to_token[pk] = token
async def start_workers(self, stop_event):
"""Start workload workers"""
self.workers = self.create_workers(stop_event)
self._tasks = [asyncio.create_task(worker()) for worker in self.workers]
logger.info("Started %d LWT workers", len(self._tasks))
async def stop_workers(self):
"""Stop workload workers and surface exceptions"""
for worker in self.workers:
worker.stop()
if self._tasks:
results = await asyncio.gather(*self._tasks, return_exceptions=True)
errs = [e for e in results if isinstance(e, Exception)]
self._tasks.clear()
assert not errs, f"worker errors: {errs}"
logger.info("All workers stopped")
async def verify_consistency(self):
"""Ensure every (pk, column) reflects the number of successful CAS writes."""
# Run SELECTs for all PKs in parallel using prepared statement
tasks = []
for pk in self.pks:
stmt = self.select_statement.bind([pk])
stmt.consistency_level = ConsistencyLevel.LOCAL_QUORUM
tasks.append(self.cql.run_async(stmt))
results = await asyncio.gather(*tasks)
mismatches = []
for pk, rows in zip(self.pks, results):
row = rows[0]
for i, worker in enumerate(self.workers):
actual = getattr(row, f"s{i}")
expected = worker.success_counts[pk]
if actual != expected:
mismatches.append(f"pk={pk} s{i}={actual}, expected {expected}")
assert not mismatches, "Consistency violations: " + "; ".join(mismatches)
total_ops = sum(sum(w.success_counts.values()) for w in self.workers)
logger.info("Consistency verified %d total successful CAS operations", total_ops)
async def get_token_for_pk(cql, ks: str, tbl: str, pk: int) -> int:
"""Get the token for a given primary key"""
stmt = SimpleStatement(
f"SELECT token(pk) AS tk FROM {ks}.{tbl} WHERE pk = %s",
consistency_level=ConsistencyLevel.QUORUM,
)
row = (await cql.run_async(stmt, [pk]))[0]
return row.tk
async def get_host_map(manager: ManagerClient, servers):
"""Create a mapping from host IDs to server info"""
ids = await asyncio.gather(*[manager.get_host_id(s.server_id) for s in servers])
return {hid: srv for hid, srv in zip(ids, servers)}
async def pick_non_replica_server(manager: ManagerClient, servers, replica_host_ids):
"""Find a random server that is not a replica for the given tablet"""
host_map = await get_host_map(manager, servers)
non_replicas = [srv for hid, srv in host_map.items() if hid not in replica_host_ids]
return random.choice(non_replicas)
async def wait_for_tablet_count(
manager: ManagerClient, server, ks: str, tbl: str,
predicate, target: int, timeout_s: int = 180, poll_s: float = 1.0
):
"""
Wait for tablet count to match predicate.
predicate: callable like lambda c: c >= target (for split) or lambda c: c <= target (for merge)
"""
deadline = time.time() + timeout_s
last = None
while time.time() < deadline:
count = await get_tablet_count(manager, server, ks, tbl)
last = count
if predicate(count):
return count
await asyncio.sleep(poll_s)
raise TimeoutError(f"tablet count wait timed out (last={last}, target={target})")
def is_uncertainty_timeout(exc: Exception) -> bool:
return isinstance(exc, (WriteTimeout, ReadTimeout)) and UNCERTAINTY_RE.search(str(exc)) is not None