Files
scylladb/test/nodetool/rest_api_mock.py
Taras Veretilnyk 65ade28a9c rest_api_mock: support duplicate query parameters
Previously, only the last value of a repeated query parameter was captured,
which could cause inaccurate request matching in tests. This update ensures
that all values are preserved by storing duplicates as lists in the `params` dict.
2025-10-01 15:53:25 +02:00

342 lines
14 KiB
Python

#
# Copyright 2023-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
#
import aiohttp
import aiohttp.web
import asyncio
import contextlib
import collections
import json
import logging
import requests
import sys
import traceback
from typing import Any, Callable, Dict
logger = logging.getLogger(__name__)
class approximate_value:
"""
Allow matching query params with a non-exact match, allowing for a given
difference tolerance (delta).
"""
def __init__(self, value=None, delta=None):
self._value = value
self._delta = delta
def __eq__(self, v):
coerced_v = type(self._value)(v)
return abs(self._value - coerced_v) <= self._delta
def to_json(self):
return {"__type__": "approximate_value", "value": self._value, "delta": self._delta}
class expected_request:
ANY = -1 # allow for any number of requests (including no requests at all), similar to the `*` quantity in regexp
ONE = 0 # exactly one request is allowed
MULTIPLE = 1 # one or more request is allowed
def __init__(self, method: str, path: str, params: dict = {}, body: Any = None, multiple: int = ONE,
response: Dict[str, Any] = None, response_status: int = 200, hit: int = 0):
self.method = method
self.path = path.rstrip("/")
self.params = params
self.body = body
self.multiple = multiple
self.response = response
self.response_status = response_status
self.hit = hit
def as_json(self):
def param_to_json(v):
try:
return v.to_json()
except AttributeError:
return v
return {
"method": self.method,
"path": self.path,
"multiple": self.multiple,
"params": {k: param_to_json(v) for k, v in self.params.items()},
"body": self.body,
"response": self.response,
"response_status": self.response_status,
"hit": self.hit}
def __eq__(self, o):
return self.method == o.method and self.path == o.path and self.params == o.params and self.body == o.body
def __str__(self):
return json.dumps(self.as_json())
def exhausted(self):
return ((self.multiple == self.ONE and self.hit > 0)
or self.multiple == self.ANY
or (self.multiple >= self.MULTIPLE and self.hit >= self.multiple))
def _make_param_value(value):
if type(value) is dict and "__type__" in value:
cls = globals()[value["__type__"]]
del value["__type__"]
return cls(**value)
return value
def _make_expected_request(req_json):
return expected_request(
req_json["method"],
req_json["path"],
params={k: _make_param_value(v) for k, v in req_json.get("params", dict()).items()},
body=req_json.get("body"),
multiple=req_json.get("multiple", expected_request.ONE),
response=req_json.get("response"),
response_status=req_json.get("response_status", 200),
hit=req_json.get("hit", 0))
class rest_server():
EXPECTED_REQUESTS_PATH = "__expected_requests__"
UNEXPECTED_REQUESTS_PATH = "__unexpected_requests__"
def __init__(self):
self.expected_requests = collections.defaultdict(list)
self.unexpected_requests = 0
@staticmethod
def _request_key(method, path):
return f"{method}:{path.rstrip('/')}"
async def get_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
return aiohttp.web.json_response([r.as_json() for rl in self.expected_requests.values() for r in rl])
async def get_unexpected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
return aiohttp.web.json_response(self.unexpected_requests)
async def post_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
payload = await request.json()
for request in map(_make_expected_request, payload):
self.expected_requests[self._request_key(request.method, request.path)].append(request)
logger.info(f"expected_requests: {self.expected_requests}")
return aiohttp.web.json_response({})
async def delete_expected_requests(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
self.expected_requests.clear()
self.unexpected_requests = 0
return aiohttp.web.json_response({})
async def handle_generic_request(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
request_key = self._request_key(request.method, request.path)
try:
expected_requests = self.expected_requests[request_key]
except KeyError:
self.unexpected_requests += 1
return aiohttp.web.Response(status=404, text=f"Request {request_key} not found in expected requests")
body = None
if request.can_read_body:
# only JSON-encoded payload is supported
body = await request.json()
params = {}
for key, value in request.query.items():
if key in params:
# Convert single value to list if we encounter a duplicate key
if not isinstance(params[key], list):
params[key] = [params[key]]
params[key].append(value)
else:
params[key] = value
this_req = expected_request(request.method, request.path, params=params, body=body)
if len(expected_requests) == 0:
self.unexpected_requests += 1
logger.error(f"unexpected request, expected no request, got {this_req}")
return aiohttp.web.Response(status=500, text=f"Expected no requests, got {this_req}")
expected_req = None
expected_req_index = None
for i, req in enumerate(expected_requests):
if this_req == req:
expected_req = req
expected_req_index = i
break
if expected_req is None:
reqs = '\n'.join([str(r) for r in expected_requests])
self.unexpected_requests += 1
logger.error(f"unexpected request, request {this_req} matches none of the expected requests:\n{reqs}")
return aiohttp.web.Response(status=500, text=f"Request {this_req} doesn't match any expected request")
if expected_req.multiple == expected_request.ONE:
del expected_requests[expected_req_index]
else:
expected_req.hit += 1
if expected_req.response is None:
logger.info(f"expected_request: {expected_req}, no response")
return aiohttp.web.json_response({})
else:
logger.info(f"expected_request: {expected_req}, response: {expected_req.response}")
return aiohttp.web.json_response(expected_req.response, status=expected_req.response_status)
async def run_server(ip, port):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s.%(msecs)03d %(levelname)s %(name)s - %(message)s",
datefmt="%H:%M:%S",
)
server = rest_server()
app = aiohttp.web.Application()
def wrap_handler(handler: Callable) -> Callable:
async def catching_handler(request) -> aiohttp.web.Response:
"""Catch all exceptions and return them to the client.
Without this, the client would get an 'Internal server error' message
without any details. Thanks to this the test log shows the actual error.
"""
try:
ret = await handler(request)
if ret is not None:
return ret
return aiohttp.web.Response()
except Exception as e:
tb = traceback.format_exc()
logger.error(f'Exception when executing {handler.__name__}: {e}\n{tb}')
return aiohttp.web.Response(status=500, text=str(e))
return catching_handler
app.router.add_routes([
aiohttp.web.get(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.get_expected_requests)),
aiohttp.web.post(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.post_expected_requests)),
aiohttp.web.delete(f"/{server.EXPECTED_REQUESTS_PATH}", wrap_handler(server.delete_expected_requests)),
aiohttp.web.get(f"/{server.UNEXPECTED_REQUESTS_PATH}", wrap_handler(server.get_unexpected_requests)),
# Register all required rest API paths.
# Unfortunately, we have to register here all the different routes, used by tests.
# Fortunately, aiohttp supports variable paths and with that, there is not that many paths to register.
aiohttp.web.route("*", "/cache_service/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/cache_service/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/column_family/{part1}/{part2}/{part3}/{part4}",
wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/compaction_manager/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/compaction_manager/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/failure_detector/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/gossiper/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/messaging_service/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/snitch/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_proxy/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_proxy/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/storage_service/{part1}/{part2}/{part3}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/stream_manager/", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/system/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/system/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/task_manager/{part1}", wrap_handler(server.handle_generic_request)),
aiohttp.web.route("*", "/task_manager/{part1}/{part2}", wrap_handler(server.handle_generic_request)),
])
logger.info("start serving")
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, ip, port)
await site.start()
try:
while True:
await asyncio.sleep(3600) # sleep forever
except asyncio.exceptions.CancelledError:
pass
logger.info("stopping")
await runner.cleanup()
def get_expected_requests(server):
"""Get the expected requests list from the server.
This will contain all the unconsumed expected request currently on the
server. Can be used to check whether all expected requests arrived.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
"""
ip, port = server
r = requests.get(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}")
r.raise_for_status()
try:
return [_make_expected_request(r) for r in r.json()]
except json.decoder.JSONDecodeError:
logger.exception('unable to decode server response as JSON: %r', r)
raise
def get_unexpected_requests(server):
"""Get the number of unexpeced requests from the server.
Any requests which didn't match an expected request is unexpected.
The amount of such requests is stored in a counter.
This counter is reset when clear_expected_requests() is called.
"""
ip, port = server
r = requests.get(f"http://{ip}:{port}/{rest_server.UNEXPECTED_REQUESTS_PATH}")
r.raise_for_status()
return r.json()
def clear_expected_requests(server):
"""Clear the expected requests list on the server.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
"""
ip, port = server
r = requests.delete(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}")
r.raise_for_status()
def set_expected_requests(server, expected_requests):
"""Set the expected requests list on the server.
Params:
* server - resolved `rest_api_mock_server` fixture (see conftest.py).
* requests - a list of request objects
"""
ip, port = server
payload = json.dumps([r.as_json() for r in expected_requests])
r = requests.post(f"http://{ip}:{port}/{rest_server.EXPECTED_REQUESTS_PATH}", data=payload)
r.raise_for_status()
@contextlib.contextmanager
def expected_requests_manager(server, expected_requests):
clear_expected_requests(server)
set_expected_requests(server, expected_requests)
try:
yield
finally:
clear_expected_requests(server)
if __name__ == '__main__':
sys.exit(asyncio.run(run_server(sys.argv[1], int(sys.argv[2]))))