diff --git a/pgo/exec_cql.py b/pgo/exec_cql.py index d7c5e11812..2bae465000 100644 --- a/pgo/exec_cql.py +++ b/pgo/exec_cql.py @@ -16,6 +16,8 @@ Usage: import argparse, os, sys from typing import Sequence +from test.pylib.driver_utils import safe_driver_shutdown + def read_statements(path: str) -> list[tuple[int, str]]: stms: list[tuple[int, str]] = [] with open(path, 'r', encoding='utf-8') as f: @@ -56,7 +58,7 @@ def exec_statements(statements: list[tuple[int, str]], socket_path: str, timeout print(f"ERROR executing statement from file line {lineno}: {s}\n{e}", file=sys.stderr) return 1 finally: - cluster.shutdown() + safe_driver_shutdown(cluster) return 0 def main(argv: Sequence[str]) -> int: diff --git a/test/cluster/auth_cluster/test_maintenance_socket.py b/test/cluster/auth_cluster/test_maintenance_socket.py index c1e6f54a26..bfe121ba2f 100644 --- a/test/cluster/auth_cluster/test_maintenance_socket.py +++ b/test/cluster/auth_cluster/test_maintenance_socket.py @@ -58,7 +58,7 @@ async def get_ready_maintenance_session(socket_path: str, timeout: int = 60): session.execute("SELECT key FROM system.local LIMIT 1") return session except Exception: - c.shutdown() + safe_driver_shutdown(c) return None session = await wait_for(try_connect, deadline) @@ -90,7 +90,7 @@ async def connect_with_credentials(ip: str, username: str, password: str, timeou try: return c.connect() except NoHostAvailable: - c.shutdown() + safe_driver_shutdown(c) return None return await wait_for(try_connect, time.time() + timeout) @@ -240,7 +240,7 @@ async def test_no_default_superuser_maintenance_socket_ops(manager: ManagerClien except Unauthorized: return True finally: - c.shutdown() + safe_driver_shutdown(c) await wait_for(check_superuser_revoked, time.time() + 60) @@ -257,11 +257,11 @@ async def test_no_default_superuser_maintenance_socket_ops(manager: ManagerClien auth_provider=PlainTextAuthProvider(username=new_role, password=new_role_password)) try: c.connect() - c.shutdown() return None # Still cached, retry except NoHostAvailable: - c.shutdown() return True + finally: + safe_driver_shutdown(c) await wait_for(check_role_dropped, time.time() + 60) diff --git a/test/cluster/auth_cluster/test_prepared_metadata_id.py b/test/cluster/auth_cluster/test_prepared_metadata_id.py index e14e1cce7f..520022cc80 100644 --- a/test/cluster/auth_cluster/test_prepared_metadata_id.py +++ b/test/cluster/auth_cluster/test_prepared_metadata_id.py @@ -19,7 +19,7 @@ from cassandra.policies import WhiteListRoundRobinPolicy from cassandra.protocol import ResultMessage from test.cluster.auth_cluster import extra_scylla_config_options as auth_config -from test.pylib.manager_client import ManagerClient +from test.pylib.manager_client import ManagerClient, safe_driver_shutdown from test.pylib.util import unique_name @@ -138,7 +138,7 @@ def _prepare_and_execute(host: str, query: str) -> tuple[bytes, bool, int]: return prepared_metadata_id, captured["metadata_changed"], len(rows) finally: session.shutdown() - cluster.shutdown() + safe_driver_shutdown(cluster) @pytest.mark.asyncio diff --git a/test/cluster/auth_cluster/test_startup_response.py b/test/cluster/auth_cluster/test_startup_response.py index 8ed34de3ff..985a9a5098 100644 --- a/test/cluster/auth_cluster/test_startup_response.py +++ b/test/cluster/auth_cluster/test_startup_response.py @@ -14,7 +14,7 @@ from unittest import mock from cassandra.cluster import Cluster, DefaultConnection, NoHostAvailable from cassandra import connection from cassandra.auth import PlainTextAuthProvider -from test.pylib.manager_client import ManagerClient +from test.pylib.manager_client import ManagerClient, safe_driver_shutdown from test.cluster.auth_cluster import extra_scylla_config_options as auth_config @pytest.mark.asyncio @@ -51,7 +51,7 @@ async def test_startup_no_auth_response(manager: ManagerClient, build_mode): # We expect failure or timeout pass finally: - c.shutdown() + safe_driver_shutdown(c) def attempt_good_connection(): nonlocal connections_observed @@ -66,7 +66,7 @@ async def test_startup_no_auth_response(manager: ManagerClient, build_mode): if count >= num_connections/2: connections_observed = True finally: - c.shutdown() + safe_driver_shutdown(c) loop = asyncio.get_running_loop() diff --git a/test/cluster/test_describe.py b/test/cluster/test_describe.py index ed56869b1a..54db863d32 100644 --- a/test/cluster/test_describe.py +++ b/test/cluster/test_describe.py @@ -7,7 +7,7 @@ import asyncio import pytest from test.cluster.util import new_test_keyspace, new_test_table -from test.pylib.manager_client import ManagerClient +from test.pylib.manager_client import ManagerClient, safe_driver_shutdown from test.pylib.util import wait_for from cassandra.connection import UnixSocketEndPoint from cassandra.policies import WhiteListRoundRobinPolicy @@ -89,4 +89,4 @@ async def test_describe_cluster_sanity(manager: ManagerClient, mode: str): assert describe_results[0].cluster == system_local_results[0].cluster_name finally: if mode == "maintenance": - cluster.shutdown() + safe_driver_shutdown(cluster) diff --git a/test/manual/bti_cassandra_compatibility_test.py b/test/manual/bti_cassandra_compatibility_test.py index 9839dffe79..4b18036fcf 100644 --- a/test/manual/bti_cassandra_compatibility_test.py +++ b/test/manual/bti_cassandra_compatibility_test.py @@ -39,6 +39,8 @@ import typing import uuid import yaml +from test.pylib.driver_utils import safe_driver_shutdown + ################################################################################ # Common aliases. @@ -612,7 +614,7 @@ async def main(seed: int, partition_count: Optional[int], row_count: Optional[in if list(result_rows) != [row]: raise RuntimeError("Expected: {}, got: {}".format([row], result_rows)) finally: - cluster.shutdown() + safe_driver_shutdown(cluster) if __name__ == "__main__": parser = argparse.ArgumentParser()