diff --git a/test/cluster/auth_cluster/test_raft_service_levels.py b/test/cluster/auth_cluster/test_raft_service_levels.py index 44558c5106..ff42186803 100644 --- a/test/cluster/auth_cluster/test_raft_service_levels.py +++ b/test/cluster/auth_cluster/test_raft_service_levels.py @@ -102,7 +102,7 @@ async def test_service_levels_upgrade(request, manager: ManagerClient, build_mod logging.info("Waiting until upgrade finishes") await asyncio.gather(*(wait_until_topology_upgrade_finishes(manager, h.address, time.time() + 60) for h in hosts)) - await wait_until_driver_service_level_created(cql, time.time() + 60) + await wait_until_driver_service_level_created(manager, time.time() + 60) result_v2 = await cql.run_async("SELECT service_level FROM system.service_levels_v2") assert set([sl.service_level for sl in result_v2]) == set(sls + [DRIVER_SL_NAME]) @@ -174,7 +174,7 @@ async def test_service_levels_work_during_recovery(manager: ManagerClient): await manager.servers_see_each_other(servers) await manager.api.upgrade_to_raft_topology(hosts[0].address) await asyncio.gather(*(wait_until_topology_upgrade_finishes(manager, h.address, time.time() + 60) for h in hosts)) - await wait_until_driver_service_level_created(cql, time.time() + 60) + await wait_until_driver_service_level_created(manager, time.time() + 60) logging.info("Validating service levels works in v2 mode after leaving recovery") new_sl = "sl" + unique_name() diff --git a/test/cluster/util.py b/test/cluster/util.py index fd0f7eb834..74288c0bd6 100644 --- a/test/cluster/util.py +++ b/test/cluster/util.py @@ -203,11 +203,14 @@ async def wait_until_topology_upgrade_finishes(manager: ManagerClient, ip_addr: return status == "done" or None await wait_for(check, deadline=deadline, period=1.0) -async def wait_until_driver_service_level_created(cql: Session, deadline: float): +async def wait_until_driver_service_level_created(manager: ManagerClient, deadline: float): + cql = manager.get_cql() async def check(): service_levels = await cql.run_async("LIST ALL SERVICE_LEVELS") return ("driver" in [sl.service_level for sl in service_levels]) or None await wait_for(check, deadline=deadline, period=1.0) + # sync driver service level on all nodes + await asyncio.gather(*(read_barrier(manager.api, s.ip_addr) for s in await manager.running_servers())) async def delete_raft_topology_state(cql: Session, host: Host): await cql.run_async("truncate table system.topology", host=host)