diff --git a/raft/server.cc b/raft/server.cc index 16c149e4dc..950badd131 100644 --- a/raft/server.cc +++ b/raft/server.cc @@ -1160,6 +1160,10 @@ future<> server_impl::applier_fiber() { // of taking snapshots ourselves but comparing our last index directly with what's currently in _fsm. auto last_snap_idx = _fsm->log_last_snapshot_idx(); + // Error injection to be set with one_shot + utils::get_local_injector().inject("raft_server_snapshot_reduce_threshold", + [this] { _config.snapshot_threshold = 3; _config.snapshot_trailing = 1; }); + if (_applied_idx > last_snap_idx && (_applied_idx - last_snap_idx >= _config.snapshot_threshold || _fsm->log_memory_usage() >= _config.snapshot_threshold_log_size)) diff --git a/test/pylib/rest_client.py b/test/pylib/rest_client.py index 50b1cb74f4..f7634caad0 100644 --- a/test/pylib/rest_client.py +++ b/test/pylib/rest_client.py @@ -217,12 +217,13 @@ class ScyllaRESTAPIClient(): @asynccontextmanager -async def inject_error(api: ScyllaRESTAPIClient, node_ip: IPAddress, injection: str, - one_shot: bool): +async def inject_error(api: ScyllaRESTAPIClient, node_ip: IPAddress, injection: str): """Attempts to inject an error. Works only in specific build modes: debug,dev,sanitize. It will trigger a test to be skipped if attempting to enable an injection has no effect. + This is a context manager for enabling and disabling when done, therefore it can't be + used for one shot. """ - await api.enable_injection(node_ip, injection, one_shot) + await api.enable_injection(node_ip, injection, False) enabled = await api.get_enabled_injections(node_ip) logging.info(f"Error injections enabled on {node_ip}: {enabled}") if not enabled: @@ -232,3 +233,15 @@ async def inject_error(api: ScyllaRESTAPIClient, node_ip: IPAddress, injection: finally: logger.info(f"Disabling error injection {injection}") await api.disable_injection(node_ip, injection) + + +async def inject_error_one_shot(api: ScyllaRESTAPIClient, node_ip: IPAddress, injection: str): + """Attempts to inject an error. Works only in specific build modes: debug,dev,sanitize. + It will trigger a test to be skipped if attempting to enable an injection has no effect. + This is a one-shot injection enable. + """ + await api.enable_injection(node_ip, injection, True) + enabled = await api.get_enabled_injections(node_ip) + logging.info(f"Error injections enabled on {node_ip}: {enabled}") + if not enabled: + pytest.skip("Error injection not enabled in Scylla - try compiling in dev/debug/sanitize mode") diff --git a/test/topology/test_snapshot.py b/test/topology/test_snapshot.py new file mode 100644 index 0000000000..20443f99c6 --- /dev/null +++ b/test/topology/test_snapshot.py @@ -0,0 +1,60 @@ +# +# Copyright (C) 2023-present ScyllaDB +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +""" +Test snapshot transfer by forcing threshold and performing schema changes +""" +import asyncio +import logging +from test.pylib.rest_client import inject_error_one_shot, inject_error +import pytest +from cassandra.query import SimpleStatement # type: ignore # pylint: disable=no-name-in-module + + +logger = logging.getLogger(__name__) + + +@pytest.mark.asyncio +async def test_snapshot(manager, random_tables): + """ + Cluster A, B, C + with reduced snapshot threshold create table, do several schema changes. + Start a new server D and it should get a snapshot on bootstrap. + Then stop A B C and query D to check it sees the correct table schema (verify_schema). + """ + server_a, server_b, server_c = await manager.running_servers() + await manager.mark_dirty() + # Reduce the snapshot thresholds + errs = [inject_error_one_shot(manager.api, s.ip_addr, 'raft_server_snapshot_reduce_threshold') + for s in [server_a, server_b, server_c]] + await asyncio.gather(*errs) + + t = await random_tables.add_table(ncolumns=5, pks=1) + + for i in range(3): + await t.add_column() + + manager.driver_close() + server_d = await manager.server_add() + logger.info("Started D %s", server_d) + + logger.info("Stopping A %s, B %s, and C %s", server_a, server_b, server_c) + await asyncio.gather(*[manager.server_stop_gracefully(s.server_id) + for s in [server_a, server_b, server_c]]) + + logger.info("Driver connecting to D %s", server_d) + await manager.driver_connect() + + await random_tables.verify_schema() + + # Start servers to have quorum for post-test checkup + # TODO: remove once there's a way to disable post-test checkup + manager.driver_close() + logger.info("Starting A %s", server_a) + await manager.server_start(server_a.server_id) + logger.info("Starting B %s", server_b) + await manager.server_start(server_b.server_id) + await manager.driver_connect() + logger.info("Test DONE") diff --git a/test/topology/test_topology.py b/test/topology/test_topology.py index ff61420241..257788002f 100644 --- a/test/topology/test_topology.py +++ b/test/topology/test_topology.py @@ -19,7 +19,7 @@ from test.pylib.scylla_cluster import ReplaceConfig from test.pylib.manager_client import ManagerClient from cassandra.cluster import Session from test.pylib.random_tables import RandomTables -from test.pylib.rest_client import inject_error +from test.pylib.rest_client import inject_error_one_shot logger = logging.getLogger(__name__) @@ -256,14 +256,14 @@ async def test_remove_garbage_group0_members(manager: ManagerClient, random_tabl logging.info(f'removenode {servers[0]} using {servers[1]}') # removenode will fail after removing the server from the token ring, # but before removing it from group 0 - async with inject_error(manager.api, servers[1].ip_addr, - 'removenode_fail_before_remove_from_group0', one_shot=True): - try: - await manager.remove_node(servers[1].server_id, servers[0].server_id) - except Exception: - # Note: the exception returned here is only '500 internal server error', - # need to look in test.py log for the actual message coming from Scylla. - logging.info(f'expected exception during injection') + await inject_error_one_shot(manager.api, servers[1].ip_addr, + 'removenode_fail_before_remove_from_group0') + try: + await manager.remove_node(servers[1].server_id, servers[0].server_id) + except Exception: + # Note: the exception returned here is only '500 internal server error', + # need to look in test.py log for the actual message coming from Scylla. + logging.info(f'expected exception during injection') # Query the storage_service/host_id endpoint to calculate a list of known token ring members' Host IDs # (internally, this endpoint uses token_metadata) diff --git a/test/topology_raft_disabled/test_raft_upgrade.py b/test/topology_raft_disabled/test_raft_upgrade.py index 1f384da4b1..1d0d35ee9b 100644 --- a/test/topology_raft_disabled/test_raft_upgrade.py +++ b/test/topology_raft_disabled/test_raft_upgrade.py @@ -16,7 +16,7 @@ from cassandra.pool import Host # type: ignore # pylint: from test.pylib.manager_client import ManagerClient, IPAddress, ServerInfo from test.pylib.random_tables import RandomTables -from test.pylib.rest_client import ScyllaRESTAPIClient, inject_error +from test.pylib.rest_client import ScyllaRESTAPIClient, inject_error_one_shot from test.pylib.util import wait_for, wait_for_cql_and_get_hosts @@ -163,42 +163,41 @@ async def test_recover_stuck_raft_upgrade(manager: ManagerClient, random_tables: # TODO error injection should probably be done through ScyllaClusterManager (we may need to mark the cluster as dirty). # In this test the cluster is dirty anyway due to a restart so it's safe. - async with inject_error(manager.api, srv1.ip_addr, 'group0_upgrade_before_synchronize', - one_shot=True): - logging.info(f"Enabling Raft on {others} and restarting") - await asyncio.gather(*(enable_raft_and_restart(manager, srv) for srv in others)) - cql = await reconnect_driver(manager) + await inject_error_one_shot(manager.api, srv1.ip_addr, 'group0_upgrade_before_synchronize') + logging.info(f"Enabling Raft on {others} and restarting") + await asyncio.gather(*(enable_raft_and_restart(manager, srv) for srv in others)) + cql = await reconnect_driver(manager) - logging.info(f"Cluster restarted, waiting until driver reconnects to {others}") - hosts = await wait_for_cql_and_get_hosts(cql, others, time.time() + 60) - logging.info(f"Driver reconnected, hosts: {hosts}") + logging.info(f"Cluster restarted, waiting until driver reconnects to {others}") + hosts = await wait_for_cql_and_get_hosts(cql, others, time.time() + 60) + logging.info(f"Driver reconnected, hosts: {hosts}") - logging.info(f"Waiting until {hosts} enter 'synchronize' state") - await asyncio.gather(*(wait_for_upgrade_state('synchronize', cql, h, time.time() + 60) for h in hosts)) - logging.info(f"{hosts} entered synchronize") + logging.info(f"Waiting until {hosts} enter 'synchronize' state") + await asyncio.gather(*(wait_for_upgrade_state('synchronize', cql, h, time.time() + 60) for h in hosts)) + logging.info(f"{hosts} entered synchronize") - # TODO ensure that srv1 failed upgrade - look at logs? - # '[shard 0] raft_group0_upgrade - Raft upgrade failed: std::runtime_error (error injection before group 0 upgrade enters synchronize).' + # TODO ensure that srv1 failed upgrade - look at logs? + # '[shard 0] raft_group0_upgrade - Raft upgrade failed: std::runtime_error (error injection before group 0 upgrade enters synchronize).' - logging.info(f"Setting recovery state on {hosts}") - for host in hosts: - await cql.run_async( - "update system.scylla_local set value = 'recovery' where key = 'group0_upgrade_state'", - host=host) + logging.info(f"Setting recovery state on {hosts}") + for host in hosts: + await cql.run_async( + "update system.scylla_local set value = 'recovery' where key = 'group0_upgrade_state'", + host=host) - logging.info(f"Restarting {others}") - await asyncio.gather(*(restart(manager, srv) for srv in others)) - cql = await reconnect_driver(manager) + logging.info(f"Restarting {others}") + await asyncio.gather(*(restart(manager, srv) for srv in others)) + cql = await reconnect_driver(manager) - logging.info(f"{others} restarted, waiting until driver reconnects to them") - hosts = await wait_for_cql_and_get_hosts(cql, others, time.time() + 60) + logging.info(f"{others} restarted, waiting until driver reconnects to them") + hosts = await wait_for_cql_and_get_hosts(cql, others, time.time() + 60) - logging.info(f"Checking if {hosts} are in recovery state") - for host in hosts: - rs = await cql.run_async( - "select value from system.scylla_local where key = 'group0_upgrade_state'", - host=host) - assert rs[0].value == 'recovery' + logging.info(f"Checking if {hosts} are in recovery state") + for host in hosts: + rs = await cql.run_async( + "select value from system.scylla_local where key = 'group0_upgrade_state'", + host=host) + assert rs[0].value == 'recovery' logging.info("Creating a table while in recovery state") table = await random_tables.add_table(ncolumns=5)