diff --git a/test/pylib/scylla_cluster.py b/test/pylib/scylla_cluster.py index d812639ba3..86df9544b0 100644 --- a/test/pylib/scylla_cluster.py +++ b/test/pylib/scylla_cluster.py @@ -365,7 +365,7 @@ class ScyllaServer: start_time: float sleep_interval: float log_file: BufferedWriter | None - host_id: HostID # Host id (UUID) + _host_id: HostID # Host id (UUID) newid = itertools.count(start=1).__next__ # Sequential unique id def __init__(self, mode: str, version: ScyllaVersionDescription, vardir: str | pathlib.Path, @@ -681,16 +681,25 @@ class ScyllaServer: caslog.setLevel(oldlevel) # Any other exception may indicate a problem, and is passed to the caller. - async def get_host_id(self, api: ScyllaRESTAPIClient) -> bool: + async def try_get_host_id(self, api: ScyllaRESTAPIClient) -> Optional[HostID]: """Try to get the host id (also tests Scylla REST API is serving)""" + + if hasattr(self, "_host_id"): + return self._host_id try: - self.host_id = await api.get_host_id(self.ip_addr) - return True + self._host_id = await api.get_host_id(self.ip_addr) + return self._host_id except (aiohttp.ClientConnectionError, HTTPError) as exc: if isinstance(exc, HTTPError) and exc.code >= 500: raise exc - return False - # Any other exception may indicate a problem, and is passed to the caller. + # Any other exception may indicate a problem, and is passed to the caller. + return None + + async def get_host_id(self, api: ScyllaRESTAPIClient) -> HostID: + result = await self.try_get_host_id(api) + if result is None: + raise RuntimeError(f"Failed to get host_id for {self}") + return result @start_stop_lock async def start(self, @@ -732,9 +741,9 @@ class ScyllaServer: self.start_time = time.time() sleep_interval = 0.1 - def report_error(message: str) -> NoReturn: + async def report_error(message: str) -> NoReturn: message += f", server_id {self.server_id}, IP {self.ip_addr}, workdir {self.workdir.name}" - message += f", host_id {getattr(self, 'host_id', '')}" + message += f", host_id {await self.try_get_host_id(api) or ''}" if expected_error is not None: message += f", the node log was expected to contain the string [{expected_error}]" self.logger.error(message) @@ -757,16 +766,16 @@ class ScyllaServer: for line in log_file: if expected_error in line: return - report_error("the node startup failed, but the log file doesn't contain the expected error") - report_error("failed to start the node") + await report_error("the node startup failed, but the log file doesn't contain the expected error") + await report_error("failed to start the node") - if hasattr(self, "host_id") or await self.get_host_id(api): + if await self.try_get_host_id(api): if server_up_state == ServerUpState.PROCESS_STARTED: server_up_state = ServerUpState.HOST_ID_QUERIED server_up_state = await self.get_cql_up_state() or server_up_state if server_up_state == expected_server_up_state: if expected_error is not None: - report_error( + await report_error( f"the node has reached {server_up_state} state," f" but was expected to fail with the expected error" ) @@ -776,9 +785,9 @@ class ScyllaServer: await asyncio.sleep(sleep_interval) if self.stop_event.is_set(): - report_error('failed to start the node as it was requested to be stopped in the meantime') + await report_error('failed to start the node as it was requested to be stopped in the meantime') else: - report_error( + await report_error( f"the node failed to reach the expected state ({expected_server_up_state}) within the timeout," f" last seen state {server_up_state}" ) @@ -918,7 +927,7 @@ class ScyllaServer: self.logger = logger def __str__(self): - host_id = getattr(self, 'host_id', 'undefined id') + host_id = getattr(self, '_host_id', 'undefined id') return f"ScyllaServer({self.server_id}, {self.ip_addr}, {host_id})" def _write_config_file(self) -> None: @@ -1064,7 +1073,7 @@ class ScyllaCluster: replaced_srv = self.servers[replaced_id] if replace_cfg.use_host_id: - extra_config['replace_node_first_boot'] = replaced_srv.host_id + extra_config['replace_node_first_boot'] = await replaced_srv.get_host_id(self.api) else: extra_config['replace_address_first_boot'] = replaced_srv.ip_addr @@ -1653,9 +1662,7 @@ class ScyllaClusterManager: """Host ID of a server.""" server = self.cluster.servers[ServerNum(int(request.match_info["server_id"]))] - if not hasattr(server, "host_id") and not await server.get_host_id(api=self.cluster.api): - raise RuntimeError(f"Failed to get host_id for {server}") - return server.host_id + return await server.get_host_id(self.cluster.api) async def _before_test_req(self, request) -> str: cluster_str = await self._before_test(request.match_info['test_case_name']) @@ -1811,7 +1818,9 @@ class ScyllaClusterManager: # initiate remove try: - await self.cluster.api.remove_node(initiator.ip_addr, to_remove.host_id, ignore_dead, + await self.cluster.api.remove_node(initiator.ip_addr, + await to_remove.get_host_id(self.cluster.api), + ignore_dead, timeout=ScyllaServer.TOPOLOGY_TIMEOUT) except (RuntimeError, HTTPError) as exc: if expected_error: