In case of an error, we want to see the contents of the system.clients table to have a better understanding of what happened - whether the row(s) are really missing or maybe they are there, but 1 digit doesn't match or the row is half-written. We'll therefore query for the whole table on the CQL side, and then filter out the rows we want to later proceed with on the python side. This way we can dump the contents of the whole system.clients table if something goes south.
716 lines
24 KiB
Python
716 lines
24 KiB
Python
# Copyright 2025-present ScyllaDB
|
|
#
|
|
# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
|
|
|
|
"""
|
|
Tests for CQL native transport ports with proxy protocol v2 support.
|
|
|
|
These tests verify that when Scylla is configured with proxy protocol ports,
|
|
clients can connect using the PROXY protocol v2 header and the original
|
|
client addresses are properly reported in system.clients.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import pytest
|
|
import socket
|
|
import ssl
|
|
import struct
|
|
import time
|
|
|
|
from test.pylib.manager_client import ManagerClient
|
|
from test.pylib.util import wait_for
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Proxy protocol v2 signature (12 bytes)
|
|
PROXY_V2_SIGNATURE = b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a'
|
|
|
|
# Port numbers for proxy protocol ports (to avoid conflicts with standard ports)
|
|
PROXY_PORT = 29042
|
|
PROXY_PORT_SSL = 29142
|
|
PROXY_SHARD_AWARE_PORT = 39042
|
|
PROXY_SHARD_AWARE_PORT_SSL = 39142
|
|
|
|
|
|
def make_proxy_v2_header(src_addr: str, src_port: int, dst_addr: str, dst_port: int) -> bytes:
|
|
"""
|
|
Construct a proxy protocol v2 header for IPv4 TCP connections.
|
|
|
|
The proxy protocol v2 binary format is defined at:
|
|
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
|
|
"""
|
|
header = bytearray()
|
|
|
|
# Signature (12 bytes)
|
|
header.extend(PROXY_V2_SIGNATURE)
|
|
|
|
# Version (upper 4 bits) and command (lower 4 bits)
|
|
# Version 2, PROXY command (0x21)
|
|
header.append(0x21)
|
|
|
|
# Address family (upper 4 bits) and transport protocol (lower 4 bits)
|
|
# AF_INET (1) and STREAM/TCP (1) = 0x11
|
|
header.append(0x11)
|
|
|
|
# Length of the address block (12 bytes for IPv4: 4+4+2+2)
|
|
header.extend(struct.pack('!H', 12))
|
|
|
|
# Source address (4 bytes)
|
|
header.extend(socket.inet_aton(src_addr))
|
|
|
|
# Destination address (4 bytes)
|
|
header.extend(socket.inet_aton(dst_addr))
|
|
|
|
# Source port (2 bytes, big-endian)
|
|
header.extend(struct.pack('!H', src_port))
|
|
|
|
# Destination port (2 bytes, big-endian)
|
|
header.extend(struct.pack('!H', dst_port))
|
|
|
|
return bytes(header)
|
|
|
|
|
|
def build_cql_startup_frame(stream: int = 1) -> bytes:
|
|
"""Build a CQL protocol v4 STARTUP frame."""
|
|
# String map with CQL_VERSION
|
|
body = b'\x00\x01\x00\x0bCQL_VERSION\x00\x053.0.0'
|
|
|
|
frame = bytearray()
|
|
frame.append(0x04) # version 4
|
|
frame.append(0x00) # flags
|
|
frame.extend(struct.pack('!H', stream)) # stream
|
|
frame.append(0x01) # opcode: STARTUP
|
|
frame.extend(struct.pack('!I', len(body))) # body length
|
|
frame.extend(body)
|
|
|
|
return bytes(frame)
|
|
|
|
|
|
def build_cql_query_frame(query: str, stream: int = 2) -> bytes:
|
|
"""Build a CQL protocol v4 QUERY frame."""
|
|
query_bytes = query.encode('utf-8')
|
|
|
|
body = bytearray()
|
|
# Long string: 4-byte length + string
|
|
body.extend(struct.pack('!I', len(query_bytes)))
|
|
body.extend(query_bytes)
|
|
# Query parameters: consistency (2 bytes) + flags (1 byte)
|
|
body.extend(struct.pack('!H', 0x0001)) # consistency: ONE
|
|
body.append(0x00) # flags: none
|
|
|
|
frame = bytearray()
|
|
frame.append(0x04) # version 4
|
|
frame.append(0x00) # flags
|
|
frame.extend(struct.pack('!H', stream)) # stream
|
|
frame.append(0x07) # opcode: QUERY
|
|
frame.extend(struct.pack('!I', len(body))) # body length
|
|
frame.extend(body)
|
|
|
|
return bytes(frame)
|
|
|
|
|
|
async def do_cql_handshake(reader, writer):
|
|
"""Complete CQL handshake (STARTUP and optional AUTH)."""
|
|
# Send CQL STARTUP
|
|
startup_frame = build_cql_startup_frame(stream=1)
|
|
writer.write(startup_frame)
|
|
await writer.drain()
|
|
|
|
# Read response (READY or AUTHENTICATE)
|
|
response = await asyncio.wait_for(reader.read(4096), timeout=100.0)
|
|
if len(response) < 9:
|
|
raise RuntimeError(f"Short response from server: {len(response)} bytes")
|
|
|
|
opcode = response[4]
|
|
if opcode == 0x03: # AUTHENTICATE
|
|
# Send AUTH_RESPONSE with default credentials
|
|
auth_body = b'\x00cassandra\x00cassandra'
|
|
auth_frame = bytearray()
|
|
auth_frame.append(0x04)
|
|
auth_frame.append(0x00)
|
|
auth_frame.extend(struct.pack('!H', 1))
|
|
auth_frame.append(0x0F) # AUTH_RESPONSE
|
|
auth_frame.extend(struct.pack('!I', len(auth_body)))
|
|
auth_frame.extend(auth_body)
|
|
writer.write(bytes(auth_frame))
|
|
await writer.drain()
|
|
response = await asyncio.wait_for(reader.read(4096), timeout=100.0)
|
|
opcode = response[4]
|
|
if opcode != 0x10: # AUTH_SUCCESS
|
|
raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}")
|
|
elif opcode != 0x02: # READY
|
|
raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}")
|
|
|
|
|
|
async def send_cql_query(reader, writer, query: str, stream: int = 2) -> bytes:
|
|
"""Send a CQL query and return the response."""
|
|
query_frame = build_cql_query_frame(query, stream=stream)
|
|
writer.write(query_frame)
|
|
await writer.drain()
|
|
response = await asyncio.wait_for(reader.read(16384), timeout=100.0)
|
|
return response
|
|
|
|
|
|
async def send_cql_with_proxy_header(
|
|
host: str,
|
|
port: int,
|
|
proxy_src_addr: str,
|
|
proxy_src_port: int,
|
|
proxy_dst_addr: str,
|
|
proxy_dst_port: int,
|
|
query: str
|
|
) -> bytes:
|
|
"""Connect to a CQL server using proxy protocol v2 and execute a query."""
|
|
reader, writer = await asyncio.open_connection(host, port)
|
|
|
|
try:
|
|
# Send proxy protocol v2 header
|
|
proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port)
|
|
writer.write(proxy_header)
|
|
await writer.drain()
|
|
|
|
# Complete CQL handshake
|
|
await do_cql_handshake(reader, writer)
|
|
|
|
# Send the query
|
|
return await send_cql_query(reader, writer, query)
|
|
|
|
finally:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
|
|
|
|
async def send_cql_with_proxy_header_tls(
|
|
host: str,
|
|
port: int,
|
|
proxy_src_addr: str,
|
|
proxy_src_port: int,
|
|
proxy_dst_addr: str,
|
|
proxy_dst_port: int,
|
|
query: str
|
|
) -> bytes:
|
|
"""
|
|
Connect to a CQL server using proxy protocol v2 with TLS and execute a query.
|
|
|
|
The proxy protocol header is sent first over the raw TCP connection,
|
|
then the connection is upgraded to TLS before the CQL protocol begins.
|
|
"""
|
|
# Create a socket and connect
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setblocking(False)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
await loop.sock_connect(sock, (host, port))
|
|
|
|
try:
|
|
# Send proxy header on raw socket BEFORE TLS handshake
|
|
proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port)
|
|
await loop.sock_sendall(sock, proxy_header)
|
|
|
|
# Create SSL context (don't verify server certificate for testing)
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
# Wrap the socket with TLS
|
|
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, do_handshake_on_connect=False)
|
|
|
|
# Do TLS handshake (non-blocking)
|
|
while True:
|
|
try:
|
|
ssl_sock.do_handshake()
|
|
break
|
|
except ssl.SSLWantReadError:
|
|
await asyncio.sleep(0.01)
|
|
except ssl.SSLWantWriteError:
|
|
await asyncio.sleep(0.01)
|
|
|
|
# Make the SSL socket blocking for simplicity
|
|
ssl_sock.setblocking(True)
|
|
|
|
# Send CQL STARTUP
|
|
startup_frame = build_cql_startup_frame(stream=1)
|
|
ssl_sock.sendall(startup_frame)
|
|
|
|
# Read response
|
|
response = ssl_sock.recv(4096)
|
|
if len(response) < 9:
|
|
raise RuntimeError(f"Short response from server: {len(response)} bytes")
|
|
|
|
opcode = response[4]
|
|
if opcode == 0x03: # AUTHENTICATE
|
|
auth_body = b'\x00cassandra\x00cassandra'
|
|
auth_frame = bytearray()
|
|
auth_frame.append(0x04)
|
|
auth_frame.append(0x00)
|
|
auth_frame.extend(struct.pack('!H', 1))
|
|
auth_frame.append(0x0F)
|
|
auth_frame.extend(struct.pack('!I', len(auth_body)))
|
|
auth_frame.extend(auth_body)
|
|
ssl_sock.sendall(bytes(auth_frame))
|
|
response = ssl_sock.recv(4096)
|
|
opcode = response[4]
|
|
if opcode != 0x10:
|
|
raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}")
|
|
elif opcode != 0x02:
|
|
raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}")
|
|
|
|
# Send query
|
|
query_frame = build_cql_query_frame(query, stream=2)
|
|
ssl_sock.sendall(query_frame)
|
|
|
|
# Read response
|
|
response = ssl_sock.recv(16384)
|
|
return response
|
|
|
|
finally:
|
|
try:
|
|
ssl_sock.close()
|
|
except:
|
|
sock.close()
|
|
|
|
|
|
async def wait_for_results(cql, query: str, expected_count: int, timeout: float = 30.0, filter_fn=None):
|
|
"""
|
|
Polls `query` until at least `expected_count` rows satisfy `filter_fn` (all rows if no filter is given).
|
|
On timeout, logs the full result set from the last poll to aid debugging.
|
|
"""
|
|
last_rows: list = []
|
|
|
|
async def check_resultset():
|
|
nonlocal last_rows
|
|
last_rows = list(await cql.run_async(query))
|
|
matching = filter_fn(last_rows) if filter_fn is not None else last_rows
|
|
if len(matching) >= expected_count:
|
|
return matching
|
|
return None
|
|
|
|
try:
|
|
return await wait_for(check_resultset, time.time() + timeout, period=0.1)
|
|
except Exception:
|
|
logger.error('Timed out waiting for %d matching rows in system.clients. Last poll returned %d total rows:\n%s',
|
|
expected_count, len(last_rows),'\n'.join(str(r) for r in last_rows))
|
|
raise
|
|
|
|
# Shared server configuration for all tests
|
|
# We configure explicit SSL ports to keep the standard ports unencrypted
|
|
# so the Python driver can connect without TLS.
|
|
PROXY_SERVER_CONFIG = {
|
|
'native_transport_port_proxy_protocol': PROXY_PORT,
|
|
'native_transport_port_ssl_proxy_protocol': PROXY_PORT_SSL,
|
|
'native_shard_aware_transport_port_proxy_protocol': PROXY_SHARD_AWARE_PORT,
|
|
'native_shard_aware_transport_port_ssl_proxy_protocol': PROXY_SHARD_AWARE_PORT_SSL,
|
|
# Set explicit non-SSL and SSL ports so the driver can connect to unencrypted port
|
|
'native_transport_port': 9042,
|
|
'native_shard_aware_transport_port': 19042,
|
|
'native_transport_port_ssl': 9142,
|
|
'native_shard_aware_transport_port_ssl': 19142,
|
|
'client_encryption_options': {
|
|
'enabled': True,
|
|
'certificate': 'conf/scylla.crt',
|
|
'keyfile': 'conf/scylla.key',
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def proxy_server(manager: ManagerClient):
|
|
"""
|
|
Fixture that creates a server with all proxy protocol ports enabled.
|
|
Returns a tuple of (server, manager).
|
|
"""
|
|
server = await manager.server_add(config=PROXY_SERVER_CONFIG)
|
|
yield (server, manager)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_basic(proxy_server):
|
|
"""
|
|
Test that connections through the proxy protocol port correctly report
|
|
the client address from the proxy header in system.clients.
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
# Use a distinctive fake source address that we can find in system.clients
|
|
fake_src_addr = "203.0.113.42" # TEST-NET-3, won't conflict with real addresses
|
|
fake_src_port = 12345
|
|
|
|
# Connect through the proxy protocol port with a fake source address
|
|
response = await send_cql_with_proxy_header(
|
|
host=server.ip_addr,
|
|
port=PROXY_PORT,
|
|
proxy_src_addr=fake_src_addr,
|
|
proxy_src_port=fake_src_port,
|
|
proxy_dst_addr=server.ip_addr,
|
|
proxy_dst_port=PROXY_PORT,
|
|
query=f"SELECT address, port FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING"
|
|
)
|
|
|
|
assert len(response) > 9, "Expected a valid CQL response"
|
|
opcode = response[4]
|
|
assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_shard_aware(proxy_server):
|
|
"""
|
|
Test that shard-aware proxy protocol port correctly uses the source port
|
|
from the proxy header for shard routing. We connect to all shards by
|
|
using source ports that map to each shard (port % num_shards).
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
# The test harness runs with --smp 2 by default
|
|
num_shards = 2
|
|
|
|
cql = manager.get_cql()
|
|
|
|
fake_src_addr = "203.0.113.43"
|
|
base_port = 10000
|
|
|
|
# Keep connections open while we verify shard assignments
|
|
connections = []
|
|
try:
|
|
# Connect to each shard by using source ports that map to each shard
|
|
for shard in range(num_shards):
|
|
# Choose a port that maps to this shard: port % num_shards == shard
|
|
fake_src_port = base_port + shard
|
|
|
|
reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_SHARD_AWARE_PORT)
|
|
connections.append((reader, writer, fake_src_port, shard))
|
|
|
|
# Send proxy header
|
|
proxy_header = make_proxy_v2_header(
|
|
fake_src_addr, fake_src_port,
|
|
server.ip_addr, PROXY_SHARD_AWARE_PORT
|
|
)
|
|
writer.write(proxy_header)
|
|
await writer.drain()
|
|
|
|
# Complete CQL handshake
|
|
await do_cql_handshake(reader, writer)
|
|
|
|
# Now query system.clients to verify shard assignments
|
|
rows = await wait_for_results(
|
|
cql,
|
|
'SELECT address, port, shard_id FROM system.clients',
|
|
expected_count=num_shards,
|
|
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
|
|
)
|
|
|
|
# Build a map of port -> shard_id from the results
|
|
port_to_shard = {row.port: row.shard_id for row in rows}
|
|
|
|
# Verify each connection landed on the expected shard
|
|
for reader, writer, fake_src_port, expected_shard in connections:
|
|
assert fake_src_port in port_to_shard, f"Port {fake_src_port} not found in system.clients"
|
|
actual_shard = port_to_shard[fake_src_port]
|
|
assert actual_shard == expected_shard, \
|
|
f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}"
|
|
|
|
finally:
|
|
for reader, writer, _, _ in connections:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_multiple_connections(proxy_server):
|
|
"""
|
|
Test that multiple connections through the proxy protocol port
|
|
with different source addresses are all correctly recorded.
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
test_clients = [
|
|
("203.0.113.101", 10001),
|
|
("203.0.113.102", 10002),
|
|
("203.0.113.103", 10003),
|
|
]
|
|
|
|
for fake_src_addr, fake_src_port in test_clients:
|
|
response = await send_cql_with_proxy_header(
|
|
host=server.ip_addr,
|
|
port=PROXY_PORT,
|
|
proxy_src_addr=fake_src_addr,
|
|
proxy_src_port=fake_src_port,
|
|
proxy_dst_addr=server.ip_addr,
|
|
proxy_dst_port=PROXY_PORT,
|
|
query="SELECT * FROM system.local"
|
|
)
|
|
|
|
assert len(response) > 9, f"Expected a valid CQL response for {fake_src_addr}"
|
|
opcode = response[4]
|
|
assert opcode == 0x08, f"Expected RESULT opcode for {fake_src_addr}, got {hex(opcode)}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_port_preserved_in_system_clients(proxy_server):
|
|
"""
|
|
Test that the source port from the proxy protocol header is correctly
|
|
preserved in system.clients, which is important for shard-aware routing.
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
fake_src_addr = "203.0.113.200"
|
|
fake_src_port = 44444
|
|
|
|
# Keep the connection open while we query system.clients
|
|
reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_PORT)
|
|
|
|
try:
|
|
# Send proxy header
|
|
proxy_header = make_proxy_v2_header(
|
|
fake_src_addr, fake_src_port,
|
|
server.ip_addr, PROXY_PORT
|
|
)
|
|
writer.write(proxy_header)
|
|
await writer.drain()
|
|
|
|
# Complete CQL handshake
|
|
await do_cql_handshake(reader, writer)
|
|
|
|
# Now query system.clients using the driver to see our connection
|
|
cql = manager.get_cql()
|
|
rows = await wait_for_results(
|
|
cql,
|
|
'SELECT address, port FROM system.clients',
|
|
expected_count=1,
|
|
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
|
|
)
|
|
|
|
# We should find our connection with the fake source address and port
|
|
assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients"
|
|
|
|
found_correct_port = False
|
|
for row in rows:
|
|
if row.port == fake_src_port:
|
|
found_correct_port = True
|
|
break
|
|
|
|
assert found_correct_port, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}"
|
|
|
|
finally:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_ssl_basic(proxy_server):
|
|
"""
|
|
Test proxy protocol with TLS encryption.
|
|
The proxy header is sent first, then the connection is upgraded to TLS.
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
fake_src_addr = "203.0.113.50"
|
|
fake_src_port = 55555
|
|
|
|
response = await send_cql_with_proxy_header_tls(
|
|
host=server.ip_addr,
|
|
port=PROXY_PORT_SSL,
|
|
proxy_src_addr=fake_src_addr,
|
|
proxy_src_port=fake_src_port,
|
|
proxy_dst_addr=server.ip_addr,
|
|
proxy_dst_port=PROXY_PORT_SSL,
|
|
query="SELECT * FROM system.local"
|
|
)
|
|
|
|
assert len(response) > 9, "Expected a valid CQL response"
|
|
opcode = response[4]
|
|
assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_ssl_shard_aware(proxy_server):
|
|
"""
|
|
Test proxy protocol with TLS on the shard-aware port. We connect to all
|
|
shards by using source ports that map to each shard (port % num_shards).
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
# The test harness runs with --smp 2 by default
|
|
num_shards = 2
|
|
|
|
cql = manager.get_cql()
|
|
|
|
fake_src_addr = "203.0.113.51"
|
|
base_port = 20000
|
|
|
|
# Keep connections open while we verify shard assignments
|
|
ssl_sockets = []
|
|
try:
|
|
# Connect to each shard by using source ports that map to each shard
|
|
for shard in range(num_shards):
|
|
# Choose a port that maps to this shard: port % num_shards == shard
|
|
fake_src_port = base_port + shard
|
|
|
|
# Create raw socket
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setblocking(False)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
await loop.sock_connect(sock, (server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL))
|
|
|
|
# Send proxy header on raw socket before TLS
|
|
proxy_header = make_proxy_v2_header(
|
|
fake_src_addr, fake_src_port,
|
|
server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL
|
|
)
|
|
sock.setblocking(True)
|
|
sock.sendall(proxy_header)
|
|
|
|
# Upgrade to TLS
|
|
ssl_context = ssl.create_default_context()
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
ssl_sock = ssl_context.wrap_socket(sock, do_handshake_on_connect=False)
|
|
|
|
# Complete TLS handshake
|
|
while True:
|
|
try:
|
|
ssl_sock.do_handshake()
|
|
break
|
|
except ssl.SSLWantReadError:
|
|
await asyncio.sleep(0.01)
|
|
except ssl.SSLWantWriteError:
|
|
await asyncio.sleep(0.01)
|
|
|
|
ssl_sockets.append((ssl_sock, fake_src_port, shard))
|
|
|
|
# Send STARTUP frame
|
|
startup_frame = build_cql_startup_frame(stream=1)
|
|
ssl_sock.sendall(startup_frame)
|
|
|
|
# Read response and handle auth if needed
|
|
response = ssl_sock.recv(4096)
|
|
opcode = response[4]
|
|
if opcode == 0x03: # AUTHENTICATE
|
|
auth_body = b'\x00cassandra\x00cassandra'
|
|
auth_frame = bytearray()
|
|
auth_frame.append(0x04)
|
|
auth_frame.append(0x00)
|
|
auth_frame.extend(struct.pack('!H', 1))
|
|
auth_frame.append(0x0F)
|
|
auth_frame.extend(struct.pack('!I', len(auth_body)))
|
|
auth_frame.extend(auth_body)
|
|
ssl_sock.sendall(bytes(auth_frame))
|
|
ssl_sock.recv(4096)
|
|
|
|
# Now query system.clients to verify shard assignments
|
|
rows = await wait_for_results(
|
|
cql,
|
|
'SELECT address, port, shard_id, ssl_enabled FROM system.clients',
|
|
expected_count=num_shards,
|
|
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
|
|
)
|
|
|
|
# Build a map of port -> (shard_id, ssl_enabled) from the results
|
|
port_to_info = {row.port: (row.shard_id, row.ssl_enabled) for row in rows}
|
|
|
|
# Verify each connection landed on the expected shard with SSL enabled
|
|
for ssl_sock, fake_src_port, expected_shard in ssl_sockets:
|
|
assert fake_src_port in port_to_info, f"Port {fake_src_port} not found in system.clients"
|
|
actual_shard, ssl_enabled = port_to_info[fake_src_port]
|
|
assert actual_shard == expected_shard, \
|
|
f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}"
|
|
assert ssl_enabled, f"Port {fake_src_port} expected ssl_enabled=True"
|
|
|
|
finally:
|
|
for ssl_sock, _, _ in ssl_sockets:
|
|
ssl_sock.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_protocol_ssl_port_preserved(proxy_server):
|
|
"""
|
|
Test that the source port from the proxy protocol header is correctly
|
|
preserved in system.clients when using TLS.
|
|
"""
|
|
server, manager = proxy_server
|
|
|
|
fake_src_addr = "203.0.113.201"
|
|
fake_src_port = 44445
|
|
|
|
# Create a connection that stays open
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.setblocking(False)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
await loop.sock_connect(sock, (server.ip_addr, PROXY_PORT_SSL))
|
|
|
|
ssl_sock = None
|
|
try:
|
|
# Send proxy header on raw socket
|
|
proxy_header = make_proxy_v2_header(
|
|
fake_src_addr, fake_src_port,
|
|
server.ip_addr, PROXY_PORT_SSL
|
|
)
|
|
await loop.sock_sendall(sock, proxy_header)
|
|
|
|
# Wrap with TLS
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=server.ip_addr, do_handshake_on_connect=False)
|
|
|
|
# Do TLS handshake
|
|
while True:
|
|
try:
|
|
ssl_sock.do_handshake()
|
|
break
|
|
except ssl.SSLWantReadError:
|
|
await asyncio.sleep(0.01)
|
|
except ssl.SSLWantWriteError:
|
|
await asyncio.sleep(0.01)
|
|
|
|
ssl_sock.setblocking(True)
|
|
|
|
# Send STARTUP
|
|
startup_frame = build_cql_startup_frame(stream=1)
|
|
ssl_sock.sendall(startup_frame)
|
|
|
|
# Read response and handle auth if needed
|
|
response = ssl_sock.recv(4096)
|
|
opcode = response[4]
|
|
if opcode == 0x03: # AUTHENTICATE
|
|
auth_body = b'\x00cassandra\x00cassandra'
|
|
auth_frame = bytearray()
|
|
auth_frame.append(0x04)
|
|
auth_frame.append(0x00)
|
|
auth_frame.extend(struct.pack('!H', 1))
|
|
auth_frame.append(0x0F)
|
|
auth_frame.extend(struct.pack('!I', len(auth_body)))
|
|
auth_frame.extend(auth_body)
|
|
ssl_sock.sendall(bytes(auth_frame))
|
|
ssl_sock.recv(4096)
|
|
|
|
# Now query system.clients using the driver to see our connection
|
|
cql = manager.get_cql()
|
|
rows = await wait_for_results(
|
|
cql,
|
|
'SELECT address, port, ssl_enabled FROM system.clients',
|
|
expected_count=1,
|
|
filter_fn=lambda all_rows: [r for r in all_rows if str(r.address) == fake_src_addr],
|
|
)
|
|
|
|
# We should find our connection
|
|
assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients"
|
|
|
|
found_correct = False
|
|
for row in rows:
|
|
if row.port == fake_src_port:
|
|
found_correct = True
|
|
assert row.ssl_enabled, "Expected connection to have ssl_enabled=True"
|
|
break
|
|
|
|
assert found_correct, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}"
|
|
|
|
finally:
|
|
if ssl_sock:
|
|
ssl_sock.close()
|
|
else:
|
|
sock.close()
|