# -*- coding: utf-8 -*- # Copyright 2025-present ScyllaDB # # SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.1 from cassandra.cluster import NoHostAvailable from cassandra.connection import Connection from contextlib import contextmanager import pytest import re import requests import socket import struct from unittest import mock from test.cqlpy import nodetool from test.cqlpy.util import cql_session from test.pylib.skip_types import skip_env def get_protocol_error_metrics(host) -> int: result = 0 metrics = requests.get(f"http://{host}:9180/metrics").text pattern = re.compile(r'^scylla_transport_cql_errors_total\{shard="\d+",type="protocol_error"\} (\d+)') for metric_line in metrics.split('\n'): match = pattern.match(metric_line) if match: count = int(match.group(1)) result += count return result def get_cpp_exceptions_metrics(host) -> int: result = 0 metrics = requests.get(f"http://{host}:9180/metrics").text pattern = re.compile(r'^scylla_reactor_cpp_exceptions\{shard="\d+"\} (\d+)') for metric_line in metrics.split('\n'): match = pattern.match(metric_line) if match: count = int(match.group(1)) result += count return result @contextmanager def cql_with_protocol(host_str, port, creds, protocol_version): try: with cql_session( host=host_str, port=port, is_ssl=creds["ssl"], username=creds["username"], password=creds["password"], protocol_version=protocol_version, ) as session: yield session session.shutdown() except NoHostAvailable: yield None def try_connect(host, port, creds, protocol_version): with cql_with_protocol(host, port, creds, protocol_version) as session: return 1 if session else 0 @pytest.fixture def debug_exceptions_logging(request, cql): def _read_level() -> str | None: try: level = nodetool.getlogginglevel(cql, "exception") if level: level = level.strip().strip('"').lower() return level except Exception as exc: print(f"Failed to read exception logger level: {exc}") return None def _set_and_verify(level: str) -> bool: try: nodetool.setlogginglevel(cql, "exception", level) except Exception as exc: print(f"Failed to set exception logger level to '{level}': {exc}") return False observed = _read_level() if observed == level: return True print(f"Exception logger level observed as '{observed}' while expecting '{level}'") return False def _restore_logging(): if not enabled and previous_level is None: return target_level = previous_level or "info" _set_and_verify(target_level) previous_level = _read_level() enabled = _set_and_verify("debug") yield _restore_logging() # If there is a protocol version mismatch, the server should # raise a protocol error, which is counted in the metrics. def test_protocol_version_mismatch(scylla_only, debug_exceptions_logging, request, host): run_count = 200 cpp_exception_threshold = 20 cpp_exception_metrics_before = get_cpp_exceptions_metrics(host) protocol_exception_metrics_before = get_protocol_error_metrics(host) port = request.config.getoption("--port") # Use the default superuser credentials, which work for both Scylla and Cassandra creds = { "ssl": request.config.getoption("--ssl"), "username": request.config.getoption("--auth_username") or "cassandra", "password": request.config.getoption("--auth_password") or "cassandra", } successful_session_count = try_connect(host, port, creds, protocol_version=4) assert successful_session_count == 1, "Expected to connect successfully with protocol version 4" for _ in range(run_count): successful_session_count = try_connect(host, port, creds, protocol_version=42) assert successful_session_count == 0, "Expected to fail connecting with protocol version 42" protocol_exception_metrics_after = get_protocol_error_metrics(host) assert protocol_exception_metrics_after > protocol_exception_metrics_before, "Expected protocol errors to increase after the test" cpp_exception_metrics_after = get_cpp_exceptions_metrics(host) assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, "Expected C++ protocol errors to not increase after the test" def _build_frame(*, opcode: int, stream: int, body: bytes) -> bytearray: frame = bytearray() frame += struct.pack("!B", 0x04) # version 4 frame += struct.pack("!B", 0x00) # flags frame += struct.pack("!H", stream) # stream frame += struct.pack("!B", opcode) # opcode frame += struct.pack("!I", len(body)) # body length frame += body return frame def _send_frame(sock: socket.socket, *, opcode: int, stream: int, body: bytes) -> None: sock.send(_build_frame(opcode=opcode, stream=stream, body=body)) def _recv_frame(sock: socket.socket) -> bytes: return sock.recv(4096) # Many protocol errors are caused by sending malformed messages. # It is not possible to reproduce them with the Python driver, # so we use a low-level socket connection to send the messages. # To avoid code duplication of this low-level code, we use a common # implementation function with parameters. To trigger a specific # protocol error, the appropriate trigger should be set to True. def _protocol_error_impl( host, *, trigger_bad_batch=False, trigger_unexpected_auth=False, trigger_process_startup_invalid_string_map=False, trigger_unknown_compression=False, trigger_process_query_internal_malformed_query=False, trigger_process_query_internal_fail_read_options=False, trigger_process_prepare_malformed_query=False, trigger_process_execute_internal_malformed_cache_key=False, trigger_process_register_malformed_string_list=False): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((host, 9042)) try: if trigger_process_startup_invalid_string_map: # STARTUP opcode = 0x01. # Body: map count = 1 (uint16), but no entries -> truncated body = b'\x00\x01' _send_frame(s, opcode=0x01, stream=1, body=body) _recv_frame(s) return if trigger_unknown_compression: # send STARTUP with an unknown COMPRESSION option # two entries in the string map: CQL_VERSION and COMPRESSION body = ( b'\x00\x02' b'\x00\x0bCQL_VERSION\x00\x053.0.0' b'\x00\x0bCOMPRESSION\x00\x07invalid_compression_algorithm' ) _send_frame(s, opcode=0x01, stream=1, body=body) _recv_frame(s) return # STARTUP body = b'\x00\x01\x00\x0bCQL_VERSION\x00\x053.0.0' _send_frame(s, opcode=0x01, stream=1, body=body) frame = _recv_frame(s) # READY or AUTHENTICATE? op = frame[4] assert op in (0x02, 0x03), f"expected READY(2) or AUTHENTICATE(3), got {hex(op)}" if op == 0x02: # READY path if trigger_unexpected_auth: skip_env("server not configured with authentication, skipping unexpected auth test") elif op == 0x03: # AUTHENTICATE path if trigger_unexpected_auth: # send OPTIONS to trigger auth‐state exception _send_frame(s, opcode=0x05, stream=2, body=b'') _recv_frame(s) return # send correct AUTH_RESPONSE body = b'\x00cassandra\x00cassandra' _send_frame(s, opcode=0x0F, stream=2, body=body) # wait for AUTH_SUCCESS (0x10) resp = _recv_frame(s) assert resp[4] == 0x10, f"expected AUTH_SUCCESS, got {hex(resp[4])}" if trigger_bad_batch: # BATCH opcode = 0x0D. # Body: batch type = LOGGED (0), 1 statement, but invalid kind = 255 body = bytearray() body += struct.pack("!B", 0x00) # batch type = LOGGED body += struct.pack("!H", 0x01) # 1 statement body += struct.pack("!B", 0xFF) # INVALID kind = 255 # BATCH opcode 0x0D _send_frame(s, opcode=0x0D, stream=3, body=body) _recv_frame(s) return if trigger_process_query_internal_malformed_query: # QUERY opcode = 0x07. # Body: long-string length (uint32) = 100, but send only 2 bytes -> truncated long_string_len = 100 body = struct.pack("!I", long_string_len) + b"AB" _send_frame(s, opcode=0x07, stream=2, body=body) resp = _recv_frame(s) # Expect ERROR frame (opcode 0x00) assert bool(resp) and resp[4] == 0x00 return if trigger_process_query_internal_fail_read_options: # QUERY opcode = 0x07 # Body: long-string query (uint32 len) + options; PAGE_SIZE flag set but page_size truncated (only 2 bytes provided instead of 4) query = b"SELECT 1" long_len = struct.pack("!I", len(query)) # options: consistency (uint16) + flags (byte with PAGE_SIZE bit = 0x04) + truncated page_size (2 bytes only) options = struct.pack("!H", 0x0001) + struct.pack("!B", 0x04) + b'\x00\x10' # only 2 bytes instead of 4 body = long_len + query + options _send_frame(s, opcode=0x07, stream=2, body=body) resp = _recv_frame(s) # Expect ERROR frame (opcode 0x00) assert bool(resp) and resp[4] == 0x00 return if trigger_process_prepare_malformed_query: # PREPARE opcode = 0x09. # Body: long-string length (uint32) = 100, but send only 2 bytes -> truncated long_string_len = 100 body = struct.pack("!I", long_string_len) + b"AB" _send_frame(s, opcode=0x09, stream=2, body=body) resp = _recv_frame(s) # Expect ERROR frame (opcode 0x00) assert bool(resp) and resp[4] == 0x00 return if trigger_process_execute_internal_malformed_cache_key: # EXECUTE opcode = 0x0A. # Body: short-bytes length (uint16) = 5, but send only 3 bytes -> truncated id declared_len = 5 body = struct.pack("!H", declared_len) + b'ABC' # actual = 2 + 3 = 5 bytes _send_frame(s, opcode=0x0A, stream=2, body=body) resp = _recv_frame(s) # Expect ERROR frame (opcode 0x00) assert bool(resp) and resp[4] == 0x00 return if trigger_process_register_malformed_string_list: # REGISTER opcode = 0x0B # Body: string list count = 1 (uint16) then a string with declared length = 5 but only 3 bytes provided -> truncated body = b'\x00\x01' + b'\x00\x05' + b'ABC' _send_frame(s, opcode=0x0B, stream=2, body=body) resp = _recv_frame(s) # Expect ERROR frame (opcode 0x00) assert bool(resp) and resp[4] == 0x00 finally: s.close() def _test_impl(host, flag): run_count = 200 cpp_exception_threshold = 20 cpp_exception_metrics_before = get_cpp_exceptions_metrics(host) protocol_exception_metrics_before = get_protocol_error_metrics(host) for _ in range(run_count): kwargs = {flag: True} _protocol_error_impl(host, **kwargs) protocol_exception_metrics_after = get_protocol_error_metrics(host) assert protocol_exception_metrics_after > protocol_exception_metrics_before, f"Expected protocol errors to increase after running test with {flag}" cpp_exception_metrics_after = get_cpp_exceptions_metrics(host) assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, f"Expected C++ protocol errors to not increase after running test with {flag}" @pytest.fixture def no_ssl(request): if request.config.getoption("--ssl"): skip_env("skipping non-SSL test on SSL-enabled run") yield # Malformed BATCH with an invalid kind triggers a protocol error. def test_invalid_kind_in_batch_message(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_bad_batch") # Send OPTIONS during AUTHENTICATE to trigger auth-state error. def test_unexpected_message_during_auth(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_unexpected_auth") # STARTUP with an invalid/missing string-map entry should produce a protocol error. def test_process_startup_invalid_string_map(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_startup_invalid_string_map") # STARTUP with unknown COMPRESSION option should produce a protocol error. def test_unknown_compression_algorithm(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_unknown_compression") # QUERY long-string truncation: declared length > provided bytes triggers protocol error. def test_process_query_internal_malformed_query(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_query_internal_malformed_query") # QUERY options malformed: PAGE_SIZE flag set but page_size truncated triggers protocol error. def test_process_query_internal_fail_read_options(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_query_internal_fail_read_options") # PREPARE long-string truncation: declared length > provided bytes triggers protocol error. def test_process_prepare_malformed_query(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_prepare_malformed_query") # EXECUTE cache-key malformed: short-bytes length > provided bytes triggers protocol error. def test_process_execute_internal_malformed_cache_key(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_execute_internal_malformed_cache_key") # REGISTER malformed string list: declared string length > provided bytes triggers protocol error. def test_process_register_malformed_string_list(scylla_only, no_ssl, debug_exceptions_logging, host): _test_impl(host, "trigger_process_register_malformed_string_list") # Test if the protocol exceptions do not decrease after running the test happy path. # This is to ensure that the protocol exceptions are not cleared or reset # during the test execution. def test_no_protocol_exceptions(scylla_only, no_ssl, debug_exceptions_logging, host): run_count = 200 cpp_exception_threshold = 20 cpp_exception_metrics_before = get_cpp_exceptions_metrics(host) protocol_exception_metrics_before = get_protocol_error_metrics(host) for _ in range(run_count): _protocol_error_impl(host) protocol_exception_metrics_after = get_protocol_error_metrics(host) assert protocol_exception_metrics_after == protocol_exception_metrics_before, "Expected protocol errors to not increase" cpp_exception_metrics_after = get_cpp_exceptions_metrics(host) assert cpp_exception_metrics_after - cpp_exception_metrics_before <= cpp_exception_threshold, "Expected C++ protocol errors to not increase" # Regression test for GitHub issue #27452. def test_supported_includes_host_id(scylla_only, no_ssl, cql, host, request): expected_host_id = str(cql.execute("SELECT host_id FROM system.local WHERE key = 'local'").one()[0]) captured_supported = {} original_handle = Connection._handle_options_response def _capture_supported(self, options_response): if hasattr(options_response, "options"): captured_supported.update(options_response.options) return original_handle(self, options_response) with mock.patch.object(Connection, "_handle_options_response", _capture_supported): with cql_session( host=host, port=request.config.getoption("--port"), is_ssl=False, username=request.config.getoption("--auth_username") or "cassandra", password=request.config.getoption("--auth_password") or "cassandra", ): pass assert captured_supported["SCYLLA_HOST_ID"] == [expected_host_id]