mirror of
https://github.com/scylladb/scylladb.git
synced 2026-04-20 16:40:35 +00:00
The test exercises all five node operations (bootstrap, replace, rebuild, removenode, decommission) and by the end only one node out of four remains alive. The CQL driver session, however, still holds stale references to the dead hosts in its connection pool and load-balancing policy state. When the new_test_keyspace context manager exits and attempts DROP KEYSPACE, the driver routes the query to the dead hosts first, gets ConnectionShutdown from each, and throws NoHostAvailable before ever trying the single live node. Fix by calling driver_connect() after the decommission step, which closes the old session and creates a fresh one connected only to the servers the test manager reports as running. Fixes: https://scylladb.atlassian.net/browse/SCYLLADB-1313. Closes scylladb/scylladb#29306
262 lines
13 KiB
Python
262 lines
13 KiB
Python
#
|
|
# Copyright (C) 2024-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
#
|
|
|
|
from functools import partial
|
|
from typing import Optional
|
|
from test.pylib.internal_types import IPAddress, ServerInfo
|
|
from test.pylib.manager_client import ManagerClient
|
|
from test.pylib.rest_client import InjectionHandler, inject_error_one_shot
|
|
from test.pylib.scylla_cluster import ReplaceConfig
|
|
from test.pylib.util import wait_for
|
|
from test.cluster.tasks.task_manager_client import TaskManagerClient
|
|
from test.cluster.util import new_test_keyspace
|
|
|
|
import asyncio
|
|
import logging
|
|
import pytest
|
|
import time
|
|
|
|
from test.cluster.tasks.task_manager_types import TaskID, TaskStats, TaskStatus
|
|
from test.cluster.tasks import extra_scylla_cmdline_options as cmdline
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def get_status_allow_peer_connection_failure(tm: TaskManagerClient, node_ip: IPAddress, task_id: TaskID) -> Optional[TaskStatus]:
|
|
ret = await tm.api.client.get_json(f"/task_manager/task_status/{task_id}", host = node_ip, allow_failed = True)
|
|
resp_status = ret.get("code", 200)
|
|
if resp_status == 200:
|
|
assert(type(ret) == dict)
|
|
return TaskStatus(**ret)
|
|
else:
|
|
assert resp_status == 500 and "seastar::rpc::closed_error (connection is closed)" in ret["message"]
|
|
return None
|
|
|
|
async def compare_status_on_all_servers(task_id: TaskID, tm: TaskManagerClient, servers: list[ServerInfo]) -> TaskStatus:
|
|
statuses = [await get_status_allow_peer_connection_failure(tm, server.ip_addr, task_id) for server in servers]
|
|
statuses = [s for s in statuses if s is not None]
|
|
assert statuses, "No statuses to compare"
|
|
assert all(status.id == statuses[0].id and status.start_time == statuses[0].start_time for status in statuses)
|
|
return statuses[0]
|
|
|
|
async def get_new_virtual_tasks_list(tm: TaskManagerClient, module_name: str, server: ServerInfo,
|
|
previous_vts: list[TaskID]) -> list[TaskStats]:
|
|
return [task for task in await tm.list_tasks(server.ip_addr, module_name)
|
|
if task.kind == "cluster" and task.task_id not in previous_vts]
|
|
|
|
async def get_new_virtual_tasks_statuses(tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID], expected_task_num: int) -> list[TaskStatus]:
|
|
vts_list = await get_new_virtual_tasks_list(tm, module_name, servers[0], previous_vts)
|
|
assert len(vts_list) == expected_task_num, "Wrong cluster tasks number"
|
|
|
|
return [await compare_status_on_all_servers(stats.task_id, tm, servers) for stats in vts_list]
|
|
|
|
def check_virtual_task_status(virtual_task: TaskStatus, expected_state: str, expected_type: str, expected_children_num: int) -> None:
|
|
assert virtual_task.state == expected_state
|
|
assert virtual_task.type == expected_type or virtual_task.type == ""
|
|
assert virtual_task.kind == "cluster"
|
|
assert virtual_task.scope == "cluster"
|
|
assert virtual_task.parent_id == "none"
|
|
assert len(virtual_task.children_ids) == expected_children_num
|
|
|
|
def check_regular_task_status(task: TaskStatus, expected_state: str, expected_type: str, expected_scope: str,
|
|
expected_parent_id: str, expected_children_num: int) -> None:
|
|
assert task.state == expected_state
|
|
assert task.type == expected_type
|
|
assert task.kind == "node"
|
|
assert task.scope == expected_scope
|
|
assert not task.is_abortable
|
|
assert task.parent_id == expected_parent_id
|
|
assert len(task.children_ids) == expected_children_num
|
|
|
|
async def check_bootstrap_tasks_tree(tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID] = []) -> tuple[list[ServerInfo], list[TaskID]]:
|
|
# Bootstrap of the first node is omitted.
|
|
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, len(servers) - 1)
|
|
|
|
for virtual_task in virtual_tasks:
|
|
check_virtual_task_status(virtual_task, "done", "bootstrap", 1)
|
|
|
|
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
|
|
check_regular_task_status(child, "done", "bootstrap: streaming", "node", virtual_task.id, 0)
|
|
|
|
return (servers, [vt.id for vt in virtual_tasks])
|
|
|
|
async def check_replace_tasks_tree(manager: ManagerClient, tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
|
|
assert servers, "No servers available"
|
|
|
|
replaced_server = servers[0]
|
|
logger.info(f"Stopping node {replaced_server}")
|
|
await manager.server_stop_gracefully(replaced_server.server_id)
|
|
|
|
logger.info(f"Replacing node {replaced_server}")
|
|
replace_cfg = ReplaceConfig(replaced_id = replaced_server.server_id, reuse_ip_addr = False, use_host_id = False)
|
|
replacing_server = await manager.server_add(replace_cfg=replace_cfg, cmdline=cmdline)
|
|
|
|
servers = servers[1:] + [replacing_server]
|
|
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
|
|
virtual_task = virtual_tasks[0]
|
|
check_virtual_task_status(virtual_task, "done", "replace", 1)
|
|
|
|
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
|
|
check_regular_task_status(child, "done", "replace: streaming", "node", virtual_task.id, 0)
|
|
|
|
return servers, previous_vts + [virtual_task.id]
|
|
|
|
async def check_rebuild_tasks_tree(manager: ManagerClient, tm: TaskManagerClient, module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
|
|
async def _all_alive():
|
|
if len(servers) == len(await manager.running_servers()):
|
|
return True
|
|
assert servers, "No servers available"
|
|
|
|
logger.info(f"Rebuilding node {servers[0]}")
|
|
rebuilt_server = servers[0]
|
|
await manager.api.rebuild_node(rebuilt_server.ip_addr, 60)
|
|
await wait_for(_all_alive, time.time() + 60)
|
|
|
|
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
|
|
virtual_task = virtual_tasks[0]
|
|
check_virtual_task_status(virtual_task, "done", "rebuild", 1)
|
|
|
|
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
|
|
check_regular_task_status(child, "done", "rebuild: streaming", "node", virtual_task.id, 0)
|
|
|
|
return servers, previous_vts + [virtual_task.id]
|
|
|
|
async def check_remove_node_tasks_tree(manager: ManagerClient, tm: TaskManagerClient,module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
|
|
assert servers, "No servers available"
|
|
|
|
removed_server = servers[0]
|
|
removing_server = servers[1]
|
|
logger.info(f"Stopping node {removed_server}")
|
|
await manager.server_stop_gracefully(removed_server.server_id)
|
|
|
|
logger.info(f"Removing node {removed_server} using {removing_server}")
|
|
await manager.remove_node(removing_server.server_id, removed_server.server_id)
|
|
|
|
servers = servers[1:]
|
|
|
|
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
|
|
virtual_task = virtual_tasks[0]
|
|
check_virtual_task_status(virtual_task, "done", "remove node", len(servers))
|
|
|
|
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
|
|
check_regular_task_status(child, "done", "removenode: streaming", "node", virtual_task.id, 0)
|
|
|
|
return servers, previous_vts + [virtual_task.id]
|
|
|
|
async def poll_for_task(tm: TaskManagerClient, module_name: str, server: ServerInfo, expected_kind: str, expected_type: str):
|
|
async def _get_streaming_tasks(server: ServerInfo) -> list[TaskStats]:
|
|
return [stats for stats in await tm.list_tasks(server.ip_addr, module_name) if stats.kind == expected_kind and stats.type == expected_type]
|
|
|
|
async def _has_streaming_tasks(server: ServerInfo):
|
|
if len(await _get_streaming_tasks(server)) > 0:
|
|
return True
|
|
|
|
await wait_for(partial(_has_streaming_tasks, server), time.time() + 100.) # Wait until streaming task is created.
|
|
|
|
streaming_tasks = await _get_streaming_tasks(server)
|
|
assert len(streaming_tasks) == 1
|
|
|
|
return streaming_tasks[0]
|
|
|
|
|
|
async def check_decommission_tasks_tree(manager: ManagerClient, tm: TaskManagerClient,module_name: str, servers: list[ServerInfo],
|
|
previous_vts: list[TaskID]) -> tuple[list[ServerInfo], list[TaskID]]:
|
|
async def _check_virtual_task(decommissioned_server: ServerInfo, handler):
|
|
logger.info("Checking top level decommission node task")
|
|
|
|
await poll_for_task(tm, module_name, decommissioned_server, "node", "decommission: streaming")
|
|
|
|
virtual_tasks = await get_new_virtual_tasks_statuses(tm, module_name, servers, previous_vts, 1)
|
|
virtual_task = virtual_tasks[0]
|
|
check_virtual_task_status(virtual_task, "running", "decommission", 1)
|
|
|
|
child = await tm.get_task_status(virtual_task.children_ids[0]["node"], virtual_task.children_ids[0]["task_id"])
|
|
check_regular_task_status(child, "running", "decommission: streaming", "node", virtual_task.id, 0)
|
|
|
|
await handler.message()
|
|
|
|
assert servers, "No servers available"
|
|
|
|
decommissioned_server = servers[0]
|
|
injection = "streaming_task_impl_decommission_run"
|
|
handler = await inject_error_one_shot(manager.api, decommissioned_server.ip_addr, injection)
|
|
logger.info(f"Decommissioning node {decommissioned_server}")
|
|
await asyncio.gather(*(manager.decommission_node(decommissioned_server.server_id),
|
|
_check_virtual_task(decommissioned_server, handler)))
|
|
|
|
servers = servers[1:]
|
|
vts_list = await get_new_virtual_tasks_list(tm, module_name, servers[0], previous_vts)
|
|
return servers, previous_vts + [vts_list[0].task_id]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_node_ops_tasks_tree(manager: ManagerClient):
|
|
"""Test node ops task manager tasks."""
|
|
module_name = "node_ops"
|
|
tm = TaskManagerClient(manager.api)
|
|
servers = [await manager.server_add(cmdline=cmdline) for _ in range(3)]
|
|
assert module_name in await tm.list_modules(servers[0].ip_addr), "node_ops module wasn't registered"
|
|
|
|
cql = manager.get_cql()
|
|
async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 1}") as ks:
|
|
await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);")
|
|
await cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({1}, {1});")
|
|
await cql.run_async(f"TRUNCATE {ks}.test;")
|
|
|
|
|
|
servers, vt_ids = await check_bootstrap_tasks_tree(tm, module_name, servers)
|
|
servers, vt_ids = await check_replace_tasks_tree(manager, tm, module_name, servers, vt_ids)
|
|
servers, vt_ids = await check_rebuild_tasks_tree(manager, tm, module_name, servers, vt_ids)
|
|
servers, vt_ids = await check_remove_node_tasks_tree(manager, tm, module_name, servers, vt_ids)
|
|
servers, vt_ids = await check_decommission_tasks_tree(manager, tm, module_name, servers, vt_ids)
|
|
|
|
# Reconnect the driver after topology changes (replace, removenode,
|
|
# decommission) so that the new_test_keyspace cleanup can reach a
|
|
# live node for DROP KEYSPACE.
|
|
await manager.driver_connect()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_node_ops_tasks_ttl(manager: ManagerClient):
|
|
"""Test node ops virtual tasks' ttl."""
|
|
module_name = "node_ops"
|
|
tm = TaskManagerClient(manager.api)
|
|
servers = [await manager.server_add(cmdline=cmdline) for _ in range(2)]
|
|
[await tm.set_user_task_ttl(server.ip_addr, 3) for server in servers]
|
|
time.sleep(3)
|
|
await get_new_virtual_tasks_statuses(tm, module_name, servers, [], expected_task_num=0)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_node_ops_task_wait(manager: ManagerClient):
|
|
"""Test node ops virtual task's wait."""
|
|
async def _decommission(manager: ManagerClient, server: ServerInfo):
|
|
await manager.decommission_node(server.server_id)
|
|
|
|
async def _wait_for_task(tm: TaskManagerClient, module_name: str, server: ServerInfo, handler: InjectionHandler):
|
|
task = await poll_for_task(tm, module_name, servers[1], "cluster", "decommission")
|
|
assert task.state == "running"
|
|
|
|
await handler.message()
|
|
|
|
status = await tm.wait_for_task(server.ip_addr, task.task_id)
|
|
assert status.state == "done"
|
|
|
|
module_name = "node_ops"
|
|
tm = TaskManagerClient(manager.api)
|
|
servers = [await manager.server_add(cmdline=cmdline) for _ in range(2)]
|
|
injection = "streaming_task_impl_decommission_run"
|
|
handler = await inject_error_one_shot(manager.api, servers[0].ip_addr, injection)
|
|
|
|
decommission_task = asyncio.create_task(
|
|
_decommission(manager, servers[0]))
|
|
|
|
waiting_task = asyncio.create_task(
|
|
_wait_for_task(tm, module_name, servers[1], handler))
|
|
|
|
await decommission_task
|
|
await waiting_task
|