diff --git a/test/pylib/scylla_cluster.py b/test/pylib/scylla_cluster.py index 1f931348f2..273727c43e 100644 --- a/test/pylib/scylla_cluster.py +++ b/test/pylib/scylla_cluster.py @@ -1136,6 +1136,7 @@ class ScyllaClusterManager: self._setup_routes(app) self.runner = aiohttp.web.AppRunner(app) self.tasks_history = dict() + self.server_broken_event = asyncio.Event() def repr_tasks_history(self): out = "Cluster_history" @@ -1207,19 +1208,23 @@ class ScyllaClusterManager: return aiohttp.web.Response(status=500, text=str(e)) return catching_handler - def route_history_wrapper(handler: Callable)-> Callable: - @wraps(handler) - async def inner_wrapper(request): - self.logger.info("[ScyllaClusterManager][%s] %s", asyncio.current_task().get_name(), request.url) - self.tasks_history[asyncio.current_task()] = request - return await handler(request) - return inner_wrapper + def route_history_wrapper(blockable = False)-> Callable: + def outer_wrapper(handler: Callable)-> Callable: + @wraps(handler) + async def inner_wrapper(request): + if blockable and self.server_broken_event.is_set(): + raise Exception("ScyllaClusterManager BROKEN") + self.logger.info("[ScyllaClusterManager][%s] %s", asyncio.current_task().get_name(), request.url) + self.tasks_history[asyncio.current_task()] = request + return await handler(request) + return inner_wrapper + return outer_wrapper def add_get(route: str, handler: Callable): - app.router.add_get(route, make_catching_handler(route_history_wrapper(handler))) + app.router.add_get(route, make_catching_handler(route_history_wrapper()(handler))) def add_put(route: str, handler: Callable): - app.router.add_put(route, make_catching_handler(route_history_wrapper(handler))) + app.router.add_put(route, make_catching_handler(route_history_wrapper(True)(handler))) add_get('/up', self._manager_up) add_get('/cluster/up', self._cluster_up) @@ -1295,12 +1300,19 @@ class ScyllaClusterManager: # copy current tasks tasks = [key for key in self.tasks_history.keys()] # wait for all other tasks in ScyllaClusterManager + try: + for task in tasks: + request = self.tasks_history.pop(task) + if not task.done(): + self.logger.info("wait for task:%s, request:%s", task, request.path_qs) + await asyncio.wait_for(task, timeout=120) + except asyncio.TimeoutError: + self.break_manager(f"error on waiting coro {task.get_name()}") - for task in tasks: - request = self.tasks_history.pop(task) - if not task.done(): - self.logger.info("wait for task:%s, request:%s", task, request.path_qs) - await asyncio.wait_for(task, timeout=120) + # check on tasks leakage + await asyncio.sleep(0.1) + if self.tasks_history: + self.break_manager(f"tasks leakage found {self.tasks_history}") success = _request.match_info["success"] == "True" self.logger.info("Test %s %s, cluster: %s", self.current_test_case_full_name, @@ -1313,6 +1325,12 @@ class ScyllaClusterManager: cluster_str = str(self.cluster) return cluster_str + def break_manager(self, reason): + # make ScyllaClusterManager not operatable from client side + self.logger.error(" %s, BREAK ScyllaClusterManager", reason) + self.server_broken_event.set() + self._mark_dirty(None) + async def _mark_dirty(self, _request) -> None: """Mark current cluster dirty""" assert self.cluster