diff --git a/test/pylib/log_browsing.py b/test/pylib/log_browsing.py index 57cd0106ef..ac5c831be1 100644 --- a/test/pylib/log_browsing.py +++ b/test/pylib/log_browsing.py @@ -4,6 +4,7 @@ from __future__ import annotations +from test.pylib.util import universalasync_typed_wrap import asyncio import logging import re @@ -11,7 +12,6 @@ from pathlib import Path from typing import TYPE_CHECKING import pytest -import universalasync if TYPE_CHECKING: from asyncio import AbstractEventLoop @@ -22,7 +22,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -@universalasync.wrap +@universalasync_typed_wrap class ScyllaLogFile: """Browse a Scylla log file. diff --git a/test/pylib/manager_client.py b/test/pylib/manager_client.py index 52e3117cad..1df6dfcfe8 100644 --- a/test/pylib/manager_client.py +++ b/test/pylib/manager_client.py @@ -17,7 +17,7 @@ from time import time import logging from test.pylib.log_browsing import ScyllaLogFile from test.pylib.rest_client import UnixRESTClient, ScyllaRESTAPIClient, ScyllaMetricsClient -from test.pylib.util import wait_for, wait_for_cql_and_get_hosts, Host +from test.pylib.util import wait_for, wait_for_cql_and_get_hosts, universalasync_typed_wrap, Host from test.pylib.internal_types import ServerNum, IPAddress, HostID, ServerInfo, ServerUpState from test.pylib.scylla_cluster import ReplaceConfig, ScyllaServer, ScyllaVersionDescription from cassandra.cluster import Session as CassandraSession, \ @@ -29,8 +29,6 @@ import aiohttp import asyncio import allure -import universalasync - logger = logging.getLogger(__name__) @@ -39,7 +37,7 @@ class NoSuchProcess(Exception): ... -@universalasync.wrap +@universalasync_typed_wrap class ManagerClient: """Helper Manager API client Args: diff --git a/test/pylib/rest_client.py b/test/pylib/rest_client.py index 0c4560a1a7..5a2904ba16 100644 --- a/test/pylib/rest_client.py +++ b/test/pylib/rest_client.py @@ -16,11 +16,11 @@ from contextlib import asynccontextmanager from typing import Any, Optional, AsyncIterator import pytest -import universalasync from aiohttp import request, BaseConnector, UnixConnector, ClientTimeout from cassandra.pool import Host # type: ignore # pylint: disable=no-name-in-module from test.pylib.internal_types import IPAddress, HostID +from test.pylib.util import universalasync_typed_wrap logger = logging.getLogger(__name__) @@ -150,7 +150,7 @@ class TCPRESTClient(RESTClient): self.default_port: int = port -@universalasync.wrap +@universalasync_typed_wrap class ScyllaRESTAPIClient: """Async Scylla REST API client""" diff --git a/test/pylib/util.py b/test/pylib/util.py index 9ad457b2e0..64072220d3 100644 --- a/test/pylib/util.py +++ b/test/pylib/util.py @@ -13,13 +13,14 @@ import asyncio import logging import pathlib import os +import universalasync from collections.abc import Awaitable, Callable, Coroutine from functools import cache import random import string -from typing import Optional, TypeVar, Any +from typing import Optional, TypeVar, Any, cast from cassandra.cluster import NoHostAvailable, Session, Cluster # type: ignore # pylint: disable=no-name-in-module from cassandra.protocol import InvalidRequest # type: ignore # pylint: disable=no-name-in-module @@ -364,3 +365,7 @@ def execute_with_tracing(cql : Session, statement : str | Statement, log : bool logger.debug("Tracing {}:\n{}\n".format(statement, "\n".join(page_traces))) return ret + + +def universalasync_typed_wrap(cls: T) -> T: + return cast(T, universalasync.wrap(cls))