Files
scylladb/test/cluster/test_proxy_protocol.py
Piotr Smaron f12e4ea42b test_proxy_protocol: introduce extra logging to aid debugging
In case of an error, we want to see the contents of the system.clients
table to have a better understanding of what happened - whether the
row(s) are really missing or maybe they are there, but 1 digit doesn't
match or the row is half-written.
We'll therefore query for the whole table on the CQL side, and then
filter out the rows we want to later proceed with on the python side.
This way we can dump the contents of the whole system.clients table if
something goes south.
2026-03-06 14:50:12 +01:00

716 lines
24 KiB
Python

# Copyright 2025-present ScyllaDB
#
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
"""
Tests for CQL native transport ports with proxy protocol v2 support.
These tests verify that when Scylla is configured with proxy protocol ports,
clients can connect using the PROXY protocol v2 header and the original
client addresses are properly reported in system.clients.
"""
import asyncio
import logging
import pytest
import socket
import ssl
import struct
import time
from test.pylib.manager_client import ManagerClient
from test.pylib.util import wait_for
logger = logging.getLogger(__name__)
# Proxy protocol v2 signature (12 bytes)
PROXY_V2_SIGNATURE = b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a'
# Port numbers for proxy protocol ports (to avoid conflicts with standard ports)
PROXY_PORT = 29042
PROXY_PORT_SSL = 29142
PROXY_SHARD_AWARE_PORT = 39042
PROXY_SHARD_AWARE_PORT_SSL = 39142
def make_proxy_v2_header(src_addr: str, src_port: int, dst_addr: str, dst_port: int) -> bytes:
"""
Construct a proxy protocol v2 header for IPv4 TCP connections.
The proxy protocol v2 binary format is defined at:
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
"""
header = bytearray()
# Signature (12 bytes)
header.extend(PROXY_V2_SIGNATURE)
# Version (upper 4 bits) and command (lower 4 bits)
# Version 2, PROXY command (0x21)
header.append(0x21)
# Address family (upper 4 bits) and transport protocol (lower 4 bits)
# AF_INET (1) and STREAM/TCP (1) = 0x11
header.append(0x11)
# Length of the address block (12 bytes for IPv4: 4+4+2+2)
header.extend(struct.pack('!H', 12))
# Source address (4 bytes)
header.extend(socket.inet_aton(src_addr))
# Destination address (4 bytes)
header.extend(socket.inet_aton(dst_addr))
# Source port (2 bytes, big-endian)
header.extend(struct.pack('!H', src_port))
# Destination port (2 bytes, big-endian)
header.extend(struct.pack('!H', dst_port))
return bytes(header)
def build_cql_startup_frame(stream: int = 1) -> bytes:
"""Build a CQL protocol v4 STARTUP frame."""
# String map with CQL_VERSION
body = b'\x00\x01\x00\x0bCQL_VERSION\x00\x053.0.0'
frame = bytearray()
frame.append(0x04) # version 4
frame.append(0x00) # flags
frame.extend(struct.pack('!H', stream)) # stream
frame.append(0x01) # opcode: STARTUP
frame.extend(struct.pack('!I', len(body))) # body length
frame.extend(body)
return bytes(frame)
def build_cql_query_frame(query: str, stream: int = 2) -> bytes:
"""Build a CQL protocol v4 QUERY frame."""
query_bytes = query.encode('utf-8')
body = bytearray()
# Long string: 4-byte length + string
body.extend(struct.pack('!I', len(query_bytes)))
body.extend(query_bytes)
# Query parameters: consistency (2 bytes) + flags (1 byte)
body.extend(struct.pack('!H', 0x0001)) # consistency: ONE
body.append(0x00) # flags: none
frame = bytearray()
frame.append(0x04) # version 4
frame.append(0x00) # flags
frame.extend(struct.pack('!H', stream)) # stream
frame.append(0x07) # opcode: QUERY
frame.extend(struct.pack('!I', len(body))) # body length
frame.extend(body)
return bytes(frame)
async def do_cql_handshake(reader, writer):
"""Complete CQL handshake (STARTUP and optional AUTH)."""
# Send CQL STARTUP
startup_frame = build_cql_startup_frame(stream=1)
writer.write(startup_frame)
await writer.drain()
# Read response (READY or AUTHENTICATE)
response = await asyncio.wait_for(reader.read(4096), timeout=100.0)
if len(response) < 9:
raise RuntimeError(f"Short response from server: {len(response)} bytes")
opcode = response[4]
if opcode == 0x03: # AUTHENTICATE
# Send AUTH_RESPONSE with default credentials
auth_body = b'\x00cassandra\x00cassandra'
auth_frame = bytearray()
auth_frame.append(0x04)
auth_frame.append(0x00)
auth_frame.extend(struct.pack('!H', 1))
auth_frame.append(0x0F) # AUTH_RESPONSE
auth_frame.extend(struct.pack('!I', len(auth_body)))
auth_frame.extend(auth_body)
writer.write(bytes(auth_frame))
await writer.drain()
response = await asyncio.wait_for(reader.read(4096), timeout=100.0)
opcode = response[4]
if opcode != 0x10: # AUTH_SUCCESS
raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}")
elif opcode != 0x02: # READY
raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}")
async def send_cql_query(reader, writer, query: str, stream: int = 2) -> bytes:
"""Send a CQL query and return the response."""
query_frame = build_cql_query_frame(query, stream=stream)
writer.write(query_frame)
await writer.drain()
response = await asyncio.wait_for(reader.read(16384), timeout=100.0)
return response
async def send_cql_with_proxy_header(
host: str,
port: int,
proxy_src_addr: str,
proxy_src_port: int,
proxy_dst_addr: str,
proxy_dst_port: int,
query: str
) -> bytes:
"""Connect to a CQL server using proxy protocol v2 and execute a query."""
reader, writer = await asyncio.open_connection(host, port)
try:
# Send proxy protocol v2 header
proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port)
writer.write(proxy_header)
await writer.drain()
# Complete CQL handshake
await do_cql_handshake(reader, writer)
# Send the query
return await send_cql_query(reader, writer, query)
finally:
writer.close()
await writer.wait_closed()
async def send_cql_with_proxy_header_tls(
host: str,
port: int,
proxy_src_addr: str,
proxy_src_port: int,
proxy_dst_addr: str,
proxy_dst_port: int,
query: str
) -> bytes:
"""
Connect to a CQL server using proxy protocol v2 with TLS and execute a query.
The proxy protocol header is sent first over the raw TCP connection,
then the connection is upgraded to TLS before the CQL protocol begins.
"""
# Create a socket and connect
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
loop = asyncio.get_event_loop()
await loop.sock_connect(sock, (host, port))
try:
# Send proxy header on raw socket BEFORE TLS handshake
proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port)
await loop.sock_sendall(sock, proxy_header)
# Create SSL context (don't verify server certificate for testing)
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Wrap the socket with TLS
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, do_handshake_on_connect=False)
# Do TLS handshake (non-blocking)
while True:
try:
ssl_sock.do_handshake()
break
except ssl.SSLWantReadError:
await asyncio.sleep(0.01)
except ssl.SSLWantWriteError:
await asyncio.sleep(0.01)
# Make the SSL socket blocking for simplicity
ssl_sock.setblocking(True)
# Send CQL STARTUP
startup_frame = build_cql_startup_frame(stream=1)
ssl_sock.sendall(startup_frame)
# Read response
response = ssl_sock.recv(4096)
if len(response) < 9:
raise RuntimeError(f"Short response from server: {len(response)} bytes")
opcode = response[4]
if opcode == 0x03: # AUTHENTICATE
auth_body = b'\x00cassandra\x00cassandra'
auth_frame = bytearray()
auth_frame.append(0x04)
auth_frame.append(0x00)
auth_frame.extend(struct.pack('!H', 1))
auth_frame.append(0x0F)
auth_frame.extend(struct.pack('!I', len(auth_body)))
auth_frame.extend(auth_body)
ssl_sock.sendall(bytes(auth_frame))
response = ssl_sock.recv(4096)
opcode = response[4]
if opcode != 0x10:
raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}")
elif opcode != 0x02:
raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}")
# Send query
query_frame = build_cql_query_frame(query, stream=2)
ssl_sock.sendall(query_frame)
# Read response
response = ssl_sock.recv(16384)
return response
finally:
try:
ssl_sock.close()
except:
sock.close()
async def wait_for_results(cql, query: str, expected_count: int, timeout: float = 30.0, filter_fn=None):
"""
Polls `query` until at least `expected_count` rows satisfy `filter_fn` (all rows if no filter is given).
On timeout, logs the full result set from the last poll to aid debugging.
"""
last_rows: list = []
async def check_resultset():
nonlocal last_rows
last_rows = list(await cql.run_async(query))
matching = filter_fn(last_rows) if filter_fn is not None else last_rows
if len(matching) >= expected_count:
return matching
return None
try:
return await wait_for(check_resultset, time.time() + timeout, period=0.1)
except Exception:
logger.error('Timed out waiting for %d matching rows in system.clients. Last poll returned %d total rows:\n%s',
expected_count, len(last_rows),'\n'.join(str(r) for r in last_rows))
raise
# Shared server configuration for all tests
# We configure explicit SSL ports to keep the standard ports unencrypted
# so the Python driver can connect without TLS.
PROXY_SERVER_CONFIG = {
'native_transport_port_proxy_protocol': PROXY_PORT,
'native_transport_port_ssl_proxy_protocol': PROXY_PORT_SSL,
'native_shard_aware_transport_port_proxy_protocol': PROXY_SHARD_AWARE_PORT,
'native_shard_aware_transport_port_ssl_proxy_protocol': PROXY_SHARD_AWARE_PORT_SSL,
# Set explicit non-SSL and SSL ports so the driver can connect to unencrypted port
'native_transport_port': 9042,
'native_shard_aware_transport_port': 19042,
'native_transport_port_ssl': 9142,
'native_shard_aware_transport_port_ssl': 19142,
'client_encryption_options': {
'enabled': True,
'certificate': 'conf/scylla.crt',
'keyfile': 'conf/scylla.key',
},
}
@pytest.fixture(scope="function")
async def proxy_server(manager: ManagerClient):
"""
Fixture that creates a server with all proxy protocol ports enabled.
Returns a tuple of (server, manager).
"""
server = await manager.server_add(config=PROXY_SERVER_CONFIG)
yield (server, manager)
@pytest.mark.asyncio
async def test_proxy_protocol_basic(proxy_server):
"""
Test that connections through the proxy protocol port correctly report
the client address from the proxy header in system.clients.
"""
server, manager = proxy_server
# Use a distinctive fake source address that we can find in system.clients
fake_src_addr = "203.0.113.42" # TEST-NET-3, won't conflict with real addresses
fake_src_port = 12345
# Connect through the proxy protocol port with a fake source address
response = await send_cql_with_proxy_header(
host=server.ip_addr,
port=PROXY_PORT,
proxy_src_addr=fake_src_addr,
proxy_src_port=fake_src_port,
proxy_dst_addr=server.ip_addr,
proxy_dst_port=PROXY_PORT,
query=f"SELECT address, port FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING"
)
assert len(response) > 9, "Expected a valid CQL response"
opcode = response[4]
assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}"
@pytest.mark.asyncio
async def test_proxy_protocol_shard_aware(proxy_server):
"""
Test that shard-aware proxy protocol port correctly uses the source port
from the proxy header for shard routing. We connect to all shards by
using source ports that map to each shard (port % num_shards).
"""
server, manager = proxy_server
# The test harness runs with --smp 2 by default
num_shards = 2
cql = manager.get_cql()
fake_src_addr = "203.0.113.43"
base_port = 10000
# Keep connections open while we verify shard assignments
connections = []
try:
# Connect to each shard by using source ports that map to each shard
for shard in range(num_shards):
# Choose a port that maps to this shard: port % num_shards == shard
fake_src_port = base_port + shard
reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_SHARD_AWARE_PORT)
connections.append((reader, writer, fake_src_port, shard))
# Send proxy header
proxy_header = make_proxy_v2_header(
fake_src_addr, fake_src_port,
server.ip_addr, PROXY_SHARD_AWARE_PORT
)
writer.write(proxy_header)
await writer.drain()
# Complete CQL handshake
await do_cql_handshake(reader, writer)
# Now query system.clients to verify shard assignments
rows = await wait_for_results(
cql,
'SELECT address, port, shard_id FROM system.clients',
expected_count=num_shards,
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
)
# Build a map of port -> shard_id from the results
port_to_shard = {row.port: row.shard_id for row in rows}
# Verify each connection landed on the expected shard
for reader, writer, fake_src_port, expected_shard in connections:
assert fake_src_port in port_to_shard, f"Port {fake_src_port} not found in system.clients"
actual_shard = port_to_shard[fake_src_port]
assert actual_shard == expected_shard, \
f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}"
finally:
for reader, writer, _, _ in connections:
writer.close()
await writer.wait_closed()
@pytest.mark.asyncio
async def test_proxy_protocol_multiple_connections(proxy_server):
"""
Test that multiple connections through the proxy protocol port
with different source addresses are all correctly recorded.
"""
server, manager = proxy_server
test_clients = [
("203.0.113.101", 10001),
("203.0.113.102", 10002),
("203.0.113.103", 10003),
]
for fake_src_addr, fake_src_port in test_clients:
response = await send_cql_with_proxy_header(
host=server.ip_addr,
port=PROXY_PORT,
proxy_src_addr=fake_src_addr,
proxy_src_port=fake_src_port,
proxy_dst_addr=server.ip_addr,
proxy_dst_port=PROXY_PORT,
query="SELECT * FROM system.local"
)
assert len(response) > 9, f"Expected a valid CQL response for {fake_src_addr}"
opcode = response[4]
assert opcode == 0x08, f"Expected RESULT opcode for {fake_src_addr}, got {hex(opcode)}"
@pytest.mark.asyncio
async def test_proxy_protocol_port_preserved_in_system_clients(proxy_server):
"""
Test that the source port from the proxy protocol header is correctly
preserved in system.clients, which is important for shard-aware routing.
"""
server, manager = proxy_server
fake_src_addr = "203.0.113.200"
fake_src_port = 44444
# Keep the connection open while we query system.clients
reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_PORT)
try:
# Send proxy header
proxy_header = make_proxy_v2_header(
fake_src_addr, fake_src_port,
server.ip_addr, PROXY_PORT
)
writer.write(proxy_header)
await writer.drain()
# Complete CQL handshake
await do_cql_handshake(reader, writer)
# Now query system.clients using the driver to see our connection
cql = manager.get_cql()
rows = await wait_for_results(
cql,
'SELECT address, port FROM system.clients',
expected_count=1,
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
)
# We should find our connection with the fake source address and port
assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients"
found_correct_port = False
for row in rows:
if row.port == fake_src_port:
found_correct_port = True
break
assert found_correct_port, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}"
finally:
writer.close()
await writer.wait_closed()
@pytest.mark.asyncio
async def test_proxy_protocol_ssl_basic(proxy_server):
"""
Test proxy protocol with TLS encryption.
The proxy header is sent first, then the connection is upgraded to TLS.
"""
server, manager = proxy_server
fake_src_addr = "203.0.113.50"
fake_src_port = 55555
response = await send_cql_with_proxy_header_tls(
host=server.ip_addr,
port=PROXY_PORT_SSL,
proxy_src_addr=fake_src_addr,
proxy_src_port=fake_src_port,
proxy_dst_addr=server.ip_addr,
proxy_dst_port=PROXY_PORT_SSL,
query="SELECT * FROM system.local"
)
assert len(response) > 9, "Expected a valid CQL response"
opcode = response[4]
assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}"
@pytest.mark.asyncio
async def test_proxy_protocol_ssl_shard_aware(proxy_server):
"""
Test proxy protocol with TLS on the shard-aware port. We connect to all
shards by using source ports that map to each shard (port % num_shards).
"""
server, manager = proxy_server
# The test harness runs with --smp 2 by default
num_shards = 2
cql = manager.get_cql()
fake_src_addr = "203.0.113.51"
base_port = 20000
# Keep connections open while we verify shard assignments
ssl_sockets = []
try:
# Connect to each shard by using source ports that map to each shard
for shard in range(num_shards):
# Choose a port that maps to this shard: port % num_shards == shard
fake_src_port = base_port + shard
# Create raw socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
loop = asyncio.get_event_loop()
await loop.sock_connect(sock, (server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL))
# Send proxy header on raw socket before TLS
proxy_header = make_proxy_v2_header(
fake_src_addr, fake_src_port,
server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL
)
sock.setblocking(True)
sock.sendall(proxy_header)
# Upgrade to TLS
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_sock = ssl_context.wrap_socket(sock, do_handshake_on_connect=False)
# Complete TLS handshake
while True:
try:
ssl_sock.do_handshake()
break
except ssl.SSLWantReadError:
await asyncio.sleep(0.01)
except ssl.SSLWantWriteError:
await asyncio.sleep(0.01)
ssl_sockets.append((ssl_sock, fake_src_port, shard))
# Send STARTUP frame
startup_frame = build_cql_startup_frame(stream=1)
ssl_sock.sendall(startup_frame)
# Read response and handle auth if needed
response = ssl_sock.recv(4096)
opcode = response[4]
if opcode == 0x03: # AUTHENTICATE
auth_body = b'\x00cassandra\x00cassandra'
auth_frame = bytearray()
auth_frame.append(0x04)
auth_frame.append(0x00)
auth_frame.extend(struct.pack('!H', 1))
auth_frame.append(0x0F)
auth_frame.extend(struct.pack('!I', len(auth_body)))
auth_frame.extend(auth_body)
ssl_sock.sendall(bytes(auth_frame))
ssl_sock.recv(4096)
# Now query system.clients to verify shard assignments
rows = await wait_for_results(
cql,
'SELECT address, port, shard_id, ssl_enabled FROM system.clients',
expected_count=num_shards,
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
)
# Build a map of port -> (shard_id, ssl_enabled) from the results
port_to_info = {row.port: (row.shard_id, row.ssl_enabled) for row in rows}
# Verify each connection landed on the expected shard with SSL enabled
for ssl_sock, fake_src_port, expected_shard in ssl_sockets:
assert fake_src_port in port_to_info, f"Port {fake_src_port} not found in system.clients"
actual_shard, ssl_enabled = port_to_info[fake_src_port]
assert actual_shard == expected_shard, \
f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}"
assert ssl_enabled, f"Port {fake_src_port} expected ssl_enabled=True"
finally:
for ssl_sock, _, _ in ssl_sockets:
ssl_sock.close()
@pytest.mark.asyncio
async def test_proxy_protocol_ssl_port_preserved(proxy_server):
"""
Test that the source port from the proxy protocol header is correctly
preserved in system.clients when using TLS.
"""
server, manager = proxy_server
fake_src_addr = "203.0.113.201"
fake_src_port = 44445
# Create a connection that stays open
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
loop = asyncio.get_event_loop()
await loop.sock_connect(sock, (server.ip_addr, PROXY_PORT_SSL))
ssl_sock = None
try:
# Send proxy header on raw socket
proxy_header = make_proxy_v2_header(
fake_src_addr, fake_src_port,
server.ip_addr, PROXY_PORT_SSL
)
await loop.sock_sendall(sock, proxy_header)
# Wrap with TLS
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=server.ip_addr, do_handshake_on_connect=False)
# Do TLS handshake
while True:
try:
ssl_sock.do_handshake()
break
except ssl.SSLWantReadError:
await asyncio.sleep(0.01)
except ssl.SSLWantWriteError:
await asyncio.sleep(0.01)
ssl_sock.setblocking(True)
# Send STARTUP
startup_frame = build_cql_startup_frame(stream=1)
ssl_sock.sendall(startup_frame)
# Read response and handle auth if needed
response = ssl_sock.recv(4096)
opcode = response[4]
if opcode == 0x03: # AUTHENTICATE
auth_body = b'\x00cassandra\x00cassandra'
auth_frame = bytearray()
auth_frame.append(0x04)
auth_frame.append(0x00)
auth_frame.extend(struct.pack('!H', 1))
auth_frame.append(0x0F)
auth_frame.extend(struct.pack('!I', len(auth_body)))
auth_frame.extend(auth_body)
ssl_sock.sendall(bytes(auth_frame))
ssl_sock.recv(4096)
# Now query system.clients using the driver to see our connection
cql = manager.get_cql()
rows = await wait_for_results(
cql,
'SELECT address, port, ssl_enabled FROM system.clients',
expected_count=1,
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
)
# We should find our connection
assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients"
found_correct = False
for row in rows:
if row.port == fake_src_port:
found_correct = True
assert row.ssl_enabled, "Expected connection to have ssl_enabled=True"
break
assert found_correct, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}"
finally:
if ssl_sock:
ssl_sock.close()
else:
sock.close()