Files
scylladb/test/nodetool/rest_api_mock.py
Piotr Smaron a3360ee385 test/nodetool: fix mock server port race by using a fixed port on a unique IP
Symptom: the rest_api_mock subprocess exits with status 1 during fixture
setup, e.g.:

    subprocess.CalledProcessError: Command '[..., 'rest_api_mock.py',
        '127.29.88.1', '34093']' returned non-zero exit status 1

Root cause: aiohttp's TCPSite.start() raises OSError(EADDRINUSE) and the
process exits 1. The bind fails because of how the (ip, port) pair is
chosen across modules within one test.py process:

  * Each test module leases a 127.x.y.z IP from the host registry. The
    registry recycles released IPs, so the same IP is shared across
    modules sequentially.
  * The original code picked the port via random.randint(10000, 65535).
    A previous module on the same IP could have left that port in
    TIME_WAIT (or worse, still actively in use) when a later module
    happened to pick the same port.

SCYLLADB-1275 (PR 29314) tried to fix this by binding a probe socket to
(ip, 0) to obtain an OS-assigned free port, closing the probe, then
launching the mock server which would bind to that port. Two issues
remained:

  1. TOCTOU: between probe close and mock-server bind, any other process
     on the host could grab the just-freed port.
  2. TIME_WAIT could still bite if the host registry recycled an IP and
     the OS reused the same port number for the probe.

Fix: drop port discovery entirely. Use a fixed port (12345, matching the
unshare-namespace path already in this fixture) on the unique IP from
the host registry. Because IPs are unique per test module within one
test.py process, the (ip, 12345) pair is unique to each module, so no
port-collision dance is needed.

reuse_address=True on TCPSite handles the residual TIME_WAIT case when
the host registry recycles an IP within the same test.py process and
the previous mock server's socket has not finished TIME_WAIT yet.
reuse_port=True is dropped, as it was only useful while attempting to
have multiple processes share a single port.

This mirrors the design used in test/cqlpy/run.py: pick a unique IP,
keep the port fixed.

Fixes: SCYLLADB-1718

Closes scylladb/scylladb#29656
2026-05-04 15:33:19 +02:00

345 lines
14 KiB
Python

