diff --git a/locator/tablets.cc b/locator/tablets.cc index 6a4d8351bf..d944e27082 100644 --- a/locator/tablets.cc +++ b/locator/tablets.cc @@ -38,6 +38,10 @@ write_replica_set_selector get_selector_for_writes(tablet_transition_stage stage return write_replica_set_selector::next; case tablet_transition_stage::cleanup: return write_replica_set_selector::next; + case tablet_transition_stage::cleanup_target: + return write_replica_set_selector::previous; + case tablet_transition_stage::revert_migration: + return write_replica_set_selector::previous; case tablet_transition_stage::end_migration: return write_replica_set_selector::next; } @@ -59,6 +63,10 @@ read_replica_set_selector get_selector_for_reads(tablet_transition_stage stage) return read_replica_set_selector::next; case tablet_transition_stage::cleanup: return read_replica_set_selector::next; + case tablet_transition_stage::cleanup_target: + return read_replica_set_selector::previous; + case tablet_transition_stage::revert_migration: + return read_replica_set_selector::previous; case tablet_transition_stage::end_migration: return read_replica_set_selector::next; } @@ -275,6 +283,8 @@ static const std::unordered_map tablet_transit {tablet_transition_stage::streaming, "streaming"}, {tablet_transition_stage::use_new, "use_new"}, {tablet_transition_stage::cleanup, "cleanup"}, + {tablet_transition_stage::cleanup_target, "cleanup_target"}, + {tablet_transition_stage::revert_migration, "revert_migration"}, {tablet_transition_stage::end_migration, "end_migration"}, }; diff --git a/locator/tablets.hh b/locator/tablets.hh index 6c4aebd4a5..85696ba9ab 100644 --- a/locator/tablets.hh +++ b/locator/tablets.hh @@ -157,6 +157,8 @@ enum class tablet_transition_stage { write_both_read_new, use_new, cleanup, + cleanup_target, + revert_migration, end_migration, }; diff --git a/service/storage_service.cc b/service/storage_service.cc index 8e43a25a79..d772ced8a6 100644 --- a/service/storage_service.cc +++ b/service/storage_service.cc @@ -5505,15 +5505,21 @@ future<> storage_service::cleanup_tablet(locator::global_tablet_id tablet) { if (!trinfo) { throw std::runtime_error(fmt::format("No transition info for tablet {}", tablet)); } - if (trinfo->stage != locator::tablet_transition_stage::cleanup) { - throw std::runtime_error(fmt::format("Tablet {} stage is not at cleanup", tablet)); + + if (trinfo->stage == locator::tablet_transition_stage::cleanup) { + auto& tinfo = tmap.get_tablet_info(tablet.tablet); + locator::tablet_replica leaving_replica = locator::get_leaving_replica(tinfo, *trinfo); + if (leaving_replica.host != tm->get_my_id()) { + throw std::runtime_error(fmt::format("Tablet {} has leaving replica different than this one", tablet)); + } + } else if (trinfo->stage == locator::tablet_transition_stage::cleanup_target) { + if (trinfo->pending_replica.host != tm->get_my_id()) { + throw std::runtime_error(fmt::format("Tablet {} has pending replica different than this one", tablet)); + } + } else { + throw std::runtime_error(fmt::format("Tablet {} stage is not at cleanup/cleanup_target", tablet)); } - auto& tinfo = tmap.get_tablet_info(tablet.tablet); - locator::tablet_replica leaving_replica = locator::get_leaving_replica(tinfo, *trinfo); - if (leaving_replica.host != tm->get_my_id()) { - throw std::runtime_error(fmt::format("Tablet {} has leaving replica different than this one", tablet)); - } auto shard_opt = tmap.get_shard(tablet.tablet, tm->get_my_id()); if (!shard_opt) { on_internal_error(rtlogger, format("Tablet {} has no shard on this node", tablet)); diff --git a/service/tablet_allocator.cc b/service/tablet_allocator.cc index e613194604..0a1d8de37e 100644 --- a/service/tablet_allocator.cc +++ b/service/tablet_allocator.cc @@ -425,6 +425,10 @@ private: return false; case tablet_transition_stage::cleanup: return false; + case tablet_transition_stage::cleanup_target: + return false; + case tablet_transition_stage::revert_migration: + return false; case tablet_transition_stage::end_migration: return false; } diff --git a/service/topology_coordinator.cc b/service/topology_coordinator.cc index 696194eb30..a6df910293 100644 --- a/service/topology_coordinator.cc +++ b/service/topology_coordinator.cc @@ -991,6 +991,23 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { } }; + auto check_excluded_replicas = [&] { + auto tsi = get_migration_streaming_info(get_token_metadata().get_topology(), tmap.get_tablet_info(gid.tablet), trinfo); + for (auto r : tsi.read_from) { + if (is_excluded(raft::server_id(r.host.uuid()))) { + rtlogger.debug("Aborting streaming of {} because read-from {} is marked as ignored", gid, r); + return true; + } + } + for (auto r : tsi.written_to) { + if (is_excluded(raft::server_id(r.host.uuid()))) { + rtlogger.debug("Aborting streaming of {} because written-to {} is marked as ignored", gid, r); + return true; + } + } + return false; + }; + switch (trinfo.stage) { case locator::tablet_transition_stage::allow_write_both_read_old: if (do_barrier()) { @@ -1014,6 +1031,14 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { utils::get_local_injector().inject("stream_tablet_fail_on_drain", [] { throw std::runtime_error("stream_tablet failed due to error injection"); }); } + + if (tablet_state.streaming && tablet_state.streaming->failed()) { + if (check_excluded_replicas()) { + transition_to_with_barrier(locator::tablet_transition_stage::cleanup_target); + break; + } + } + if (advance_in_background(gid, tablet_state.streaming, "streaming", [&] { rtlogger.info("Initiating tablet streaming ({}) of {} to {}", trinfo.transition, gid, trinfo.pending_replica); auto dst = trinfo.pending_replica.host; @@ -1047,6 +1072,30 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { transition_to(locator::tablet_transition_stage::end_migration); } break; + case locator::tablet_transition_stage::cleanup_target: + if (advance_in_background(gid, tablet_state.cleanup, "cleanup_target", [&] { + locator::tablet_replica dst = trinfo.pending_replica; + if (is_excluded(raft::server_id(dst.host.uuid()))) { + rtlogger.info("Tablet cleanup of {} on {} skipped because node is excluded and doesn't need to revert migration", gid, dst); + return make_ready_future<>(); + } + rtlogger.info("Initiating tablet cleanup of {} on {} to revert migration", gid, dst); + return ser::storage_service_rpc_verbs::send_tablet_cleanup(&_messaging, + netw::msg_addr(id2ip(dst.host)), _as, raft::server_id(dst.host.uuid()), gid); + })) { + transition_to(locator::tablet_transition_stage::revert_migration); + } + break; + case locator::tablet_transition_stage::revert_migration: + // Need a separate stage and a barrier after cleanup RPC to cut off stale RPCs. + // See do_tablet_operation() doc. + if (do_barrier()) { + _tablets.erase(gid); + updates.emplace_back(get_mutation_builder() + .del_transition(last_token) + .build()); + } + break; case locator::tablet_transition_stage::end_migration: // Need a separate stage and a barrier after cleanup RPC to cut off stale RPCs. // See do_tablet_operation() doc. diff --git a/test/pylib/tablets.py b/test/pylib/tablets.py new file mode 100644 index 0000000000..b2fec1fbe5 --- /dev/null +++ b/test/pylib/tablets.py @@ -0,0 +1,61 @@ +# +# Copyright (C) 2024-present ScyllaDB +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# + +from test.pylib.util import read_barrier +from test.pylib.manager_client import ManagerClient +from test.pylib.internal_types import ServerInfo, HostID +from typing import NamedTuple + +class TabletReplicas(NamedTuple): + last_token: int + replicas: list[tuple[HostID, int]] + +async def get_all_tablet_replicas(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str) -> list[TabletReplicas]: + """ + Retrieves the tablet distribution for a given table. + This call is guaranteed to see all prior changes applied to group0 tables. + + :param server: server to query. Can be any live node. + """ + + host = manager.get_cql().cluster.metadata.get_host(server.ip_addr) + + # read_barrier is needed to ensure that local tablet metadata on the queried node + # reflects the finalized tablet movement. + await read_barrier(manager.get_cql(), host) + + table_id = await manager.get_table_id(keyspace_name, table_name) + rows = await manager.get_cql().run_async(f"SELECT last_token, replicas FROM system.tablets where " + f"table_id = {table_id}", host=host) + return [TabletReplicas( + last_token=x.last_token, + replicas=[(HostID(str(host)), shard) for (host, shard) in x.replicas] + ) for x in rows] + +async def get_tablet_replicas(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str, token: int) -> list[tuple[HostID, int]]: + """ + Gets tablet replicas of the tablet which owns a given token of a given table. + This call is guaranteed to see all prior changes applied to group0 tables. + + :param server: server to query. Can be any live node. + """ + rows = await get_all_tablet_replicas(manager, server, keyspace_name, table_name) + for row in rows: + if row.last_token >= token: + return row.replicas + return [] + + +async def get_tablet_replica(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str, token: int) -> tuple[HostID, int]: + """ + Get the first replica of the tablet which owns a given token of a given table. + This call is guaranteed to see all prior changes applied to group0 tables. + + :param server: server to query. Can be any live node. + """ + replicas = await get_tablet_replicas(manager, server, keyspace_name, table_name, token) + return replicas[0] + diff --git a/test/topology_custom/test_tablets_migration.py b/test/topology_custom/test_tablets_migration.py new file mode 100644 index 0000000000..e1ed3d637b --- /dev/null +++ b/test/topology_custom/test_tablets_migration.py @@ -0,0 +1,91 @@ +# +# Copyright (C) 2024-present ScyllaDB +# +# SPDX-License-Identifier: AGPL-3.0-or-later +# +from cassandra.query import SimpleStatement, ConsistencyLevel +from test.pylib.manager_client import ManagerClient +from test.pylib.rest_client import HTTPError +from test.pylib.tablets import get_all_tablet_replicas +from test.topology.conftest import skip_mode +import pytest +import logging +import asyncio + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("fail_replica", ["source", "destination"]) +@pytest.mark.parametrize("fail_stage", ["streaming"]) +@pytest.mark.asyncio +@skip_mode('release', 'error injections are not supported in release mode') +async def test_node_failure_during_tablet_migration(manager: ManagerClient, fail_replica, fail_stage): + logger.info("Bootstrapping cluster") + cfg = {'enable_user_defined_functions': False, 'experimental_features': ['tablets', 'consistent-topology-changes']} + host_ids = [] + servers = [] + + async def make_server(): + s = await manager.server_add(config=cfg) + servers.append(s) + host_ids.append(await manager.get_host_id(s.server_id)) + await manager.api.disable_tablet_balancing(s.ip_addr) + + await make_server() + cql = manager.get_cql() + + await cql.run_async("CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 2} AND tablets = {'initial': 1}") + await make_server() + await cql.run_async("CREATE TABLE test.test (pk int PRIMARY KEY, c int);") + + keys = range(256) + await asyncio.gather(*[cql.run_async(f"INSERT INTO test.test (pk, c) VALUES ({k}, {k});") for k in keys]) + await make_server() + + logger.info(f"Cluster is [{host_ids}]") + + replicas = await get_all_tablet_replicas(manager, servers[0], 'test', 'test') + logger.info(f"Tablet is on [{replicas}]") + assert len(replicas) == 1 and len(replicas[0].replicas) == 2 + + old_replica = None + for r in replicas[0].replicas: + assert r[0] != host_ids[2], "Tablet got migrated to node2" + if r[0] == host_ids[1]: + old_replica = r + assert old_replica is not None + new_replica = (host_ids[2], 0) + logger.info(f"Moving tablet {old_replica} -> {new_replica}") + + fail_idx = 1 if fail_replica == "source" else 2 + + logger.info(f"Will fail {fail_stage}") + if fail_stage == "streaming": + await manager.api.enable_injection(servers[2].ip_addr, "stream_mutation_fragments", one_shot=True) + s2_log = await manager.server_open_log(servers[2].server_id) + s2_mark = await s2_log.mark() + else: + assert False, f"Unknown stage {fail_stage}" + + migration_task = asyncio.create_task( + manager.api.move_tablet(servers[0].ip_addr, "test", "test", old_replica[0], old_replica[1], new_replica[0], new_replica[1], 0)) + + logger.info(f"Wait for {fail_stage} to happen") + if fail_stage == "streaming": + await s2_log.wait_for('stream_mutation_fragments: waiting', from_mark=s2_mark) + else: + assert False + + logger.info(f"Stop {fail_replica} {host_ids[fail_idx]}") + await manager.server_stop(servers[fail_idx].server_id) + logger.info(f"Remove {fail_replica} {host_ids[fail_idx]}") + await manager.remove_node(servers[0].server_id, servers[fail_idx].server_id) + + logger.info("Done, waiting for migration to finish") + await migration_task + + replicas = await get_all_tablet_replicas(manager, servers[0], 'test', 'test') + logger.info(f"Tablet is now on [{replicas}]") + assert len(replicas) == 1 + for r in replicas[0].replicas: + assert r[0] != host_ids[fail_idx] diff --git a/test/topology_experimental_raft/test_tablets.py b/test/topology_experimental_raft/test_tablets.py index 9b97895942..d203b4e0e4 100644 --- a/test/topology_experimental_raft/test_tablets.py +++ b/test/topology_experimental_raft/test_tablets.py @@ -10,16 +10,15 @@ from test.pylib.manager_client import ManagerClient from test.pylib.rest_client import inject_error_one_shot, HTTPError from test.pylib.rest_client import inject_error from test.pylib.util import wait_for_cql_and_get_hosts, read_barrier +from test.pylib.tablets import get_tablet_replica, get_all_tablet_replicas from test.topology.conftest import skip_mode from test.topology.util import reconnect_driver -from test.pylib.internal_types import HostID import pytest import asyncio import logging import time import random -from typing import NamedTuple logger = logging.getLogger(__name__) @@ -33,56 +32,6 @@ async def inject_error_on(manager, error_name, servers): errs = [manager.api.enable_injection(s.ip_addr, error_name, False) for s in servers] await asyncio.gather(*errs) -class TabletReplicas(NamedTuple): - last_token: int - replicas: list[tuple[HostID, int]] - -async def get_all_tablet_replicas(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str) -> list[TabletReplicas]: - """ - Retrieves the tablet distribution for a given table. - This call is guaranteed to see all prior changes applied to group0 tables. - - :param server: server to query. Can be any live node. - """ - - host = manager.get_cql().cluster.metadata.get_host(server.ip_addr) - - # read_barrier is needed to ensure that local tablet metadata on the queried node - # reflects the finalized tablet movement. - await read_barrier(manager.get_cql(), host) - - table_id = await manager.get_table_id(keyspace_name, table_name) - rows = await manager.get_cql().run_async(f"SELECT last_token, replicas FROM system.tablets where " - f"table_id = {table_id}", host=host) - return [TabletReplicas( - last_token=x.last_token, - replicas=[(HostID(str(host)), shard) for (host, shard) in x.replicas] - ) for x in rows] - -async def get_tablet_replicas(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str, token: int) -> list[tuple[HostID, int]]: - """ - Gets tablet replicas of the tablet which owns a given token of a given table. - This call is guaranteed to see all prior changes applied to group0 tables. - - :param server: server to query. Can be any live node. - """ - rows = await get_all_tablet_replicas(manager, server, keyspace_name, table_name) - for row in rows: - if row.last_token >= token: - return row.replicas - return [] - - -async def get_tablet_replica(manager: ManagerClient, server: ServerInfo, keyspace_name: str, table_name: str, token: int) -> tuple[HostID, int]: - """ - Get the first replica of the tablet which owns a given token of a given table. - This call is guaranteed to see all prior changes applied to group0 tables. - - :param server: server to query. Can be any live node. - """ - replicas = await get_tablet_replicas(manager, server, keyspace_name, table_name, token) - return replicas[0] - async def repair_on_node(manager: ManagerClient, server: ServerInfo, servers: list[ServerInfo]): node = server.ip_addr await manager.servers_see_each_other(servers)