104 lines
4.8 KiB
Python
104 lines
4.8 KiB
Python
#
|
|
# Copyright (C) 2024-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
#
|
|
from time import time
|
|
|
|
from test.cluster.tasks.task_manager_types import TaskID, TaskStats, TaskStatus, State
|
|
from test.pylib.internal_types import IPAddress
|
|
from test.pylib.rest_client import ScyllaRESTAPIClient
|
|
|
|
import asyncio
|
|
from typing import Optional
|
|
|
|
from test.pylib.util import wait_for
|
|
|
|
|
|
class TaskManagerClient():
|
|
"""Async Task Manager client"""
|
|
|
|
def __init__(self, api: ScyllaRESTAPIClient):
|
|
self.api = api
|
|
|
|
async def list_modules(self, node_ip: IPAddress) -> list[str]:
|
|
"""Get the list of supported modules."""
|
|
modules = await self.api.client.get_json("/task_manager/list_modules", host=node_ip)
|
|
assert(type(modules) == list)
|
|
return modules
|
|
|
|
async def list_tasks(self, node_ip: IPAddress, module_name: str, internal: bool = False,
|
|
keyspace: Optional[str] = None, table: Optional[str] = None) -> list[TaskStats]:
|
|
"""Get the list of tasks stats in one module."""
|
|
args = { "internal": str(internal) }
|
|
if keyspace:
|
|
args["keyspace"] = keyspace
|
|
if table:
|
|
args["table"] = table
|
|
stats_list = await self.api.client.get_json(f"/task_manager/list_module_tasks/{module_name}", params=args,
|
|
host=node_ip)
|
|
assert(type(stats_list) == list)
|
|
return [TaskStats(**stats_dict) for stats_dict in stats_list]
|
|
|
|
async def wait_task_appears(self, node_ip: IPAddress, module_name: str,
|
|
task_type: Optional[str] = None,
|
|
entity: Optional[str] = None,
|
|
deadline: Optional[float] = None) -> TaskStats:
|
|
"""
|
|
Waits for a task to appear in "running" state based on the specified task filter.
|
|
A task matches the filter if all of its fields mach the specified attributes.
|
|
Throws an exception if no such task appears before the deadline.
|
|
|
|
:return: stats of the first task matching the filter.
|
|
"""
|
|
async def get_tasks():
|
|
tasks = await self.list_tasks(node_ip, module_name)
|
|
for stats in tasks:
|
|
if stats.state == State.running and \
|
|
(task_type is None or stats.type == task_type) and \
|
|
(entity is None or stats.entity == entity):
|
|
return stats
|
|
return None
|
|
return await wait_for(get_tasks, deadline or (time() + 60), period=0.1, backoff_factor=1.2, max_period=1)
|
|
|
|
async def get_task_status(self, node_ip: IPAddress, task_id: TaskID) -> TaskStatus:
|
|
"""Get status of one task."""
|
|
status = await self.api.client.get_json(f"/task_manager/task_status/{task_id}", host=node_ip)
|
|
assert(type(status) == dict)
|
|
return TaskStatus(**status)
|
|
|
|
async def abort_task(self, node_ip: IPAddress, task_id: TaskID) -> None:
|
|
"""Abort a task."""
|
|
await self.api.client.post(f"/task_manager/abort_task/{task_id}", host=node_ip)
|
|
|
|
async def wait_for_task(self, node_ip: IPAddress, task_id: TaskID) -> TaskStatus:
|
|
"""Wait for a task and get its status."""
|
|
status = await self.api.client.get_json(f"/task_manager/wait_task/{task_id}", host=node_ip)
|
|
assert(type(status) == dict)
|
|
return TaskStatus(**status)
|
|
|
|
async def set_task_ttl(self, node_ip: IPAddress, ttl: int) -> int:
|
|
"""Set task ttl and get old value."""
|
|
old_ttl = await self.api.client.post_json("/task_manager/ttl", params={ "ttl": str(ttl) }, host=node_ip)
|
|
assert(type(old_ttl) == int)
|
|
return old_ttl
|
|
|
|
async def set_user_task_ttl(self, node_ip: IPAddress, ttl: int) -> int:
|
|
"""Set task ttl and get old value."""
|
|
old_ttl = await self.api.client.post_json("/task_manager/user_ttl", params={ "user_ttl": str(ttl) }, host=node_ip)
|
|
assert(type(old_ttl) == int)
|
|
return old_ttl
|
|
|
|
async def get_task_status_recursively(self, node_ip: IPAddress, task_id: TaskID) -> list[TaskStatus]:
|
|
"""Get status of a task and all its descendants."""
|
|
status_list = await self.api.client.get_json(f"/task_manager/task_status_recursive/{task_id}", host=node_ip)
|
|
assert(type(status_list) == list)
|
|
return [TaskStatus(**status_dict) for status_dict in status_list]
|
|
|
|
async def drain_module_tasks(self, node_ip: IPAddress, module_name: str, internal: bool = False) -> None:
|
|
"""Drain tasks of one module."""
|
|
tasks = await self.list_tasks(node_ip, module_name, internal=internal)
|
|
await asyncio.gather(*(self.api.client.get(f"/task_manager/wait_task/{stats.task_id}", host=node_ip,
|
|
allow_failed=True) for stats in tasks))
|
|
await self.api.client.post(f"/task_manager/drain/{module_name}", host=node_ip)
|