#
# Copyright 2023-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1
#
import aiohttp
import aiohttp.web
import asyncio
import contextlib
import collections
import json
import logging
import requests
import sys
import traceback
from typing import Any, Callable, Dict
logger = logging.getLogger(__name__)
class approximate_value:
"""
Allow matching query params with a non-exact match, allowing for a given
difference tolerance (delta).
"""
def __init__(self, value=None, delta=None):
self._value = value
self._delta = delta
def __eq__(self, v):
coerced_v = type(self._value)(v)
return abs(self._value - coerced_v) <= self._delta
def to_json(self):
return {"__type__": "approximate_value", "value": self._value, "delta": self._delta}
class expected_request:
ANY = -1 # allow for any number of requests (including no requests at all), similar to the `*` quantity in regexp
ONE = 0 # exactly one request is allowed
MULTIPLE = 1 # one or more request is allowed
def __init__(self, method: str, path: str, params: dict = {}, body: Any = None, multiple: int = ONE,
response: Dict[str, Any] = None, response_status: int = 200, hit: int = 0):
self.method = method
self.path = path.rstrip("/")
self.params = params
self.body = body
self.multiple = multiple
self.response = response
self.response_status = response_status
self.hit = hit
def as_json(self):
def param_to_json(v):
try:
return v.to_json()
except AttributeError:
return v
return {
"method": self.method,
"path": self.path,
"multiple": self.multiple,
"params": {k: param_to_json(v) for k, v in self.params.items()},
"body": self.body,
"response": self.response,
"response_status": self.response_status,
"hit": self.hit}
def __eq__(self, o):
return self.method == o.method and self.path == o.path and self.params == o.params and self.body == o.body
def __str__(self):
return json.dumps(self.as_json())
def exhausted(self):
return ((self.multiple == self.ONE and self.hit > 0)
or self.multiple == self.ANY
or (self.multiple >= self.MULTIPLE and self.hit >= self.multiple))
def _make_param_value(value):
if type(value) is dict and "__type__" in value:
cls = globals()[value["__type__"]]
del value["__type__"]
return cls(**value)
return value
def _make_expected_request(req_json):
return expected_request(
req_json["method"],
req_json["path"],
params={k: _make_param_value(v) for k, v in req_json.get("params", dict()).items()},
body=req_json.get("body"),
multiple=req_json.get("multiple", expected_request.ONE),
response=req_json.get("response"),
response_status=req_json.get("response_status", 200),
hit=req_json.get("hit", 0))
class rest_server():
EXPECTED_REQUESTS_PATH = "__expected_requests__"
UNEXPECTED_REQUESTS_PATH = "__unexpected_requests__"
def __init__(self):
self.expected_requests = collections.defaultdict(list)
self.unexpected_requests = 0
@staticmethod
def _request_key(method, path):
return f"{method}:{path.rstrip('/')}"
async def get_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
return aiohttp.web.json_response([r.as_json() for rl in self.expected_requests.values() for r in rl])
async def get_unexpected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
return aiohttp.web.json_response(self.unexpected_requests)
async def post_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
payload = await request.json()
for request in map(_make_expected_request, payload):
self.expected_requests[self._request_key(request.method, request.path)].append(request)
logger.info(f"expected_requests: {self.expected_requests}")
return aiohttp.web.json_response({})
async def delete_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
self.expected_requests.clear()
self.unexpected_requests = 0
return aiohttp.web.json_response({})
async def handle_generic_request(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
request_key = self._request_key(request.method, request.path)
try:
expected_requests = self.expected_requests[request_key]
except KeyError:
self.unexpected_requests += 1
return aiohttp.web.Response(status=404, text=f"Request {request_key} not found in expected requests")
body = None
if request.can_read_body:
# only JSON-encoded payload is supported
body = await request.json()
params = {}
for key, value in request.query.items():
if key in params:
# Convert single value to list if we encounter a duplicate key
if not isinstance(params[key], list):
params[key] = [params[key]]
params[key].append(value)
else:
params[key] = value
this_req = expected_request(request.method, request.path, params=params, body=body)
if len(expected_requests) == 0:
self.unexpected_requests += 1
logger.error(f"unexpected request, expected no request, got {this_req}")
return aiohttp.web.Response(status=500, text=f"Expected no requests, got {this_req}")
expected_req = None
expected_req_index = None
for i, req in enumerate(expected_requests):
if this_req == req:
expected_req = req
expected_req_index = i
break
if expected_req is None:
reqs = '\n'.join([str(r) for r in expected_requests])
self.unexpected_requests += 1
logger.error(f"unexpected request, request {this_req} matches none of the expected requests:\n{reqs}")
return aiohttp.web.Response(status=500, text=f"Request {this_req} doesn't match any expected request")
if expected_req.multiple == expected_request.ONE:
del expected_requests[expected_req_index]
else:
expected_req.hit += 1
if expected_req.response is None:
logger.info(f"expected_request: {expected_req}, no response")
return aiohttp.web.json_response({})
else:
logger.info(f"expected_request: {expected_req}, response: {expected_req.response}")
return aiohttp.web.json_response(expected_req.response, status=expected_req.response_status)
async def run_server(ip, port):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d %(levelname)s %(name)s - %(message)s",
datefmt="%H:%M:%S",
)
server = rest_server()
app = aiohttp.web.Application()
def wrap_handler(handler: Callable) -> Callable:
async def catching_handler(request) -> aiohttp.web.Response:
"""Catch all exceptions and return them to the client.
Without this, the client would get an 'Internal server error' message
without any details. Thanks to this the test log shows the actual error.
"""
try:
ret = await handler(request)
if ret is not None:
return ret
return aiohttp.web.Response()
except Exception as e:
tb = traceback.format_exc()
logger.error(f'Exception when executing {handler.__name__}: {e}\n{tb}')
return aiohttp.web.Response(status=500, text=str(e))
return catching_handler
app.router.add_routes([
aiohttp.web.get(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.get_expected_requests)),
aiohttp.web.post(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.post_expected_requests)),
aiohttp.web.delete(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.delete_expected_requests)),
aiohttp.web.get(f"/{server.UNEXPECTED_REQUESTS_PATH}", wrap_handler(server.get_unexpected_requests)),
# Register all required rest API paths.
# Unfortunately, we have to register here all the different routes, used by tests.
# Fortunately, aiohttp supports variable paths and with that, there is not that many paths to register.
aiohttp.web.route("*", "/cache_service/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/cache_service/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}/{part3}/{part4}",
wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/compaction_manager/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/compaction_manager/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/failure_detector/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/gossiper/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/messaging_service/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/snitch/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_proxy/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_proxy/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/stream_manager/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/system/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/system/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/task_manager/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/task_manager/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
])
logger.info("start serving")
runner = aiohttp.web.AppRunner(app)
await runner.setup()
# reuse_address lets the server bind even if a previous mock server
# left the (ip, port) pair in TIME_WAIT. This can happen when the host
# registry recycles an IP across modules.
site = aiohttp.web.TCPSite(runner, ip, port, reuse_address=True)
await site.start()
try:
while True:
await asyncio.sleep(3600) # sleep forever
except asyncio.exceptions.CancelledError:
pass
logger.info("stopping")
await runner.cleanup()
def get_expected_requests(server):
"""Get the expected requests list from the server.
This will contain all the unconsumed expected request currently on the
server. Can be used to check whether all expected requests arrived.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
"""
ip, port = server
r = requests.get(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}")
r.raise_for_status()
try:
return [_make_expected_request(r) for r in r.json()]
except json.decoder.JSONDecodeError:
logger.exception('unable to decode server response as JSON: %r', r)
raise
def get_unexpected_requests(server):
"""Get the number of unexpeced requests from the server.
Any requests which didn't match an expected request is unexpected.
The amount of such requests is stored in a counter.
This counter is reset when clear_expected_requests() is called.
"""
ip, port = server
r = requests.get(f"http://{ip}:{port}/{rest_server.UNEXPECTED_REQUESTS_PATH}")
r.raise_for_status()
return r.json()
def clear_expected_requests(server):
"""Clear the expected requests list on the server.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
"""
ip, port = server
r = requests.delete(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}")
r.raise_for_status()
def set_expected_requests(server, expected_requests):
"""Set the expected requests list on the server.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
* requests - a list of request objects
"""
ip, port = server
payload = json.dumps([r.as_json() for r in expected_requests])
r = requests.post(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}", data=payload)
r.raise_for_status()
@contextlib.contextmanager
def expected_requests_manager(server, expected_requests):
clear_expected_requests(server)
set_expected_requests(server, expected_requests)
try:
yield
finally:
clear_expected_requests(server)
if __name__ == '__main__':
sys.exit(asyncio.run(run_server(sys.argv[1], int(sys.argv[2]))))