Files
scylladb/test/cluster/object_store/test_backup.py
Pavel Emelyanov 19820910f8 test: Add test for backup vs migration race
The test starts regular backup+restore on a smaller cluster, but prior
to it spawns tablet migration from one node to another and locks it in
the middle with the help of block_tablet_streaming injection (even
though tablets have no data and there's nothing to stream, the injection
is located early enough to work).

Signed-off-by: Pavel Emelyanov <xemul@scylladb.com>
2026-05-12 10:40:24 +03:00

1203 lines
62 KiB
Python

#!/usr/bin/env python3
import glob
import json
import os
import logging
import asyncio
import subprocess
import tempfile
import itertools
import aiohttp
import pytest
import time
import random
from test.pylib.manager_client import ManagerClient, ServerInfo
from test.cluster.util import wait_for_cql_and_get_hosts, get_replication, new_test_keyspace
from test.pylib.rest_client import read_barrier
from test.pylib.util import unique_name, wait_all
from test.pylib.tablets import get_tablet_replica, get_all_tablet_replicas
from cassandra.cluster import ConsistencyLevel
from collections import defaultdict
from test.pylib.util import wait_for
from test.pylib.rest_client import HTTPError
from test.cluster.tasks.task_manager_client import TaskManagerClient
from test.cluster.util import wait_for_token_ring_and_group0_consistency
import statistics
logger = logging.getLogger(__name__)
async def take_snapshot(ks, servers, manager, logger):
logger.info(f'Take snapshot and collect sstables lists')
snap_name = unique_name('backup_')
sstables = dict()
await asyncio.gather(*(manager.api.flush_keyspace(s.ip_addr, ks) for s in servers))
await asyncio.gather(*(manager.api.take_snapshot(s.ip_addr, ks, snap_name) for s in servers))
for s in servers:
workdir = await manager.server_get_workdir(s.server_id)
cf_dir = os.listdir(f'{workdir}/data/{ks}')[0]
tocs = [ f.name for f in os.scandir(f'{workdir}/data/{ks}/{cf_dir}/snapshots/{snap_name}') if f.is_file() and f.name.endswith('TOC.txt') ]
logger.info(f'Collected sstables from {s.ip_addr}:{cf_dir}/snapshots/{snap_name}: {tocs}')
sstables[s] = tocs
return snap_name,sstables
async def take_snapshot_on_one_server(ks, server, manager, logger):
snap_name, sstables = await take_snapshot(ks, [server], manager, logger)
return snap_name, sstables[server]
@pytest.mark.asyncio
@pytest.mark.parametrize("move_files", [False, True])
async def test_simple_backup(manager: ManagerClient, object_storage, move_files):
'''check that backing up a snapshot for a keyspace works'''
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'snapshots=trace:task_manager=trace:api=info']
server = await manager.server_add(config=cfg, cmdline=cmd)
cql = manager.get_cql()
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
snap_name, files = await take_snapshot_on_one_server(ks, server, manager, logger)
assert len(files) > 0
workdir = await manager.server_get_workdir(server.server_id)
cf_dir = os.listdir(f'{workdir}/data/{ks}')[0]
print('Backup snapshot')
prefix = f'{cf}/backup'
tid = await manager.api.backup(server.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, prefix, move_files=move_files)
print(f'Started task {tid}')
status = await manager.api.get_task_status(server.ip_addr, tid)
print(f'Status: {status}, waiting to finish')
status = await manager.api.wait_task(server.ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
assert (status['progress_total'] > 0) and (status['progress_completed'] == status['progress_total'])
# all components in the "backup" snapshot should have been moved into bucket if move_files
assert len(os.listdir(f'{workdir}/data/{ks}/{cf_dir}/snapshots/{snap_name}')) == 0 if move_files else len(files)
objects = set(o.key for o in object_storage.get_resource().Bucket(object_storage.bucket_name).objects.all())
for f in files:
print(f'Check {f} is in backup')
assert f'{prefix}/{f}' in objects
# Check that task runs in the backup sched group
log = await manager.server_open_log(server.server_id)
res = await log.grep(r'INFO.*\[shard [0-9]:([a-z]+)\] .* Backup sstables from .* to')
assert len(res) == 1 and res[0][1].group(1) == 'bckp'
@pytest.mark.asyncio
@pytest.mark.parametrize("ne_parameter", [ "endpoint", "bucket", "snapshot" ])
async def test_backup_with_non_existing_parameters(manager: ManagerClient, object_storage, ne_parameter):
'''backup should fail if either of the parameters does not exist'''
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'snapshots=trace:task_manager=trace:api=info']
server = await manager.server_add(config=cfg, cmdline=cmd)
cql = manager.get_cql()
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
backup_snap_name, files = await take_snapshot_on_one_server(ks, server, manager, logger)
assert len(files) > 0
prefix = f'{cf}/backup'
tid = await manager.api.backup(server.ip_addr, ks, cf,
backup_snap_name if ne_parameter != 'snapshot' else 'no-such-snapshot',
object_storage.address if ne_parameter != 'endpoint' else 'no-such-endpoint',
object_storage.bucket_name if ne_parameter != 'bucket' else 'no-such-bucket',
prefix)
status = await manager.api.wait_task(server.ip_addr, tid)
assert status is not None
assert status['state'] == 'failed'
if ne_parameter == 'endpoint':
assert status['error'] == 'std::invalid_argument (endpoint no-such-endpoint not found)'
@pytest.mark.asyncio
async def test_backup_endpoint_config_is_live_updateable(manager: ManagerClient, object_storage):
'''backup should fail if the endpoint is invalid/inaccessible
after updating the config, it should succeed'''
cfg = {'enable_user_defined_functions': False,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'sstables_manager=debug']
server = await manager.server_add(config=cfg, cmdline=cmd)
cql = manager.get_cql()
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
snap_name, files = await take_snapshot_on_one_server(ks, server, manager, logger)
prefix = f'{cf}/backup'
tid = await manager.api.backup(server.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, prefix)
status = await manager.api.wait_task(server.ip_addr, tid)
assert status is not None
assert status['state'] == 'failed'
assert status['error'] == f'std::invalid_argument (endpoint {object_storage.address} not found)'
objconf = object_storage.create_endpoint_conf()
await manager.server_update_config(server.server_id, 'object_storage_endpoints', objconf)
async def endpoint_appeared_in_config():
await read_barrier(manager.api, server.ip_addr)
resp = await manager.api.get_config(server.ip_addr, 'object_storage_endpoints')
for ep in objconf:
if ep['name'] not in resp:
return None
return True
await wait_for(endpoint_appeared_in_config, deadline=time.time() + 60)
tid = await manager.api.backup(server.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, prefix)
status = await manager.api.wait_task(server.ip_addr, tid)
assert status is not None
assert status['state'] == 'done'
async def do_test_backup_helper(manager: ManagerClient, object_storage,
breakpoint_name, handler, num_servers: int = 1):
'''helper for backup abort testing'''
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'snapshots=trace:task_manager=trace:api=info']
server = (await manager.servers_add(num_servers, config=cfg, cmdline=cmd))[0]
cql = manager.get_cql()
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
snap_name, files = await take_snapshot_on_one_server(ks, server, manager, logger)
await manager.api.enable_injection(server.ip_addr, breakpoint_name, one_shot=True)
log = await manager.server_open_log(server.server_id)
mark = await log.mark()
print('Backup snapshot')
# use a unique path, because we're running more than one test using the same minio and ks/cf name.
# If we just use {cf}/backup, files like "schema.cql" and "manifest.json" will remain after previous test
# case, and we will count these erroneously.
prefix = unique_name('backup_')
tid = await manager.api.backup(server.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, prefix)
print(f'Started task {tid}, aborting it early')
await log.wait_for(breakpoint_name + ': waiting', from_mark=mark)
await handler(server, prefix, files, tid)
async def do_test_backup_abort(manager: ManagerClient, object_storage,
breakpoint_name, min_files, max_files = None):
'''helper for backup abort testing'''
async def abort_and_check(server, prefix, files, tid):
assert len(files) > 1
await manager.api.abort_task(server.ip_addr, tid)
await manager.api.message_injection(server.ip_addr, breakpoint_name)
status = await manager.api.wait_task(server.ip_addr, tid)
print(f'Status: {status}')
assert (status is not None) and (status['state'] == 'failed')
assert "seastar::abort_requested_exception (abort requested)" in status['error']
objects = set(o.key for o in object_storage.get_resource().Bucket(object_storage.bucket_name).objects.all())
uploaded_count = 0
for f in files:
in_backup = f'{prefix}/{f}' in objects
print(f'Check {f} is in backup: {in_backup}')
if in_backup:
uploaded_count += 1
# Note: since s3 client is abortable and run async, we might fail even the first file
# regardless of if we set the abort status before or after the upload is initiated.
# Parallelism is a pain.
assert min_files <= uploaded_count < len(files)
assert max_files is None or uploaded_count < max_files
await do_test_backup_helper(manager, object_storage, breakpoint_name, abort_and_check)
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_backup_is_abortable(manager: ManagerClient, object_storage):
'''check that backing up a snapshot for a keyspace works'''
await do_test_backup_abort(manager, object_storage, breakpoint_name="backup_task_pause", min_files=0)
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_backup_is_abortable_in_s3_client(manager: ManagerClient, object_storage):
'''check that backing up a snapshot for a keyspace works'''
await do_test_backup_abort(manager, object_storage, breakpoint_name="backup_task_pre_upload", min_files=0, max_files=1)
@pytest.mark.asyncio
@pytest.mark.parametrize(("do_encrypt", "do_abort"), [(False, False), (False, True), (True, False)])
async def test_simple_backup_and_restore(manager: ManagerClient, object_storage, tmpdir, do_encrypt, do_abort):
'''check that restoring from backed up snapshot for a keyspace:table works'''
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
if do_encrypt:
d = tmpdir / "system_keys"
d.mkdir()
cfg = cfg | {
'system_key_directory': str(d),
'user_info_encryption': { 'enabled': True, 'key_provider': 'LocalFileSystemKeyProviderFactory' }
}
cmd = ['--logger-log-level', 'sstables_loader=debug:sstable_directory=trace:snapshots=trace:s3=trace:sstable=debug:http=debug:encryption=debug:api=info']
server = await manager.server_add(config=cfg, cmdline=cmd)
cql = manager.get_cql()
workdir = await manager.server_get_workdir(server.server_id)
# This test is sensitive not to share the bucket with any other test
# that can run in parallel, so generate some unique name for the snapshot
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
snap_name, toc_names = await take_snapshot_on_one_server(ks, server, manager, logger)
cf_dir = os.listdir(f'{workdir}/data/{ks}')[0]
def list_sstables():
return [f for f in os.scandir(f'{workdir}/data/{ks}/{cf_dir}') if f.is_file()]
orig_res = cql.execute(f"SELECT * FROM {ks}.{cf}")
orig_rows = {x.name: x.value for x in orig_res}
# include a "suffix" in the key to mimic the use case where scylla-manager
# 1. backups sstables of multiple snapshots, and deduplicate the backup'ed
# sstables by only upload the new sstables
# 2. restore a given snapshot by collecting all sstables of this snapshot from
# multiple places
#
# in this test, we:
# 1. upload:
# prefix: {some}/{objects}/{path}
# sstables:
# - 1-TOC.txt
# - 2-TOC.txt
# - ...
# 2. download:
# prefix = {some}/{objects}/{path}
# sstables:
# - 1-TOC.txt
# - 2-TOC.txt
# - ...
prefix = f'{cf}/{snap_name}'
tid = await manager.api.backup(server.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, f'{prefix}')
status = await manager.api.wait_task(server.ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
print('Drop the table data and validate it\'s gone')
cql.execute(f"TRUNCATE TABLE {ks}.{cf};")
files = list_sstables()
assert len(files) == 0
res = cql.execute(f"SELECT * FROM {ks}.{cf};")
assert not res
objects = set(o.key for o in object_storage.get_resource().Bucket(object_storage.bucket_name).objects.filter(Prefix=prefix))
assert len(objects) > 0
print('Try to restore')
tid = await manager.api.restore(server.ip_addr, ks, cf, object_storage.address, object_storage.bucket_name, prefix, toc_names)
if do_abort:
await manager.api.abort_task(server.ip_addr, tid)
status = await manager.api.wait_task(server.ip_addr, tid)
if not do_abort:
assert status is not None
assert status['state'] == 'done'
assert status['progress_units'] == 'batches'
assert status['progress_completed'] == status['progress_total']
assert status['progress_completed'] > 0
print('Check that sstables came back')
files = list_sstables()
sstable_names = [f'{entry.name}' for entry in files if entry.name.endswith('.db')]
db_objects = [object for object in objects if object.endswith('.db')]
if do_abort:
assert len(files) >= 0
# These checks can be viewed as dubious. We restore (atm) on a mutation basis mostly.
# There is no guarantee we'll generate the same amount of sstables as was in the original
# backup (?). But, since we are not stressing the server here (not provoking memtable flushes),
# we should in principle never generate _more_ sstables than originated the backup.
tocs = [f'{entry.name}' for entry in files if entry.name.endswith('TOC.txt')]
assert len(toc_names) >= len(tocs)
assert len(sstable_names) <= len(db_objects)
else:
assert len(files) > 0
assert (status is not None) and (status['state'] == 'done')
print(f'Check that data came back too')
res = cql.execute(f"SELECT * FROM {ks}.{cf};")
rows = { x.name: x.value for x in res }
assert rows == orig_rows, "Unexpected table contents after restore"
print('Check that backup files are still there') # regression test for #20938
post_objects = set(o.key for o in object_storage.get_resource().Bucket(object_storage.bucket_name).objects.filter(Prefix=prefix))
assert objects == post_objects
async def do_abort_restore(manager: ManagerClient, object_storage):
# Define configuration for the servers.
objconf = object_storage.create_endpoint_conf()
config = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300,
}
servers = await manager.servers_add(servers_num=3, config=config, auto_rack_dc='dc1')
# Obtain the CQL interface from the manager.
cql = manager.get_cql()
# Create keyspace, table, and fill data
logger.info("Creating keyspace and table, then inserting data...")
table = 'test_cf'
async with new_test_keyspace(manager,
"WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 3}") as keyspace:
create_table_query = f"CREATE TABLE {keyspace}.{table} (name text PRIMARY KEY, value text);"
await cql.run_async(create_table_query)
insert_stmt = cql.prepare(f"INSERT INTO {keyspace}.{table} (name, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
num_keys = 10000
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), str(i))) for i in range(num_keys)))
snapshot_name, sstables = await take_snapshot(keyspace, servers, manager, logger)
# Backup the keyspace on each server to S3
prefix = f"{table}/{snapshot_name}"
logger.info(f"Backing up keyspace using prefix '{prefix}' on all servers...")
for server in servers:
backup_tid = await manager.api.backup(
server.ip_addr,
keyspace,
table,
snapshot_name,
object_storage.address,
object_storage.bucket_name,
prefix
)
backup_status = await manager.api.wait_task(server.ip_addr, backup_tid)
assert backup_status is not None and backup_status.get('state') == 'done', \
f"Backup task failed on server {server.server_id}"
# Truncate data and start restore
logger.info("Dropping table data...")
await cql.run_async(f"TRUNCATE TABLE {keyspace}.{table};")
logger.info("Initiating restore operations...")
logs = [await manager.server_open_log(server.server_id) for server in servers]
injection = "stream_mutation_fragments" # "block_load_and_stream"
await asyncio.gather(*(manager.api.enable_injection(s.ip_addr, injection, True) for s in servers))
restore_task_ids = {}
for server in servers:
restore_tid = await manager.api.restore(
server.ip_addr,
keyspace,
table,
object_storage.address,
object_storage.bucket_name,
prefix,
sstables[server]
)
restore_task_ids[server.server_id] = restore_tid
await wait_all([l.wait_for(f"{injection}: waiting", timeout=10) for l in logs])
logger.info("Aborting restore tasks...")
await asyncio.gather(*(manager.api.abort_task(server.ip_addr, restore_task_ids[server.server_id]) for server in servers))
await asyncio.gather(*(manager.api.message_injection(s.ip_addr, injection) for s in servers))
# Check final status of restore tasks
failed = False
for server in servers:
final_status = await manager.api.wait_task(server.ip_addr, restore_task_ids[server.server_id])
logger.info(f"Restore task status on server {server.server_id}: {final_status}")
assert (final_status is not None)
failed |= final_status['state'] == 'failed'
assert failed, "Expected at least one restore task to fail after aborting"
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_abort_restore_with_rpc_error(manager: ManagerClient, object_storage):
await do_abort_restore(manager, object_storage)
# Helper class to parametrize the test below
class topo:
def __init__(self, rf, nodes, racks, dcs):
self.rf = rf
self.nodes = nodes
self.racks = racks
self.dcs = dcs
async def create_cluster(topology, manager, logger, object_storage=None):
rf_rack_valid_keyspaces = (topology.rf <= topology.racks)
logger.info(f'Start cluster with {topology.nodes} nodes in {topology.dcs} DCs, {topology.racks} racks, rf_rack_valid_keyspaces: {rf_rack_valid_keyspaces}')
cfg = {'task_ttl_in_seconds': 300, 'rf_rack_valid_keyspaces': rf_rack_valid_keyspaces}
if object_storage:
objconf = object_storage.create_endpoint_conf()
cfg['object_storage_endpoints'] = objconf
cmd = [ '--logger-log-level', 'sstables_loader=debug:sstable_directory=trace:snapshots=trace:s3=trace:sstable=debug:http=debug:api=info' ]
servers = []
host_ids = {}
cur_dc = 0
cur_rack = 0
for s in range(topology.nodes):
dc = f"dc{cur_dc}"
rack = f"rack{cur_rack}"
cur_dc += 1
if cur_dc >= topology.dcs:
cur_dc = 0
cur_rack += 1
if cur_rack >= topology.racks:
cur_rack = 0
s = await manager.server_add(config=cfg, cmdline=cmd, property_file={'dc': dc, 'rack': rack})
logger.info(f'Created node {s.ip_addr} in {dc}.{rack}')
servers.append(s)
host_ids[s.server_id] = await manager.get_host_id(s.server_id)
return servers,host_ids
async def do_restore_server(manager, logger, ks, cf, s, toc_names, scope, primary_replica_only, prefix, object_storage):
logger.info(f'Restore {s.ip_addr} with {toc_names}, scope={scope}')
tid = await manager.api.restore(s.ip_addr, ks, cf, object_storage.address, object_storage.bucket_name, prefix, toc_names, scope, primary_replica_only=primary_replica_only)
status = await manager.api.wait_task(s.ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
async def check_streaming_directions(logger, servers, topology, host_ids, scope, primary_replica_only, log_marks):
host_ids_per_dc = defaultdict(list)
host_ids_per_dc_rack = dict()
servers_by_host_id = dict()
for s in servers:
host = host_ids[s.server_id]
host_ids_per_dc[s.datacenter].append(host)
host_ids_per_dc_rack.setdefault(s.datacenter, defaultdict(list))[s.rack].append(host)
servers_by_host_id[host] = s
logger.info(f'Validate streaming directions')
for s in servers:
streamed_to = defaultdict(int)
log, mark = log_marks[s.server_id]
direct_downloads = await log.grep('sstables_loader - Adding downloaded SSTables to the table', from_mark=mark)
res = await log.grep(r'sstables_loader - load_and_stream:.*target_node=(?P<target_host_id>[0-9a-f-]+)', from_mark=mark)
for r in res:
target_host_id = r[1].group('target_host_id')
if scope == 'all':
assert target_host_id in servers_by_host_id.keys()
elif scope == 'dc':
assert servers_by_host_id[target_host_id].datacenter == s.datacenter
elif scope == 'rack':
assert servers_by_host_id[target_host_id].datacenter == s.datacenter
assert servers_by_host_id[target_host_id].rack == s.rack
elif scope == 'node':
assert target_host_id == str(host_ids[s.server_id])
streamed_to[target_host_id] += 1
# validate balance only when rf == #racks
if topology.rf != topology.racks:
logger.info(f'Skipping balance checks since rf != racks ({topology.rf} != {topology.racks})')
continue
if scope == 'all':
assert set(streamed_to.keys()).issubset(set(host_ids.values()))
assert len(streamed_to) == topology.rf * topology.dcs
elif scope == 'dc':
# it's guaranteed the node replicated only within the datacenter by asserts above
assert set(streamed_to.keys()).issubset(set(host_ids_per_dc[s.datacenter]))
assert len(streamed_to) == topology.rf
elif scope == 'rack' and topology.rf == topology.racks:
assert set(streamed_to.keys()).issubset(set(host_ids_per_dc_rack[s.datacenter][s.rack]))
assert len(streamed_to) == 1
# asses balance
streamed_to_counts = streamed_to.values()
if scope == 'node' and len(streamed_to_counts) == 0:
assert len(direct_downloads) > 0
else:
assert len(streamed_to_counts) > 0
mean_count = statistics.mean(streamed_to_counts)
max_deviation = max(abs(count - mean_count) for count in streamed_to_counts)
if not primary_replica_only:
assert max_deviation == 0, f'if primary_replica_only is False, streaming should be perfectly balanced: {streamed_to}'
continue
assert max_deviation < 0.1 * mean_count, f'node {s.ip_addr} streaming to primary replicas was unbalanced: {streamed_to}'
def distribute_sstables(sstables, servers, topology, scope):
sstables_per_server = defaultdict(list)
# rf_rack_valid can be True also with rack lists
rf_rack_valid = topology.rf == topology.racks
if scope == 'all' or scope == 'dc' or not rf_rack_valid:
sstables_per_dc = defaultdict(list)
for s, sstables_list in sstables.items():
sstables_per_dc[s.datacenter].extend(sstables_list)
servers_per_dc = defaultdict(list)
for s in servers:
servers_per_dc[s.datacenter].append(s)
for dc, sstables_in_dc in sstables_per_dc.items():
for s in servers_per_dc[dc]:
if scope == 'node':
# If not rf_rack_valid, each node should load data from all sstables in the DC
# Otherwise, as done in the case below, each node load data from all sstables in its rack
# (since it is ensured that every rack has a replica of each mutation)
sstables_per_server[s] = sstables_in_dc
else:
sstables_per_server[s] = sstables[s]
elif scope == 'rack' or scope == 'node':
servers_per_dc_rack = dict()
sstables_per_dc_rack = dict()
for s, sstables_list in sstables.items():
servers_per_dc_rack.setdefault(s.datacenter, defaultdict(list))[s.rack].append(s)
sstables_per_dc_rack.setdefault(s.datacenter, defaultdict(list))[s.rack].extend(sstables_list)
for dc, racks in sstables_per_dc_rack.items():
for rack, sstables_in_rack in racks.items():
if scope == 'rack':
assert topology.rf == topology.racks
for s in servers_per_dc_rack[dc][rack]:
sstables_per_server[s] = sstables[s]
else:
assert scope == 'node'
for s in servers_per_dc_rack[dc][rack]:
sstables_per_server[s] = sstables_in_rack
else:
raise f"distribute_sstables: {scope=} not supported"
return sstables_per_server
async def do_backup(s, snap_name, prefix, ks, cf, object_storage, manager, logger):
logger.info(f'Backup to {snap_name}')
tid = await manager.api.backup(s.ip_addr, ks, cf, snap_name, object_storage.address, object_storage.bucket_name, prefix)
status = await manager.api.wait_task(s.ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
async def collect_mutations(cql, server, manager, ks, cf):
host = await wait_for_cql_and_get_hosts(cql, [server], time.time() + 30)
await read_barrier(manager.api, server.ip_addr) # scylladb/scylladb#18199
ret = defaultdict(list)
for frag in await cql.run_async(f"SELECT * FROM MUTATION_FRAGMENTS({ks}.{cf})", host=host[0]):
ret[frag.pk].append({'mutation_source': frag.mutation_source, 'partition_region': frag.partition_region, 'node': server.ip_addr})
return ret
async def check_mutation_replicas(cql, manager, servers, keys, topology, logger, ks, cf, expected_replicas = None):
'''Check that each mutation is replicated to the expected number of replicas'''
if expected_replicas is None:
expected_replicas = topology.rf * topology.dcs
mutations = defaultdict(list)
by_node = await asyncio.gather(*(collect_mutations(cql, s, manager, ks, cf) for s in servers))
for node_frags in by_node:
for pk, frags in node_frags.items():
mutations[pk].append(frags)
for k in random.sample(keys, 10):
if not str(k) in mutations:
logger.info(f'Mutations: {mutations}')
assert False, f"Key '{k}' not found in mutations. {topology=}"
if len(mutations[str(k)]) != expected_replicas:
logger.info(f'Mutations: {mutations}')
assert False, f"'{k}' is replicated {len(mutations[str(k)])} times, expected {expected_replicas}"
async def mark_all_logs(manager, servers):
log_marks = dict()
for s in servers:
log = await manager.server_open_log(s.server_id)
log_marks[s.server_id] = (log, await log.mark())
return log_marks
class SSTablesOnObjectStorage:
def __init__(self, object_storage):
self.object_storage = object_storage
async def save(self, manager, servers, snap_name, prefix, ks, cf, logger):
await asyncio.gather(*(do_backup(s, snap_name, prefix, ks, cf, self.object_storage, manager, logger) for s in servers))
async def restore(self, manager, sstables_per_server, prefix, ks, cf, scope, primary_replica_only, logger):
await asyncio.gather(*(do_restore_server(manager, logger, ks, cf, s, sstables, scope, primary_replica_only, prefix, self.object_storage) for s, sstables in sstables_per_server.items()))
@pytest.mark.asyncio
@pytest.mark.parametrize("topology", [
topo(rf = 1, nodes = 3, racks = 1, dcs = 1),
topo(rf = 3, nodes = 5, racks = 1, dcs = 1),
topo(rf = 1, nodes = 4, racks = 2, dcs = 1),
topo(rf = 3, nodes = 6, racks = 2, dcs = 1),
topo(rf = 2, nodes = 8, racks = 4, dcs = 2)
])
async def test_restore_with_streaming_scopes(build_mode: str, manager: ManagerClient, object_storage, topology):
'''Check that restoring of a cluster with stream scopes works'''
await do_test_streaming_scopes(build_mode, manager, topology, SSTablesOnObjectStorage(object_storage))
async def do_test_streaming_scopes(build_mode: str, manager: ManagerClient, topology, sstables_storage):
'''
This test creates a cluster specified by the topology parameter above,
configurable number of nodes, tacks, datacenters, and replication factor.
It creates a dataset, takes a snapshot and copies the sstables of all nodes to a temporary
location. It then truncates the table so all sstables are gone, copies all the sstables into
each node's upload directory, and refreshes the nodes given the scope passed as the test parameter.
The test then performs two types of checks:
1) Check that the data is back in the table by getting all mutations from the nodes and checking
that a random sample of them contains the expected key and that they are replicated according to RF * DCS factor.
2) Check that the streaming communication between nodes is as expected according to the scope parameter of the test.
This stage parses the logs and checks that the data was streamed to nodes within the configured scope.
'''
servers, host_ids = await create_cluster(topology, manager, logger, sstables_storage.object_storage)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
num_keys = 10
original_min_tablet_count=5
scopes = ['rack', 'dc'] if build_mode == 'debug' else ['all', 'dc', 'rack', 'node']
pros = [ True, False ] # Primary Replica Only
restored_min_tablet_counts = [original_min_tablet_count] if (build_mode == 'debug' or sstables_storage.object_storage is None) else [2, original_min_tablet_count, 10]
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {original_min_tablet_count}}};")
insert_stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), i)) for i in range(num_keys)))
# validate replicas assertions hold on fresh dataset
await check_mutation_replicas(cql, manager, servers, range(num_keys), topology, logger, ks, 'test')
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
prefix = f'test/{snap_name}'
await sstables_storage.save(manager, servers, snap_name, prefix, ks, 'test', logger)
for scope, pro, restored_min_tablet_count in itertools.product(scopes, pros, restored_min_tablet_counts):
if scope == 'node' and pro == True:
continue
# We can support rack-aware restore with rack lists, if we restore the rack-list per dc as it was at backup time.
# Otherwise, with numeric replication_factor we'd pick arbitrary subset of the racks when the keyspace
# is initially created and an arbitrary subset or the rack at restore time.
if scope == 'rack' and topology.rf != topology.racks:
logger.info(f'Skipping scope={scope} test since rf={topology.rf} != racks={topology.racks} and it cannot be supported with numeric replication_factor')
continue
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {restored_min_tablet_count}}};")
log_marks = await mark_all_logs(manager, servers)
logger.info(f'Loading {servers=} with {sstables=} scope={scope}')
sstables_per_server = distribute_sstables(sstables, servers, topology, scope)
await sstables_storage.restore(manager, sstables_per_server, prefix, ks, 'test', scope, pro, logger)
if pro:
await manager.api.tablet_repair(servers[0].ip_addr, ks, 'test', 'all', timeout=600)
await check_mutation_replicas(cql, manager, servers, range(num_keys), topology, logger, ks, 'test')
if restored_min_tablet_count == original_min_tablet_count:
await check_streaming_directions(logger, servers, topology, host_ids, scope, pro, log_marks)
@pytest.mark.asyncio
@pytest.mark.parametrize("topology", [
topo(rf = 1, nodes = 3, racks = 1, dcs = 1),
topo(rf = 2, nodes = 2, racks = 2, dcs = 1),
])
async def test_restore_tablets(build_mode: str, manager: ManagerClient, object_storage, topology):
'''Check that restoring of a cluster using tablet-aware restore works'''
servers, host_ids = await create_cluster(topology, manager, logger, object_storage)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
num_keys = 10
tablet_count=5
tablet_count_for_restore=8 # should be tablet_count rounded up to the power of two
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}}};")
insert_stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), i)) for i in range(num_keys)))
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
await asyncio.gather(*(do_backup(s, snap_name, f'{s.server_id}/{snap_name}', ks, 'test', object_storage, manager, logger) for s in servers))
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count_for_restore}, 'max_tablet_count': {tablet_count_for_restore}}};")
logger.info(f'Restore cluster via {servers[1].ip_addr}')
manifests = [ f'{s.server_id}/{snap_name}/manifest.json' for s in servers ]
tid = await manager.api.restore_tablets(servers[1].ip_addr, ks, 'test', snap_name, servers[0].datacenter, object_storage.address, object_storage.bucket_name, manifests)
status = await manager.api.wait_task(servers[1].ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
await check_mutation_replicas(cql, manager, servers, range(num_keys), topology, logger, ks, 'test')
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_restore_tablets_vs_migration(build_mode: str, manager: ManagerClient, object_storage):
'''Check that restore handles tablets migrating around'''
topology = topo(rf = 1, nodes = 2, racks = 1, dcs = 1)
servers, host_ids = await create_cluster(topology, manager, logger, object_storage)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
num_keys = 10
tablet_count=4
tablet_count_for_restore=4
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}}};")
insert_stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), i)) for i in range(num_keys)))
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
await asyncio.gather(*(do_backup(s, snap_name, f'{s.server_id}/{snap_name}', ks, 'test', object_storage, manager, logger) for s in servers))
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count_for_restore}, 'max_tablet_count': {tablet_count_for_restore}}};")
s0_host_id = await manager.get_host_id(servers[0].server_id)
s1_host_id = await manager.get_host_id(servers[1].server_id)
tablet = (await get_all_tablet_replicas(manager, servers[0], ks, 'test'))[0]
current = tablet.replicas[0]
target = s1_host_id if current[0] == s0_host_id else s0_host_id
await asyncio.gather(*[manager.api.enable_injection(s.ip_addr, "block_tablet_streaming", False, parameters={'keyspace': ks, 'table': 'test'}) for s in servers])
migration_task = asyncio.create_task(manager.api.move_tablet(servers[0].ip_addr, ks, "test", current[0], current[1], target, 0, tablet.last_token))
logger.info(f'Restore cluster via {servers[1].ip_addr}')
manifests = [ f'{s.server_id}/{snap_name}/manifest.json' for s in servers ]
tid = await manager.api.restore_tablets(servers[1].ip_addr, ks, 'test', snap_name, servers[0].datacenter, object_storage.address, object_storage.bucket_name, manifests)
await asyncio.gather(*[manager.api.message_injection(s.ip_addr, f"block_tablet_streaming") for s in servers])
status = await manager.api.wait_task(servers[1].ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
await migration_task
await check_mutation_replicas(cql, manager, servers, range(num_keys), topology, logger, ks, 'test')
@pytest.mark.asyncio
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_restore_tablets_download_failure(build_mode: str, manager: ManagerClient, object_storage):
'''Check that failure to download an sstable propagates back to API'''
topology = topo(rf = 1, nodes = 2, racks = 1, dcs = 1)
servers, host_ids = await create_cluster(topology, manager, logger, object_storage)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
num_keys = 12
tablet_count=4
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}}};")
insert_stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), i)) for i in range(num_keys)))
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
await asyncio.gather(*(do_backup(s, snap_name, f'{s.server_id}/{snap_name}', ks, 'test', object_storage, manager, logger) for s in servers))
await manager.api.enable_injection(servers[1].ip_addr, "fail_download_sstable", one_shot=True)
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}, 'max_tablet_count': {tablet_count}}};")
logger.info(f'Restore cluster via {servers[0].ip_addr}')
manifests = [ f'{s.server_id}/{snap_name}/manifest.json' for s in servers ]
tid = await manager.api.restore_tablets(servers[0].ip_addr, ks, 'test', snap_name, servers[0].datacenter, object_storage.address, object_storage.bucket_name, manifests)
status = await manager.api.wait_task(servers[0].ip_addr, tid)
assert 'state' in status and status['state'] == 'failed'
assert 'error' in status and 'Failed to download' in status['error']
@pytest.mark.asyncio
@pytest.mark.parametrize("target", ['coordinator', 'replica', 'api'])
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_restore_tablets_node_loss_resiliency(build_mode: str, manager: ManagerClient, object_storage, target):
'''Check how restore handler node loss in the middle of operation'''
topology = topo(rf = 2, nodes = 4, racks = 2, dcs = 1)
servers, host_ids = await create_cluster(topology, manager, logger, object_storage)
log = await manager.server_open_log(servers[0].server_id)
await log.wait_for("raft_topology - start topology coordinator fiber", timeout=10)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
num_keys = 24
tablet_count=8
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}}};")
insert_stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, value) VALUES (?, ?)")
insert_stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(insert_stmt, (str(i), i)) for i in range(num_keys)))
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
await asyncio.gather(*(do_backup(s, snap_name, f'{s.server_id}/{snap_name}', ks, 'test', object_storage, manager, logger) for s in servers))
async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test ( pk text primary key, value int ) WITH tablets = {{'min_tablet_count': {tablet_count}, 'max_tablet_count': {tablet_count}}};")
await manager.api.enable_injection(servers[2].ip_addr, "pause_tablet_restore", one_shot=True)
log = await manager.server_open_log(servers[2].server_id)
mark = await log.mark()
manifests = [ f'{s.server_id}/{snap_name}/manifest.json' for s in servers ]
tid = await manager.api.restore_tablets(servers[1].ip_addr, ks, 'test', snap_name, servers[0].datacenter, object_storage.address, object_storage.bucket_name, manifests)
await log.wait_for("pause_tablet_restore: waiting for message", from_mark=mark)
if target == 'api':
await manager.server_stop(servers[1].server_id)
with pytest.raises(aiohttp.client_exceptions.ClientConnectorError):
await manager.api.wait_task(servers[1].ip_addr, tid)
else:
if target == 'coordinator':
await manager.server_stop(servers[0].server_id)
await manager.api.message_injection(servers[2].ip_addr, "pause_tablet_restore")
elif target == 'replica':
await manager.server_stop(servers[2].server_id)
# Sometimes killing nodes manage to restore tablets before being killed
# So the best thing to do is to make sure restore task finishes at all
await asyncio.wait_for(manager.api.wait_task(servers[1].ip_addr, tid), timeout=60)
@pytest.mark.asyncio
async def test_restore_with_non_existing_sstable(manager: ManagerClient, object_storage):
'''Check that restore task fails well when given a non-existing sstable'''
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'snapshots=trace:task_manager=trace:api=info']
server = await manager.server_add(config=cfg, cmdline=cmd)
cql = manager.get_cql()
print('Create keyspace')
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
sstable_name = 'me-3gou_0fvw_4r94g2h8nw60b8ly4c-big-TOC.txt'
tid = await manager.api.restore(server.ip_addr, ks, cf, object_storage.address, object_storage.bucket_name, 'no_such_prefix', [sstable_name])
status = await manager.api.wait_task(server.ip_addr, tid)
print(f'Status: {status}')
assert 'state' in status and status['state'] == 'failed'
assert 'error' in status and 'Not Found' in status['error']
@pytest.mark.asyncio
async def test_backup_broken_streaming(manager: ManagerClient, s3_storage):
# Define configuration for the servers.
objconf = s3_storage.create_endpoint_conf()
config = {
'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300,
}
cmd = ['--smp', '1', '--logger-log-level', 'sstables_loader=debug:sstable=debug']
server = await manager.server_add(config=config, cmdline=cmd)
# Obtain the CQL interface from the manager.
cql = manager.get_cql()
scylla_path = await manager.server_get_exe(server.server_id)
async with new_test_keyspace(manager,
"WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1}") as keyspace:
table = 'test_cf'
create_table_query = (
f"CREATE TABLE {keyspace}.{table} (name text PRIMARY KEY, value text) "
f"WITH tablets = {{'min_tablet_count': '16'}};"
)
cql.execute(create_table_query)
expected_rows = 0
with tempfile.TemporaryDirectory() as tmp_dir:
resource_dir = "test/resource/sstables/fully_partially_contained_ssts"
schema_file = os.path.join(tmp_dir, "schema.cql")
with open(schema_file, "w") as f:
f.write(f"CREATE TABLE {keyspace}.{table} (name text PRIMARY KEY, value text)")
f.flush()
for root, _, files in os.walk(resource_dir):
for file in files:
local_path = os.path.join(root, file)
print("Processing file:", local_path)
sst_generation = subprocess.check_output(
[scylla_path, "sstable", "write", "--schema-file", schema_file, "--input-format", "json",
"--output-dir", tmp_dir, "--input-file", local_path]).decode().strip()
sst_path = glob.glob(f"{tmp_dir}/??-{sst_generation}-???-TOC.txt")[0]
expected_rows += json.loads(subprocess.check_output(
[scylla_path, "sstable", "query", "-q", f"SELECT COUNT(*) FROM scylla_sstable.{table}",
"--output-format", "json", "--sstables",
sst_path]).decode())[0]['count']
prefix = unique_name('/test/streaming_')
s3_resource = s3_storage.get_resource()
bucket = s3_resource.Bucket(s3_storage.bucket_name)
sstables = []
print(f"Uploading files from '{tmp_dir}' to prefix '{prefix}':")
for root, _, files in os.walk(tmp_dir):
for file in files:
if file.endswith("-TOC.txt"):
sstables.append(file)
local_path = os.path.join(root, file)
s3_key = f"{prefix}/{file}"
print(f" - Uploading {local_path} to {s3_key}")
bucket.upload_file(local_path, s3_key)
restore_task_id = await manager.api.restore(
server.ip_addr, keyspace, table,
s3_storage.address, s3_storage.bucket_name,
prefix, sstables, "node"
)
status = await manager.api.wait_task(server.ip_addr, restore_task_id)
assert status and status.get(
'state') == 'done', f"Restore task failed on server {server.server_id}. Reason {status}"
res = cql.execute(f"SELECT COUNT(*) FROM {keyspace}.{table} BYPASS CACHE USING TIMEOUT 600s;")
row = res.one()
assert row.count == expected_rows, f"number of rows after restore is incorrect: {row.count}"
log = await manager.server_open_log(server.server_id)
await log.wait_for("fully contained SSTables to local node from object storage", timeout=10)
# just make sure we had partially contained sstables as well
await log.wait_for("partially contained SSTables", timeout=10)
@pytest.mark.asyncio
@pytest.mark.parametrize("domain", ['rack', 'dc'])
@pytest.mark.parametrize("scope_is_same", [True, False])
async def test_restore_primary_replica(manager: ManagerClient, object_storage, domain, scope_is_same):
'''Check that restoring with primary_replica_only streams to the correct primary replica(s) depending on scope.
When scope matches the node's own domain (scope_is_same=True):
- scope equals the domain itself, so streaming is confined within the same rack/DC.
- Each mutation exists exactly 2 times in the cluster, once per domain.
- Each streaming operation targets exactly one node, which must be within the same domain.
When scope is wider than the node's own domain (scope_is_same=False):
- scope is set to "dc" (for rack domain) or "all" (for dc domain), allowing cross-domain streaming.
- Each mutation exists exactly 1 time in the cluster.
- Each restoring node streams to exactly 2 distinct nodes, as the primary replica may fall in either domain.'''
dcs = 1 if domain == 'rack' else 2
if scope_is_same:
topology = topo(rf = 4, nodes = 8, racks = 2, dcs = dcs)
scope = domain
expected_replicas = 2
else:
if domain == 'rack':
topology = topo(rf = 2, nodes = 2, racks = 2, dcs = dcs)
scope = "dc"
else:
topology = topo(rf = 1, nodes = 2, racks = 1, dcs = dcs)
scope = "all"
expected_replicas = 1
cf = 'cf'
keys = range(256)
replication_str = f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {topology.rf}}}"
servers, host_ids = await create_cluster(topology, manager, logger, object_storage)
await manager.disable_tablet_balancing()
cql = manager.get_cql()
async with new_test_keyspace(manager, replication_str) as ks:
cql.execute(f"CREATE TABLE {ks}.{cf} ( pk text primary key, value int );")
stmt = cql.prepare(f"INSERT INTO {ks}.{cf} ( pk, value ) VALUES (?, ?)")
stmt.consistency_level = ConsistencyLevel.ALL
await asyncio.gather(*(cql.run_async(stmt, (str(k), k)) for k in keys))
# validate replicas assertions hold on fresh dataset
await check_mutation_replicas(cql, manager, servers, keys, topology, logger, ks, cf)
snap_name, sstables = await take_snapshot(ks, servers, manager, logger)
prefix = f'{cf}/{snap_name}'
await asyncio.gather(*(do_backup(s, snap_name, prefix, ks, cf, object_storage, manager, logger) for s in servers))
async with new_test_keyspace(manager, replication_str) as ks:
cql.execute(f"CREATE TABLE {ks}.{cf} ( pk text primary key, value int );")
await asyncio.gather(*(do_restore_server(manager, logger, ks, cf, s, sstables[s], scope, True, prefix, object_storage) for s in servers))
await check_mutation_replicas(cql, manager, servers, keys, topology, logger, ks, cf, expected_replicas=expected_replicas)
logger.info(f'Validate streaming directions')
for i, s in enumerate(servers):
log = await manager.server_open_log(s.server_id)
res = await log.grep(r'INFO.*sstables_loader - load_and_stream: ops_uuid=([0-9a-z-]+).*target_node=([0-9a-z-]+),.*num_bytes_sent=([0-9]+)')
nodes_by_operation = defaultdict(list)
for r in res:
nodes_by_operation[r[1].group(1)].append(r[1].group(2))
def same_domain(s1, s2):
if domain == 'rack':
return s1.rack == s2.rack
else:
return s1.datacenter == s2.datacenter
if not scope_is_same:
streamed_to = set(node for nodes in nodes_by_operation.values() for node in nodes)
logger.info(f'{s.ip_addr} {host_ids[s.server_id]} streamed to {streamed_to}')
assert len(streamed_to) == 2
else:
scope_nodes = set([ str(host_ids[s.server_id]) for s in servers if same_domain(s, servers[i]) ])
for op, nodes in nodes_by_operation.items():
logger.info(f'Operation {op} streamed to nodes {nodes}')
assert len(nodes) == 1, "Each streaming operation should stream to exactly one primary replica"
assert nodes[0] in scope_nodes, f"Primary replica should be within the scope {scope}"
@pytest.mark.skip_mode(mode='release', reason='error injections are not supported in release mode')
async def test_decommision_waits_for_backup(manager: ManagerClient, object_storage):
'''check that backing up a snapshot for a keyspace blocks decommission'''
async def decommission_and_check(server: ServerInfo, prefix: str, files, tid):
log = await manager.server_open_log(server.server_id)
mark = await log.mark()
async def finish_backup():
# wait for snapshot to stop on waiting for backup
await log.wait_for("Waiting for snapshot/backup tasks to finish", from_mark=mark)
mark2 = await log.mark()
# let the backup run and finish
await manager.api.message_injection(server.ip_addr, "backup_task_pre_upload")
status = await manager.api.wait_task(server.ip_addr, tid)
assert (status is not None) and (status['state'] == 'done')
objects = set(o.key for o in object_storage.get_resource().Bucket(object_storage.bucket_name).objects.all())
uploaded_count = 0
# all files should be uploaded. note: can be zero due to two nodes
for f in files:
in_backup = f'{prefix}/{f}' in objects
print(f'Check {f} is in backup: {in_backup}')
if in_backup:
uploaded_count += 1
assert uploaded_count == len(files)
# Now wait for decommission to finish
await log.wait_for("DECOMMISSIONING: disabled backup and snapshots", from_mark=mark2)
await asyncio.gather(manager.decommission_node(server.server_id), finish_backup())
await do_test_backup_helper(manager, object_storage, "backup_task_pre_upload", decommission_and_check, 2)
async def test_aborted_decommision_reenables_snapshot(manager: ManagerClient, object_storage):
"""
Tests that an aborted decommission will still allow snapshots
"""
num_servers = 2
objconf = object_storage.create_endpoint_conf()
cfg = {'enable_user_defined_functions': False,
'object_storage_endpoints': objconf,
'experimental_features': ['keyspace-storage-options'],
'task_ttl_in_seconds': 300
}
cmd = ['--logger-log-level', 'snapshots=trace:task_manager=trace:api=info']
servers = (await manager.servers_add(num_servers, config=cfg, cmdline=cmd))
cql = manager.get_cql()
cf = 'test_cf'
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': '1'}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.{cf} ( name text primary key, value text );")
await asyncio.gather(*(cql.run_async(f"INSERT INTO {ks}.{cf} ( name, value ) VALUES ('{name}', '{value}');") for name, value in [('0', 'zero'), ('1', 'one'), ('2', 'two')]))
await manager.server_sees_others(servers[1].server_id, 1)
async def abort_decommission():
tm = TaskManagerClient(manager.api)
while True:
logger.info("Listing tasks in %s", servers[1])
tasks = await tm.list_tasks(servers[1].ip_addr, "node_ops")
for t in tasks:
if t.type == 'decommission':
logger.debug("Found decommission task. Aborting...")
await tm.abort_task(servers[1].ip_addr, t.task_id)
for s in servers:
await manager.api.message_injection(s.ip_addr, "topology_coordinator_before_leave")
try:
logger.debug("Checking decommission task status")
status = await tm.wait_for_task(servers[1].ip_addr, t.task_id)
logger.debug("Task status %s", status)
return status.state != "done"
except:
return False
await asyncio.sleep(.1)
async def decommission():
try:
logger.info("Decommissioning %s", servers[0])
await manager.api.decommission_node(servers[0].ip_addr, 1000)
except Exception as e:
logger.error("Exception in decommission %s", e)
pass
for s in servers:
await manager.api.enable_injection(s.ip_addr, "topology_coordinator_before_leave", one_shot=True)
_, aborted = await asyncio.gather(decommission(), abort_decommission())
assert aborted, "Injection point sync should ensure we abort decommission"
logger.info("Decommissioned was aborted. Creating snapshot")
await wait_for_token_ring_and_group0_consistency(manager, time.time() + 30)
await take_snapshot_on_one_server(ks, servers[0], manager, logger)