Enable debug logging for "exception" logger inside protocol exception tests. The exceptions will be logged, and it will be possible to see which ones occured if a protocol exceptions test fails. Refs #27272 Refs #27325
363 lines
15 KiB
Python
363 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
||
# Copyright 2025-present ScyllaDB
|
||
#
|
||
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
||
|
||
from cassandra.cluster import NoHostAvailable
|
||
from contextlib import contextmanager
|
||
import pytest
|
||
import re
|
||
import requests
|
||
import socket
|
||
import struct
|
||
from test.cqlpy import nodetool
|
||
from test.cqlpy.util import cql_session
|
||
|
||
def get_protocol_error_metrics(host) -> int:
|
||
result = 0
|
||
metrics = requests.get(f"http://{host}:9180/metrics").text
|
||
pattern = re.compile(r'^scylla_transport_cql_errors_total\{shard="\d+",type="protocol_error"\} (\d+)')
|
||
|
||
for metric_line in metrics.split('\n'):
|
||
match = pattern.match(metric_line)
|
||
if match:
|
||
count = int(match.group(1))
|
||
result += count
|
||
|
||
return result
|
||
|
||
def get_cpp_exceptions_metrics(host) -> int:
|
||
result = 0
|
||
metrics = requests.get(f"http://{host}:9180/metrics").text
|
||
pattern = re.compile(r'^scylla_reactor_cpp_exceptions\{shard="\d+"\} (\d+)')
|
||
|
||
for metric_line in metrics.split('\n'):
|
||
match = pattern.match(metric_line)
|
||
if match:
|
||
count = int(match.group(1))
|
||
result += count
|
||
|
||
return result
|
||
|
||
@contextmanager
|
||
def cql_with_protocol(host_str, port, creds, protocol_version):
|
||
try:
|
||
with cql_session(
|
||
host=host_str,
|
||
port=port,
|
||
is_ssl=creds["ssl"],
|
||
username=creds["username"],
|
||
password=creds["password"],
|
||
protocol_version=protocol_version,
|
||
) as session:
|
||
yield session
|
||
session.shutdown()
|
||
except NoHostAvailable:
|
||
yield None
|
||
|
||
def try_connect(host, port, creds, protocol_version):
|
||
with cql_with_protocol(host, port, creds, protocol_version) as session:
|
||
return 1 if session else 0
|
||
|
||
@pytest.fixture
|
||
def debug_exceptions_logging(request, cql):
|
||
def _read_level() -> str | None:
|
||
try:
|
||
level = nodetool.getlogginglevel(cql, "exception")
|
||
if level:
|
||
level = level.strip().strip('"').lower()
|
||
return level
|
||
except Exception as exc:
|
||
print(f"Failed to read exception logger level: {exc}")
|
||
return None
|
||
|
||
def _set_and_verify(level: str) -> bool:
|
||
try:
|
||
nodetool.setlogginglevel(cql, "exception", level)
|
||
except Exception as exc:
|
||
print(f"Failed to set exception logger level to '{level}': {exc}")
|
||
return False
|
||
|
||
observed = _read_level()
|
||
if observed == level:
|
||
return True
|
||
|
||
print(f"Exception logger level observed as '{observed}' while expecting '{level}'")
|
||
return False
|
||
|
||
def _restore_logging():
|
||
if not enabled and previous_level is None:
|
||
return
|
||
|
||
target_level = previous_level or "info"
|
||
_set_and_verify(target_level)
|
||
|
||
previous_level = _read_level()
|
||
enabled = _set_and_verify("debug")
|
||
|
||
yield
|
||
_restore_logging()
|
||
|
||
# If there is a protocol version mismatch, the server should
|
||
# raise a protocol error, which is counted in the metrics.
|
||
def test_protocol_version_mismatch(scylla_only, debug_exceptions_logging, request, host):
|
||
run_count = 200
|
||
cpp_exception_threshold = 20
|
||
|
||
cpp_exception_metrics_before = get_cpp_exceptions_metrics(host)
|
||
protocol_exception_metrics_before = get_protocol_error_metrics(host)
|
||
|
||
port = request.config.getoption("--port")
|
||
# Use the default superuser credentials, which work for both Scylla and Cassandra
|
||
creds = {
|
||
"ssl": request.config.getoption("--ssl"),
|
||
"username": request.config.getoption("--auth_username") or "cassandra",
|
||
"password": request.config.getoption("--auth_password") or "cassandra",
|
||
}
|
||
|
||
successful_session_count = try_connect(host, port, creds, protocol_version=4)
|
||
assert successful_session_count == 1, "Expected to connect successfully with protocol version 4"
|
||
|
||
for _ in range(run_count):
|
||
successful_session_count = try_connect(host, port, creds, protocol_version=42)
|
||
assert successful_session_count == 0, "Expected to fail connecting with protocol version 42"
|
||
|
||
protocol_exception_metrics_after = get_protocol_error_metrics(host)
|
||
assert protocol_exception_metrics_after > protocol_exception_metrics_before, "Expected protocol errors to increase after the test"
|
||
|
||
cpp_exception_metrics_after = get_cpp_exceptions_metrics(host)
|
||
assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, "Expected C++ protocol errors to not increase after the test"
|
||
|
||
def _build_frame(*, opcode: int, stream: int, body: bytes) -> bytearray:
|
||
frame = bytearray()
|
||
frame += struct.pack("!B", 0x04) # version 4
|
||
frame += struct.pack("!B", 0x00) # flags
|
||
frame += struct.pack("!H", stream) # stream
|
||
frame += struct.pack("!B", opcode) # opcode
|
||
frame += struct.pack("!I", len(body)) # body length
|
||
frame += body
|
||
return frame
|
||
|
||
def _send_frame(sock: socket.socket, *, opcode: int, stream: int, body: bytes) -> None:
|
||
sock.send(_build_frame(opcode=opcode, stream=stream, body=body))
|
||
|
||
def _recv_frame(sock: socket.socket) -> bytes:
|
||
return sock.recv(4096)
|
||
|
||
# Many protocol errors are caused by sending malformed messages.
|
||
# It is not possible to reproduce them with the Python driver,
|
||
# so we use a low-level socket connection to send the messages.
|
||
# To avoid code duplication of this low-level code, we use a common
|
||
# implementation function with parameters. To trigger a specific
|
||
# protocol error, the appropriate trigger should be set to True.
|
||
def _protocol_error_impl(
|
||
host, *,
|
||
trigger_bad_batch=False,
|
||
trigger_unexpected_auth=False,
|
||
trigger_process_startup_invalid_string_map=False,
|
||
trigger_unknown_compression=False,
|
||
trigger_process_query_internal_malformed_query=False,
|
||
trigger_process_query_internal_fail_read_options=False,
|
||
trigger_process_prepare_malformed_query=False,
|
||
trigger_process_execute_internal_malformed_cache_key=False,
|
||
trigger_process_register_malformed_string_list=False):
|
||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
s.connect((host, 9042))
|
||
try:
|
||
if trigger_process_startup_invalid_string_map:
|
||
# STARTUP opcode = 0x01.
|
||
# Body: map count = 1 (uint16), but no entries -> truncated
|
||
body = b'\x00\x01'
|
||
_send_frame(s, opcode=0x01, stream=1, body=body)
|
||
_recv_frame(s)
|
||
return
|
||
|
||
if trigger_unknown_compression:
|
||
# send STARTUP with an unknown COMPRESSION option
|
||
# two entries in the string map: CQL_VERSION and COMPRESSION
|
||
body = (
|
||
b'\x00\x02'
|
||
b'\x00\x0bCQL_VERSION\x00\x053.0.0'
|
||
b'\x00\x0bCOMPRESSION\x00\x07invalid_compression_algorithm'
|
||
)
|
||
_send_frame(s, opcode=0x01, stream=1, body=body)
|
||
_recv_frame(s)
|
||
return
|
||
|
||
# STARTUP
|
||
body = b'\x00\x01\x00\x0bCQL_VERSION\x00\x053.0.0'
|
||
_send_frame(s, opcode=0x01, stream=1, body=body)
|
||
frame = _recv_frame(s)
|
||
|
||
# READY or AUTHENTICATE?
|
||
op = frame[4]
|
||
assert op in (0x02, 0x03), f"expected READY(2) or AUTHENTICATE(3), got {hex(op)}"
|
||
|
||
if op == 0x02:
|
||
# READY path
|
||
if trigger_unexpected_auth:
|
||
pytest.skip("server not configured with authentication, skipping unexpected auth test")
|
||
elif op == 0x03:
|
||
# AUTHENTICATE path
|
||
if trigger_unexpected_auth:
|
||
# send OPTIONS to trigger auth‐state exception
|
||
_send_frame(s, opcode=0x05, stream=2, body=b'')
|
||
_recv_frame(s)
|
||
return
|
||
|
||
# send correct AUTH_RESPONSE
|
||
body = b'\x00cassandra\x00cassandra'
|
||
_send_frame(s, opcode=0x0F, stream=2, body=body)
|
||
# wait for AUTH_SUCCESS (0x10)
|
||
resp = _recv_frame(s)
|
||
assert resp[4] == 0x10, f"expected AUTH_SUCCESS, got {hex(resp[4])}"
|
||
|
||
if trigger_bad_batch:
|
||
# BATCH opcode = 0x0D.
|
||
# Body: batch type = LOGGED (0), 1 statement, but invalid kind = 255
|
||
body = bytearray()
|
||
body += struct.pack("!B", 0x00) # batch type = LOGGED
|
||
body += struct.pack("!H", 0x01) # 1 statement
|
||
body += struct.pack("!B", 0xFF) # INVALID kind = 255
|
||
# BATCH opcode 0x0D
|
||
_send_frame(s, opcode=0x0D, stream=3, body=body)
|
||
_recv_frame(s)
|
||
return
|
||
|
||
if trigger_process_query_internal_malformed_query:
|
||
# QUERY opcode = 0x07.
|
||
# Body: long-string length (uint32) = 100, but send only 2 bytes -> truncated
|
||
long_string_len = 100
|
||
body = struct.pack("!I", long_string_len) + b"AB"
|
||
_send_frame(s, opcode=0x07, stream=2, body=body)
|
||
resp = _recv_frame(s)
|
||
# Expect ERROR frame (opcode 0x00)
|
||
assert bool(resp) and resp[4] == 0x00
|
||
return
|
||
|
||
if trigger_process_query_internal_fail_read_options:
|
||
# QUERY opcode = 0x07
|
||
# Body: long-string query (uint32 len) + options; PAGE_SIZE flag set but page_size truncated (only 2 bytes provided instead of 4)
|
||
query = b"SELECT 1"
|
||
long_len = struct.pack("!I", len(query))
|
||
# options: consistency (uint16) + flags (byte with PAGE_SIZE bit = 0x04) + truncated page_size (2 bytes only)
|
||
options = struct.pack("!H", 0x0001) + struct.pack("!B", 0x04) + b'\x00\x10' # only 2 bytes instead of 4
|
||
body = long_len + query + options
|
||
_send_frame(s, opcode=0x07, stream=2, body=body)
|
||
resp = _recv_frame(s)
|
||
# Expect ERROR frame (opcode 0x00)
|
||
assert bool(resp) and resp[4] == 0x00
|
||
return
|
||
|
||
if trigger_process_prepare_malformed_query:
|
||
# PREPARE opcode = 0x09.
|
||
# Body: long-string length (uint32) = 100, but send only 2 bytes -> truncated
|
||
long_string_len = 100
|
||
body = struct.pack("!I", long_string_len) + b"AB"
|
||
_send_frame(s, opcode=0x09, stream=2, body=body)
|
||
resp = _recv_frame(s)
|
||
# Expect ERROR frame (opcode 0x00)
|
||
assert bool(resp) and resp[4] == 0x00
|
||
return
|
||
|
||
if trigger_process_execute_internal_malformed_cache_key:
|
||
# EXECUTE opcode = 0x0A.
|
||
# Body: short-bytes length (uint16) = 5, but send only 3 bytes -> truncated id
|
||
declared_len = 5
|
||
body = struct.pack("!H", declared_len) + b'ABC' # actual = 2 + 3 = 5 bytes
|
||
_send_frame(s, opcode=0x0A, stream=2, body=body)
|
||
resp = _recv_frame(s)
|
||
# Expect ERROR frame (opcode 0x00)
|
||
assert bool(resp) and resp[4] == 0x00
|
||
return
|
||
|
||
if trigger_process_register_malformed_string_list:
|
||
# REGISTER opcode = 0x0B
|
||
# Body: string list count = 1 (uint16) then a string with declared length = 5 but only 3 bytes provided -> truncated
|
||
body = b'\x00\x01' + b'\x00\x05' + b'ABC'
|
||
_send_frame(s, opcode=0x0B, stream=2, body=body)
|
||
resp = _recv_frame(s)
|
||
# Expect ERROR frame (opcode 0x00)
|
||
assert bool(resp) and resp[4] == 0x00
|
||
|
||
finally:
|
||
s.close()
|
||
|
||
def _test_impl(host, flag):
|
||
run_count = 200
|
||
cpp_exception_threshold = 20
|
||
|
||
cpp_exception_metrics_before = get_cpp_exceptions_metrics(host)
|
||
protocol_exception_metrics_before = get_protocol_error_metrics(host)
|
||
|
||
for _ in range(run_count):
|
||
kwargs = {flag: True}
|
||
_protocol_error_impl(host, **kwargs)
|
||
|
||
protocol_exception_metrics_after = get_protocol_error_metrics(host)
|
||
assert protocol_exception_metrics_after > protocol_exception_metrics_before, f"Expected protocol errors to increase after running test with {flag}"
|
||
|
||
cpp_exception_metrics_after = get_cpp_exceptions_metrics(host)
|
||
assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, f"Expected C++ protocol errors to not increase after running test with {flag}"
|
||
|
||
@pytest.fixture
|
||
def no_ssl(request):
|
||
if request.config.getoption("--ssl"):
|
||
pytest.skip("skipping non-SSL test on SSL-enabled run")
|
||
yield
|
||
|
||
# Malformed BATCH with an invalid kind triggers a protocol error.
|
||
def test_invalid_kind_in_batch_message(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_bad_batch")
|
||
|
||
# Send OPTIONS during AUTHENTICATE to trigger auth-state error.
|
||
def test_unexpected_message_during_auth(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_unexpected_auth")
|
||
|
||
# STARTUP with an invalid/missing string-map entry should produce a protocol error.
|
||
def test_process_startup_invalid_string_map(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_startup_invalid_string_map")
|
||
|
||
# STARTUP with unknown COMPRESSION option should produce a protocol error.
|
||
def test_unknown_compression_algorithm(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_unknown_compression")
|
||
|
||
# QUERY long-string truncation: declared length > provided bytes triggers protocol error.
|
||
def test_process_query_internal_malformed_query(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_query_internal_malformed_query")
|
||
|
||
# QUERY options malformed: PAGE_SIZE flag set but page_size truncated triggers protocol error.
|
||
def test_process_query_internal_fail_read_options(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_query_internal_fail_read_options")
|
||
|
||
# PREPARE long-string truncation: declared length > provided bytes triggers protocol error.
|
||
def test_process_prepare_malformed_query(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_prepare_malformed_query")
|
||
|
||
# EXECUTE cache-key malformed: short-bytes length > provided bytes triggers protocol error.
|
||
def test_process_execute_internal_malformed_cache_key(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_execute_internal_malformed_cache_key")
|
||
|
||
# REGISTER malformed string list: declared string length > provided bytes triggers protocol error.
|
||
def test_process_register_malformed_string_list(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
_test_impl(host, "trigger_process_register_malformed_string_list")
|
||
|
||
# Test if the protocol exceptions do not decrease after running the test happy path.
|
||
# This is to ensure that the protocol exceptions are not cleared or reset
|
||
# during the test execution.
|
||
def test_no_protocol_exceptions(scylla_only, no_ssl, debug_exceptions_logging, host):
|
||
run_count = 200
|
||
cpp_exception_threshold = 20
|
||
|
||
cpp_exception_metrics_before = get_cpp_exceptions_metrics(host)
|
||
protocol_exception_metrics_before = get_protocol_error_metrics(host)
|
||
|
||
for _ in range(run_count):
|
||
_protocol_error_impl(host)
|
||
|
||
protocol_exception_metrics_after = get_protocol_error_metrics(host)
|
||
assert protocol_exception_metrics_after == protocol_exception_metrics_before, "Expected protocol errors to not increase"
|
||
|
||
cpp_exception_metrics_after = get_cpp_exceptions_metrics(host)
|
||
assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, "Expected C++ protocol errors to not increase"
|