Files
scylladb/test/cluster/tasks/task_manager_client.py

104 lines
4.8 KiB
Python

#
# Copyright (C) 2024-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
from time import time
from test.cluster.tasks.task_manager_types import TaskID, TaskStats, TaskStatus, State
from test.pylib.internal_types import IPAddress
from test.pylib.rest_client import ScyllaRESTAPIClient
import asyncio
from typing import Optional
from test.pylib.util import wait_for
class TaskManagerClient():
"""Async Task Manager client"""
def __init__(self, api: ScyllaRESTAPIClient):
self.api = api
async def list_modules(self, node_ip: IPAddress) -> list[str]:
"""Get the list of supported modules."""
modules = await self.api.client.get_json("/task_manager/list_modules", host=node_ip)
assert(type(modules) == list)
return modules
async def list_tasks(self, node_ip: IPAddress, module_name: str, internal: bool = False,
keyspace: Optional[str] = None, table: Optional[str] = None) -> list[TaskStats]:
"""Get the list of tasks stats in one module."""
args = { "internal": str(internal) }
if keyspace:
args["keyspace"] = keyspace
if table:
args["table"] = table
stats_list = await self.api.client.get_json(f"/task_manager/list_module_tasks/{module_name}", params=args,
host=node_ip)
assert(type(stats_list) == list)
return [TaskStats(**stats_dict) for stats_dict in stats_list]
async def wait_task_appears(self, node_ip: IPAddress, module_name: str,
task_type: Optional[str] = None,
entity: Optional[str] = None,
deadline: Optional[float] = None) -> TaskStats:
"""
Waits for a task to appear in "running" state based on the specified task filter.
A task matches the filter if all of its fields mach the specified attributes.
Throws an exception if no such task appears before the deadline.
:return: stats of the first task matching the filter.
"""
async def get_tasks():
tasks = await self.list_tasks(node_ip, module_name)
for stats in tasks:
if stats.state == State.running and \
(task_type is None or stats.type == task_type) and \
(entity is None or stats.entity == entity):
return stats
return None
return await wait_for(get_tasks, deadline or (time() + 60), period=0.1, backoff_factor=1.2, max_period=1)
async def get_task_status(self, node_ip: IPAddress, task_id: TaskID) -> TaskStatus:
"""Get status of one task."""
status = await self.api.client.get_json(f"/task_manager/task_status/{task_id}", host=node_ip)
assert(type(status) == dict)
return TaskStatus(**status)
async def abort_task(self, node_ip: IPAddress, task_id: TaskID) -> None:
"""Abort a task."""
await self.api.client.post(f"/task_manager/abort_task/{task_id}", host=node_ip)
async def wait_for_task(self, node_ip: IPAddress, task_id: TaskID) -> TaskStatus:
"""Wait for a task and get its status."""
status = await self.api.client.get_json(f"/task_manager/wait_task/{task_id}", host=node_ip)
assert(type(status) == dict)
return TaskStatus(**status)
async def set_task_ttl(self, node_ip: IPAddress, ttl: int) -> int:
"""Set task ttl and get old value."""
old_ttl = await self.api.client.post_json("/task_manager/ttl", params={ "ttl": str(ttl) }, host=node_ip)
assert(type(old_ttl) == int)
return old_ttl
async def set_user_task_ttl(self, node_ip: IPAddress, ttl: int) -> int:
"""Set task ttl and get old value."""
old_ttl = await self.api.client.post_json("/task_manager/user_ttl", params={ "user_ttl": str(ttl) }, host=node_ip)
assert(type(old_ttl) == int)
return old_ttl
async def get_task_status_recursively(self, node_ip: IPAddress, task_id: TaskID) -> list[TaskStatus]:
"""Get status of a task and all its descendants."""
status_list = await self.api.client.get_json(f"/task_manager/task_status_recursive/{task_id}", host=node_ip)
assert(type(status_list) == list)
return [TaskStatus(**status_dict) for status_dict in status_list]
async def drain_module_tasks(self, node_ip: IPAddress, module_name: str, internal: bool = False) -> None:
"""Drain tasks of one module."""
tasks = await self.list_tasks(node_ip, module_name, internal=internal)
await asyncio.gather(*(self.api.client.get(f"/task_manager/wait_task/{stats.task_id}", host=node_ip,
allow_failed=True) for stats in tasks))
await self.api.client.post(f"/task_manager/drain/{module_name}", host=node_ip)