Files
scylladb/test/cluster/tasks/test_node_ops_tasks.py
Aleksandra Martyniuk d4fdeb4839 tasks: pass token_metadata_ptr to task_manager::virtual_task::impl::get_children
In get_children we get the vector of alive nodes with get_nodes.
Yet, between this and sending rpc to those nodes there might be
a preemption. Currently, the liveness of a node is checked once
again before the rpcs (only with gossiper not in topology - unlike
get_nodes).

Modify get_children, so that it keeps a token_metadata_ptr,
preventing topology from changing between get_nodes and rpcs.

Remove test_get_children as it checked if the get_children method
won't fail if a node is down after get_nodes - which cannot happen
currently.
2026-03-18 15:37:24 +01:00

257 lines
13 KiB
Python

#
# Copyright (C) 2024-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
from functools import partial
from typing import Optional
from test.pylib.internal_types import IPAddress, ServerInfo
from test.pylib.manager_client import ManagerClient
from test.pylib.rest_client import InjectionHandler, inject_error_one_shot
from test.pylib.scylla_cluster import ReplaceConfig
from test.pylib.util import wait_for
from test.cluster.tasks.task_manager_client import TaskManagerClient
from test.cluster.util import new_test_keyspace
import asyncio
import logging
import pytest
import time
from test.cluster.tasks.task_manager_types import TaskID, TaskStats, TaskStatus
from test.cluster.tasks import extra_scylla_cmdline_options as cmdline
logger = logging.getLogger(__name__)
async def get_status_allow_peer_connection_failure(tm: TaskManagerClient, node_ip: IPAddress, task_id: TaskID) -> Optional[TaskStatus]:
ret = await tm.api.client.get_json(f"/task_manager/task_status/{task_id}", host = node_ip, allow_failed = True)
resp_status = ret.get("code", 200)
if resp_status == 200:
assert(type(ret) == dict)
return TaskStatus(**ret)
else:
assert resp_status == 500 and "seastar::rpc::closed_error (connection is closed)" in ret["message"]
return None
async def compare_status_on_all_servers(task_id: TaskID, tm: TaskManagerClient, servers: list[ServerInfo]) -> TaskStatus:
statuses = [await get_status_allow_peer_connection_failure(tm, server.ip_addr, task_id) for server in servers]
statuses = [s for s in statuses if s is not None]
assert statuses, "No statuses to compare"
assert all(status.id == statuses[0].id and status.start_time == statuses[0].start_time for status in statuses)
return statuses[0]
async def get_new_virtual_tasks_list(tm: TaskManagerClient, module_name: str, server: ServerInfo,
previous_vts: list[TaskID]) -> list[TaskStats]:
return [task for task in await tm.list_tasks(server.ip_addr, module_name)
if task.kind == "cluster" and task.task_id not in previous_vts]
async def get_new_virtual_tasks_statuses(tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID], expected_task_num: int) -> list[TaskStatus]:
vts_list = await get_new_virtual_tasks_list(tm, module_name, servers[0], previous_vts)
assert len(vts_list) == expected_task_num, "Wrong cluster tasks number"
return [await compare_status_on_all_servers(stats.task_id, tm, servers) for stats in vts_list]
def check_virtual_task_status(virtual_task: TaskStatus, expected_state: str, expected_type: str, expected_children_num: int) -> None:
assert virtual_task.state == expected_state
assert virtual_task.type == expected_type or virtual_task.type == ""
assert virtual_task.kind == "cluster"
assert virtual_task.scope == "cluster"
assert virtual_task.parent_id == "none"
assert len(virtual_task.children_ids) == expected_children_num
def check_regular_task_status(task: TaskStatus, expected_state: str, expected_type: str, expected_scope: str,
expected_parent_id: str, expected_children_num: int) -> None:
assert task.state == expected_state
assert task.type == expected_type
assert task.kind == "node"
assert task.scope == expected_scope
assert not task.is_abortable
assert task.parent_id == expected_parent_id
assert len(task.children_ids) == expected_children_num
async def check_bootstrap_tasks_tree(tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID] = []) -> tuple[list[ServerInfo], list[TaskID]]:
# Bootstrap of the first node is omitted.
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, len(servers) - 1)
for virtual_task in virtual_tasks:
check_virtual_task_status(virtual_task, "done", "bootstrap", 1)
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
check_regular_task_status(child, "done", "bootstrap: streaming", "node", virtual_task.id, 0)
return (servers, [vt.id for vt in virtual_tasks])
async def check_replace_tasks_tree(manager: ManagerClient, tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
assert servers, "No servers available"
replaced_server = servers[0]
logger.info(f"Stopping node {replaced_server}")
await manager.server_stop_gracefully(replaced_server.server_id)
logger.info(f"Replacing node {replaced_server}")
replace_cfg = ReplaceConfig(replaced_id = replaced_server.server_id, reuse_ip_addr = False, use_host_id = False)
replacing_server = await manager.server_add(replace_cfg=replace_cfg, cmdline=cmdline)
servers = servers[1:] + [replacing_server]
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
virtual_task = virtual_tasks[0]
check_virtual_task_status(virtual_task, "done", "replace", 1)
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
check_regular_task_status(child, "done", "replace: streaming", "node", virtual_task.id, 0)
return servers, previous_vts + [virtual_task.id]
async def check_rebuild_tasks_tree(manager: ManagerClient, tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
async def _all_alive():
if len(servers) == len(await manager.running_servers()):
return True
assert servers, "No servers available"
logger.info(f"Rebuilding node {servers[0]}")
rebuilt_server = servers[0]
await manager.api.rebuild_node(rebuilt_server.ip_addr, 60)
await wait_for(_all_alive, time.time() + 60)
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
virtual_task = virtual_tasks[0]
check_virtual_task_status(virtual_task, "done", "rebuild", 1)
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
check_regular_task_status(child, "done", "rebuild: streaming", "node", virtual_task.id, 0)
return servers, previous_vts + [virtual_task.id]
async def check_remove_node_tasks_tree(manager: ManagerClient, tm: TaskManagerClient,module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
assert servers, "No servers available"
removed_server = servers[0]
removing_server = servers[1]
logger.info(f"Stopping node {removed_server}")
await manager.server_stop_gracefully(removed_server.server_id)
logger.info(f"Removing node {removed_server} using {removing_server}")
await manager.remove_node(removing_server.server_id, removed_server.server_id)
servers = servers[1:]
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
virtual_task = virtual_tasks[0]
check_virtual_task_status(virtual_task, "done", "remove node", len(servers))
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
check_regular_task_status(child, "done", "removenode: streaming", "node", virtual_task.id, 0)
return servers, previous_vts + [virtual_task.id]
async def poll_for_task(tm: TaskManagerClient, module_name: str, server: ServerInfo, expected_kind: str, expected_type: str):
async def _get_streaming_tasks(server: ServerInfo) -> list[TaskStats]:
return [stats for stats in await tm.list_tasks(server.ip_addr, module_name) if stats.kind == expected_kind and stats.type == expected_type]
async def _has_streaming_tasks(server: ServerInfo):
if len(await _get_streaming_tasks(server)) > 0:
return True
await wait_for(partial(_has_streaming_tasks, server), time.time() + 100.) # Wait until streaming task is created.
streaming_tasks = await _get_streaming_tasks(server)
assert len(streaming_tasks) == 1
return streaming_tasks[0]
async def check_decommission_tasks_tree(manager: ManagerClient, tm: TaskManagerClient,module_name: str, servers: list[ServerInfo],
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
async def _check_virtual_task(decommissioned_server: ServerInfo, handler):
logger.info("Checking top level decommission node task")
await poll_for_task(tm, module_name, decommissioned_server, "node", "decommission: streaming")
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
virtual_task = virtual_tasks[0]
check_virtual_task_status(virtual_task, "running", "decommission", 1)
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
check_regular_task_status(child, "running", "decommission: streaming", "node", virtual_task.id, 0)
await handler.message()
assert servers, "No servers available"
decommissioned_server = servers[0]
injection = "streaming_task_impl_decommission_run"
handler = await inject_error_one_shot(manager.api, decommissioned_server.ip_addr, injection)
logger.info(f"Decommissioning node {decommissioned_server}")
await asyncio.gather(*(manager.decommission_node(decommissioned_server.server_id),
_check_virtual_task(decommissioned_server, handler)))
servers = servers[1:]
vts_list = await get_new_virtual_tasks_list(tm, module_name, servers[0], previous_vts)
return servers, previous_vts + [vts_list[0].task_id]
@pytest.mark.asyncio
async def test_node_ops_tasks_tree(manager: ManagerClient):
"""Test node ops task manager tasks."""
module_name = "node_ops"
tm = TaskManagerClient(manager.api)
servers = [await manager.server_add(cmdline=cmdline) for _ in range(3)]
assert module_name in await tm.list_modules(servers[0].ip_addr), "node_ops module wasn't registered"
cql = manager.get_cql()
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 1}") as ks:
await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);")
await cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({1}, {1});")
await cql.run_async(f"TRUNCATE {ks}.test;")
servers, vt_ids = await check_bootstrap_tasks_tree(tm, module_name, servers)
servers, vt_ids = await check_replace_tasks_tree(manager, tm, module_name, servers, vt_ids)
servers, vt_ids = await check_rebuild_tasks_tree(manager, tm, module_name, servers, vt_ids)
servers, vt_ids = await check_remove_node_tasks_tree(manager, tm, module_name, servers, vt_ids)
servers, vt_ids = await check_decommission_tasks_tree(manager, tm, module_name, servers, vt_ids)
@pytest.mark.asyncio
async def test_node_ops_tasks_ttl(manager: ManagerClient):
"""Test node ops virtual tasks' ttl."""
module_name = "node_ops"
tm = TaskManagerClient(manager.api)
servers = [await manager.server_add(cmdline=cmdline) for _ in range(2)]
[await tm.set_user_task_ttl(server.ip_addr, 3) for server in servers]
time.sleep(3)
await get_new_virtual_tasks_statuses(tm, module_name, servers, [], expected_task_num=0)
@pytest.mark.asyncio
async def test_node_ops_task_wait(manager: ManagerClient):
"""Test node ops virtual task's wait."""
async def _decommission(manager: ManagerClient, server: ServerInfo):
await manager.decommission_node(server.server_id)
async def _wait_for_task(tm: TaskManagerClient, module_name: str, server: ServerInfo, handler: InjectionHandler):
task = await poll_for_task(tm, module_name, servers[1], "cluster", "decommission")
assert task.state == "running"
await handler.message()
status = await tm.wait_for_task(server.ip_addr, task.task_id)
assert status.state == "done"
module_name = "node_ops"
tm = TaskManagerClient(manager.api)
servers = [await manager.server_add(cmdline=cmdline) for _ in range(2)]
injection = "streaming_task_impl_decommission_run"
handler = await inject_error_one_shot(manager.api, servers[0].ip_addr, injection)
decommission_task = asyncio.create_task(
_decommission(manager, servers[0]))
waiting_task = asyncio.create_task(
_wait_for_task(tm, module_name, servers[1], handler))
await decommission_task
await waiting_task