diff --git a/test/nodetool/rest_api_mock.py b/test/nodetool/rest_api_mock.py index eac6e1e1e8..3157a7808c 100644 --- a/test/nodetool/rest_api_mock.py +++ b/test/nodetool/rest_api_mock.py @@ -43,11 +43,12 @@ class expected_request: ONE = 0 # exactly one request is allowed MULTIPLE = 1 # one or more request is allowed - def __init__(self, method: str, path: str, params: dict = {}, multiple: int = ONE, + def __init__(self, method: str, path: str, params: dict = {}, body: Any = None, multiple: int = ONE, response: Dict[str, Any] = None, response_status: int = 200, hit: int = 0): self.method = method self.path = path.rstrip("/") self.params = params + self.body = body self.multiple = multiple self.response = response self.response_status = response_status @@ -65,12 +66,13 @@ class expected_request: "path": self.path, "multiple": self.multiple, "params": {k: param_to_json(v) for k, v in self.params.items()}, + "body": self.body, "response": self.response, "response_status": self.response_status, "hit": self.hit} def __eq__(self, o): - return self.method == o.method and self.path == o.path and self.params == o.params + return self.method == o.method and self.path == o.path and self.params == o.params and self.body == o.body def __str__(self): return json.dumps(self.as_json()) @@ -95,6 +97,7 @@ def _make_expected_request(req_json): req_json["method"], req_json["path"], params={k: _make_param_value(v) for k, v in req_json.get("params", dict()).items()}, + body=req_json.get("body"), multiple=req_json.get("multiple", expected_request.ONE), response=req_json.get("response"), response_status=req_json.get("response_status", 200), @@ -140,7 +143,12 @@ class rest_server(): self.unexpected_requests += 1 return aiohttp.web.Response(status=404, text=f"Request {request_key} not found in expected requests") - this_req = expected_request(request.method, request.path, params=dict(request.query)) + body = None + if request.can_read_body: + # only JSON-encoded payload is supported + body = await request.json() + + this_req = expected_request(request.method, request.path, params=dict(request.query), body=body) if len(expected_requests) == 0: self.unexpected_requests += 1