diff --git a/node_ops/task_manager_module.cc b/node_ops/task_manager_module.cc index 3909196282..ff5d0b39aa 100644 --- a/node_ops/task_manager_module.cc +++ b/node_ops/task_manager_module.cc @@ -91,7 +91,7 @@ future> node_ops_virtual_task::get_status_help .entity = "", .progress_units = "", .progress = tasks::task_manager::task::progress{}, - .children = started ? co_await get_children(get_module(), id, [&gossiper = _ss.gossiper()] (gms::inet_address addr) { return gossiper.is_alive(addr); }) : std::vector{} + .children = started ? co_await get_children(get_module(), id, _ss.get_token_metadata_ptr()) : std::vector{} }; } diff --git a/service/task_manager_module.cc b/service/task_manager_module.cc index cf97a6f476..a316a404d7 100644 --- a/service/task_manager_module.cc +++ b/service/task_manager_module.cc @@ -156,7 +156,7 @@ future> tablet_virtual_task::wait(tasks::task_ } else if (is_resize_task(task_type)) { auto new_tablet_count = _ss.get_token_metadata().tablets().get_tablet_map(table).tablet_count(); res->status.state = new_tablet_count == tablet_count ? tasks::task_manager::task_state::suspended : tasks::task_manager::task_state::done; - res->status.children = task_type == locator::tablet_task_type::split ? co_await get_children(get_module(), id, [&gossiper = _ss.gossiper()] (gms::inet_address addr) { return gossiper.is_alive(addr); }) : std::vector{}; + res->status.children = task_type == locator::tablet_task_type::split ? co_await get_children(get_module(), id, _ss.get_token_metadata_ptr()) : std::vector{}; } res->status.end_time = db_clock::now(); // FIXME: Get precise end time. co_return res->status; @@ -262,7 +262,7 @@ future> tablet_virtual_task::get_status_helper(task if (task_info.tablet_task_id.uuid() == id.uuid()) { update_status(task_info, res.status, sched_nr); res.status.state = tasks::task_manager::task_state::running; - res.status.children = task_type == locator::tablet_task_type::split ? co_await get_children(get_module(), id, [&gossiper = _ss.gossiper()] (gms::inet_address addr) { return gossiper.is_alive(addr); }) : std::vector{}; + res.status.children = task_type == locator::tablet_task_type::split ? co_await get_children(get_module(), id, _ss.get_token_metadata_ptr()) : std::vector{}; co_return res; } } diff --git a/tasks/task_manager.cc b/tasks/task_manager.cc index 86ec6aec50..4abc9ada90 100644 --- a/tasks/task_manager.cc +++ b/tasks/task_manager.cc @@ -389,7 +389,7 @@ task_manager::virtual_task::impl::impl(module_ptr module) noexcept : _module(std::move(module)) {} -future> task_manager::virtual_task::impl::get_children(module_ptr module, task_id parent_id, std::function is_host_alive) { +future> task_manager::virtual_task::impl::get_children(module_ptr module, task_id parent_id, locator::token_metadata_ptr tmptr) { auto ms = module->get_task_manager()._messaging; if (!ms) { auto ids = co_await module->get_task_manager().get_virtual_task_children(parent_id); @@ -406,19 +406,18 @@ future> task_manager::virtual_task::impl::get_childre tmlogger.info("tasks_vt_get_children: waiting"); co_await handler.wait_for_message(std::chrono::steady_clock::now() + std::chrono::seconds{10}); }); - co_return co_await map_reduce(nodes, [ms, parent_id, is_host_alive = std::move(is_host_alive)] (auto addr) -> future> { - if (is_host_alive(addr)) { - return ms->send_tasks_get_children(netw::msg_addr{addr}, parent_id).then([addr] (auto resp) { - return resp | std::views::transform([addr] (auto id) { - return task_identity{ - .node = addr, - .task_id = id - }; - }) | std::ranges::to>(); - }); - } else { - return make_ready_future>(); - } + co_return co_await map_reduce(nodes, [ms, parent_id] (auto addr) -> future> { + return ms->send_tasks_get_children(netw::msg_addr{addr}, parent_id).then([addr] (auto resp) { + return resp | std::views::transform([addr] (auto id) { + return task_identity{ + .node = addr, + .task_id = id + }; + }) | std::ranges::to>(); + }).handle_exception_type([addr, parent_id] (const rpc::closed_error& ex) { + tmlogger.warn("Failed to get children of virtual task with id={} from node {}: {}", parent_id, addr, ex); + return std::vector{}; + }); }, std::vector{}, concat); } diff --git a/tasks/task_manager.hh b/tasks/task_manager.hh index 954e9bd5ba..782c642f1a 100644 --- a/tasks/task_manager.hh +++ b/tasks/task_manager.hh @@ -19,6 +19,7 @@ #include "db_clock.hh" #include "utils/log.hh" #include "gms/inet_address.hh" +#include "locator/token_metadata_fwd.hh" #include "schema/schema_fwd.hh" #include "tasks/types.hh" #include "utils/chunked_vector.hh" @@ -279,7 +280,7 @@ public: impl& operator=(impl&&) = delete; virtual ~impl() = default; protected: - static future> get_children(module_ptr module, task_id parent_id, std::function is_host_alive); + static future> get_children(module_ptr module, task_id parent_id, locator::token_metadata_ptr tmptr); public: virtual task_group get_group() const noexcept = 0; // Returns std::nullopt if an operation with task_id isn't tracked by this virtual_task. diff --git a/test/topology_tasks/test_node_ops_tasks.py b/test/topology_tasks/test_node_ops_tasks.py index 4a61284ef7..0bebdfe7ca 100644 --- a/test/topology_tasks/test_node_ops_tasks.py +++ b/test/topology_tasks/test_node_ops_tasks.py @@ -258,27 +258,3 @@ async def test_node_ops_task_wait(manager: ManagerClient): await decommission_task await waiting_task - -@pytest.mark.asyncio -async def test_get_children(manager: ManagerClient): - module_name = "node_ops" - tm = TaskManagerClient(manager.api) - servers = [await manager.server_add() for _ in range(2)] - - injection = "tasks_vt_get_children" - handler = await inject_error_one_shot(manager.api, servers[0].ip_addr, injection) - - log = await manager.server_open_log(servers[0].server_id) - mark = await log.mark() - - bootstrap_task = [task for task in await tm.list_tasks(servers[0].ip_addr, module_name) if task.kind == "cluster"][0] - - async def _decommission(): - await log.wait_for('tasks_vt_get_children: waiting', from_mark=mark) - await manager.decommission_node(servers[1].server_id) - await handler.message() - - async def _get_status(): - await tm.get_task_status(servers[0].ip_addr, bootstrap_task.task_id) - - await asyncio.gather(*(_decommission(), _get_status()))