diff --git a/db/config.cc b/db/config.cc index f52e4fd3a6..ee894ce1a6 100644 --- a/db/config.cc +++ b/db/config.cc @@ -1105,6 +1105,14 @@ db::config::config(std::shared_ptr exts) "Like native_transport_port, but clients-side port number (modulo smp) is used to route the connection to the specific shard.") , native_shard_aware_transport_port_ssl(this, "native_shard_aware_transport_port_ssl", value_status::Used, 19142, "Like native_transport_port_ssl, but clients-side port number (modulo smp) is used to route the connection to the specific shard.") + , native_transport_port_proxy_protocol(this, "native_transport_port_proxy_protocol", value_status::Used, 0, + "Port on which the CQL native transport listens for clients using proxy protocol v2. Disabled (0) by default.") + , native_transport_port_ssl_proxy_protocol(this, "native_transport_port_ssl_proxy_protocol", value_status::Used, 0, + "Port on which the CQL TLS native transport listens for clients using proxy protocol v2. Disabled (0) by default.") + , native_shard_aware_transport_port_proxy_protocol(this, "native_shard_aware_transport_port_proxy_protocol", value_status::Used, 0, + "Like native_transport_port_proxy_protocol, but clients-side port number (modulo smp) is used to route the connection to the specific shard.") + , native_shard_aware_transport_port_ssl_proxy_protocol(this, "native_shard_aware_transport_port_ssl_proxy_protocol", value_status::Used, 0, + "Like native_transport_port_ssl_proxy_protocol, but clients-side port number (modulo smp) is used to route the connection to the specific shard.") , native_transport_max_threads(this, "native_transport_max_threads", value_status::Invalid, 128, "The maximum number of thread handling requests. The meaning is the same as rpc_max_threads.\n" "Default is different (128 versus unlimited).\n" diff --git a/db/config.hh b/db/config.hh index 543fb6d36e..4076d209dd 100644 --- a/db/config.hh +++ b/db/config.hh @@ -324,6 +324,10 @@ public: named_value native_transport_port_ssl; named_value native_shard_aware_transport_port; named_value native_shard_aware_transport_port_ssl; + named_value native_transport_port_proxy_protocol; + named_value native_transport_port_ssl_proxy_protocol; + named_value native_shard_aware_transport_port_proxy_protocol; + named_value native_shard_aware_transport_port_ssl_proxy_protocol; named_value native_transport_max_threads; named_value native_transport_max_frame_size_in_mb; named_value broadcast_rpc_address; diff --git a/test/cluster/test_proxy_protocol.py b/test/cluster/test_proxy_protocol.py new file mode 100644 index 0000000000..ca8e66331e --- /dev/null +++ b/test/cluster/test_proxy_protocol.py @@ -0,0 +1,679 @@ +# Copyright 2025-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + +""" +Tests for CQL native transport ports with proxy protocol v2 support. + +These tests verify that when Scylla is configured with proxy protocol ports, +clients can connect using the PROXY protocol v2 header and the original +client addresses are properly reported in system.clients. +""" + +import asyncio +import logging +import pytest +import socket +import ssl +import struct + +from test.pylib.manager_client import ManagerClient + +logger = logging.getLogger(__name__) + +# Proxy protocol v2 signature (12 bytes) +PROXY_V2_SIGNATURE = b'\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a' + +# Port numbers for proxy protocol ports (to avoid conflicts with standard ports) +PROXY_PORT = 29042 +PROXY_PORT_SSL = 29142 +PROXY_SHARD_AWARE_PORT = 39042 +PROXY_SHARD_AWARE_PORT_SSL = 39142 + + +def make_proxy_v2_header(src_addr: str, src_port: int, dst_addr: str, dst_port: int) -> bytes: + """ + Construct a proxy protocol v2 header for IPv4 TCP connections. + + The proxy protocol v2 binary format is defined at: + https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + """ + header = bytearray() + + # Signature (12 bytes) + header.extend(PROXY_V2_SIGNATURE) + + # Version (upper 4 bits) and command (lower 4 bits) + # Version 2, PROXY command (0x21) + header.append(0x21) + + # Address family (upper 4 bits) and transport protocol (lower 4 bits) + # AF_INET (1) and STREAM/TCP (1) = 0x11 + header.append(0x11) + + # Length of the address block (12 bytes for IPv4: 4+4+2+2) + header.extend(struct.pack('!H', 12)) + + # Source address (4 bytes) + header.extend(socket.inet_aton(src_addr)) + + # Destination address (4 bytes) + header.extend(socket.inet_aton(dst_addr)) + + # Source port (2 bytes, big-endian) + header.extend(struct.pack('!H', src_port)) + + # Destination port (2 bytes, big-endian) + header.extend(struct.pack('!H', dst_port)) + + return bytes(header) + + +def build_cql_startup_frame(stream: int = 1) -> bytes: + """Build a CQL protocol v4 STARTUP frame.""" + # String map with CQL_VERSION + body = b'\x00\x01\x00\x0bCQL_VERSION\x00\x053.0.0' + + frame = bytearray() + frame.append(0x04) # version 4 + frame.append(0x00) # flags + frame.extend(struct.pack('!H', stream)) # stream + frame.append(0x01) # opcode: STARTUP + frame.extend(struct.pack('!I', len(body))) # body length + frame.extend(body) + + return bytes(frame) + + +def build_cql_query_frame(query: str, stream: int = 2) -> bytes: + """Build a CQL protocol v4 QUERY frame.""" + query_bytes = query.encode('utf-8') + + body = bytearray() + # Long string: 4-byte length + string + body.extend(struct.pack('!I', len(query_bytes))) + body.extend(query_bytes) + # Query parameters: consistency (2 bytes) + flags (1 byte) + body.extend(struct.pack('!H', 0x0001)) # consistency: ONE + body.append(0x00) # flags: none + + frame = bytearray() + frame.append(0x04) # version 4 + frame.append(0x00) # flags + frame.extend(struct.pack('!H', stream)) # stream + frame.append(0x07) # opcode: QUERY + frame.extend(struct.pack('!I', len(body))) # body length + frame.extend(body) + + return bytes(frame) + + +async def do_cql_handshake(reader, writer): + """Complete CQL handshake (STARTUP and optional AUTH).""" + # Send CQL STARTUP + startup_frame = build_cql_startup_frame(stream=1) + writer.write(startup_frame) + await writer.drain() + + # Read response (READY or AUTHENTICATE) + response = await asyncio.wait_for(reader.read(4096), timeout=10.0) + if len(response) < 9: + raise RuntimeError(f"Short response from server: {len(response)} bytes") + + opcode = response[4] + if opcode == 0x03: # AUTHENTICATE + # Send AUTH_RESPONSE with default credentials + auth_body = b'\x00cassandra\x00cassandra' + auth_frame = bytearray() + auth_frame.append(0x04) + auth_frame.append(0x00) + auth_frame.extend(struct.pack('!H', 1)) + auth_frame.append(0x0F) # AUTH_RESPONSE + auth_frame.extend(struct.pack('!I', len(auth_body))) + auth_frame.extend(auth_body) + writer.write(bytes(auth_frame)) + await writer.drain() + response = await asyncio.wait_for(reader.read(4096), timeout=10.0) + opcode = response[4] + if opcode != 0x10: # AUTH_SUCCESS + raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}") + elif opcode != 0x02: # READY + raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}") + + +async def send_cql_query(reader, writer, query: str, stream: int = 2) -> bytes: + """Send a CQL query and return the response.""" + query_frame = build_cql_query_frame(query, stream=stream) + writer.write(query_frame) + await writer.drain() + response = await asyncio.wait_for(reader.read(16384), timeout=10.0) + return response + + +async def send_cql_with_proxy_header( + host: str, + port: int, + proxy_src_addr: str, + proxy_src_port: int, + proxy_dst_addr: str, + proxy_dst_port: int, + query: str +) -> bytes: + """Connect to a CQL server using proxy protocol v2 and execute a query.""" + reader, writer = await asyncio.open_connection(host, port) + + try: + # Send proxy protocol v2 header + proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port) + writer.write(proxy_header) + await writer.drain() + + # Complete CQL handshake + await do_cql_handshake(reader, writer) + + # Send the query + return await send_cql_query(reader, writer, query) + + finally: + writer.close() + await writer.wait_closed() + + +async def send_cql_with_proxy_header_tls( + host: str, + port: int, + proxy_src_addr: str, + proxy_src_port: int, + proxy_dst_addr: str, + proxy_dst_port: int, + query: str +) -> bytes: + """ + Connect to a CQL server using proxy protocol v2 with TLS and execute a query. + + The proxy protocol header is sent first over the raw TCP connection, + then the connection is upgraded to TLS before the CQL protocol begins. + """ + # Create a socket and connect + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + + loop = asyncio.get_event_loop() + await loop.sock_connect(sock, (host, port)) + + try: + # Send proxy header on raw socket BEFORE TLS handshake + proxy_header = make_proxy_v2_header(proxy_src_addr, proxy_src_port, proxy_dst_addr, proxy_dst_port) + await loop.sock_sendall(sock, proxy_header) + + # Create SSL context (don't verify server certificate for testing) + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + # Wrap the socket with TLS + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host, do_handshake_on_connect=False) + + # Do TLS handshake (non-blocking) + while True: + try: + ssl_sock.do_handshake() + break + except ssl.SSLWantReadError: + await asyncio.sleep(0.01) + except ssl.SSLWantWriteError: + await asyncio.sleep(0.01) + + # Make the SSL socket blocking for simplicity + ssl_sock.setblocking(True) + + # Send CQL STARTUP + startup_frame = build_cql_startup_frame(stream=1) + ssl_sock.sendall(startup_frame) + + # Read response + response = ssl_sock.recv(4096) + if len(response) < 9: + raise RuntimeError(f"Short response from server: {len(response)} bytes") + + opcode = response[4] + if opcode == 0x03: # AUTHENTICATE + auth_body = b'\x00cassandra\x00cassandra' + auth_frame = bytearray() + auth_frame.append(0x04) + auth_frame.append(0x00) + auth_frame.extend(struct.pack('!H', 1)) + auth_frame.append(0x0F) + auth_frame.extend(struct.pack('!I', len(auth_body))) + auth_frame.extend(auth_body) + ssl_sock.sendall(bytes(auth_frame)) + response = ssl_sock.recv(4096) + opcode = response[4] + if opcode != 0x10: + raise RuntimeError(f"Expected AUTH_SUCCESS (0x10), got {hex(opcode)}") + elif opcode != 0x02: + raise RuntimeError(f"Expected READY (0x02) or AUTHENTICATE (0x03), got {hex(opcode)}") + + # Send query + query_frame = build_cql_query_frame(query, stream=2) + ssl_sock.sendall(query_frame) + + # Read response + response = ssl_sock.recv(16384) + return response + + finally: + try: + ssl_sock.close() + except: + sock.close() + + +# Shared server configuration for all tests +# We configure explicit SSL ports to keep the standard ports unencrypted +# so the Python driver can connect without TLS. +PROXY_SERVER_CONFIG = { + 'native_transport_port_proxy_protocol': PROXY_PORT, + 'native_transport_port_ssl_proxy_protocol': PROXY_PORT_SSL, + 'native_shard_aware_transport_port_proxy_protocol': PROXY_SHARD_AWARE_PORT, + 'native_shard_aware_transport_port_ssl_proxy_protocol': PROXY_SHARD_AWARE_PORT_SSL, + # Set explicit non-SSL and SSL ports so the driver can connect to unencrypted port + 'native_transport_port': 9042, + 'native_shard_aware_transport_port': 19042, + 'native_transport_port_ssl': 9142, + 'native_shard_aware_transport_port_ssl': 19142, + 'client_encryption_options': { + 'enabled': True, + 'certificate': 'conf/scylla.crt', + 'keyfile': 'conf/scylla.key', + }, +} + + +@pytest.fixture(scope="function") +async def proxy_server(manager: ManagerClient): + """ + Fixture that creates a server with all proxy protocol ports enabled. + Returns a tuple of (server, manager). + """ + server = await manager.server_add(config=PROXY_SERVER_CONFIG) + yield (server, manager) + + +@pytest.mark.asyncio +async def test_proxy_protocol_basic(proxy_server): + """ + Test that connections through the proxy protocol port correctly report + the client address from the proxy header in system.clients. + """ + server, manager = proxy_server + + # Use a distinctive fake source address that we can find in system.clients + fake_src_addr = "203.0.113.42" # TEST-NET-3, won't conflict with real addresses + fake_src_port = 12345 + + # Connect through the proxy protocol port with a fake source address + response = await send_cql_with_proxy_header( + host=server.ip_addr, + port=PROXY_PORT, + proxy_src_addr=fake_src_addr, + proxy_src_port=fake_src_port, + proxy_dst_addr=server.ip_addr, + proxy_dst_port=PROXY_PORT, + query=f"SELECT address, port FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING" + ) + + assert len(response) > 9, "Expected a valid CQL response" + opcode = response[4] + assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}" + + +@pytest.mark.asyncio +async def test_proxy_protocol_shard_aware(proxy_server): + """ + Test that shard-aware proxy protocol port correctly uses the source port + from the proxy header for shard routing. We connect to all shards by + using source ports that map to each shard (port % num_shards). + """ + server, manager = proxy_server + + # The test harness runs with --smp 2 by default + num_shards = 2 + + cql = manager.get_cql() + + fake_src_addr = "203.0.113.43" + base_port = 10000 + + # Keep connections open while we verify shard assignments + connections = [] + try: + # Connect to each shard by using source ports that map to each shard + for shard in range(num_shards): + # Choose a port that maps to this shard: port % num_shards == shard + fake_src_port = base_port + shard + + reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_SHARD_AWARE_PORT) + connections.append((reader, writer, fake_src_port, shard)) + + # Send proxy header + proxy_header = make_proxy_v2_header( + fake_src_addr, fake_src_port, + server.ip_addr, PROXY_SHARD_AWARE_PORT + ) + writer.write(proxy_header) + await writer.drain() + + # Complete CQL handshake + await do_cql_handshake(reader, writer) + + # Now query system.clients to verify shard assignments + rows = list(cql.execute( + f"SELECT address, port, shard_id FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING" + )) + + # Build a map of port -> shard_id from the results + port_to_shard = {row.port: row.shard_id for row in rows} + + # Verify each connection landed on the expected shard + for reader, writer, fake_src_port, expected_shard in connections: + assert fake_src_port in port_to_shard, f"Port {fake_src_port} not found in system.clients" + actual_shard = port_to_shard[fake_src_port] + assert actual_shard == expected_shard, \ + f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}" + + finally: + for reader, writer, _, _ in connections: + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_protocol_multiple_connections(proxy_server): + """ + Test that multiple connections through the proxy protocol port + with different source addresses are all correctly recorded. + """ + server, manager = proxy_server + + test_clients = [ + ("203.0.113.101", 10001), + ("203.0.113.102", 10002), + ("203.0.113.103", 10003), + ] + + for fake_src_addr, fake_src_port in test_clients: + response = await send_cql_with_proxy_header( + host=server.ip_addr, + port=PROXY_PORT, + proxy_src_addr=fake_src_addr, + proxy_src_port=fake_src_port, + proxy_dst_addr=server.ip_addr, + proxy_dst_port=PROXY_PORT, + query="SELECT * FROM system.local" + ) + + assert len(response) > 9, f"Expected a valid CQL response for {fake_src_addr}" + opcode = response[4] + assert opcode == 0x08, f"Expected RESULT opcode for {fake_src_addr}, got {hex(opcode)}" + + +@pytest.mark.asyncio +async def test_proxy_protocol_port_preserved_in_system_clients(proxy_server): + """ + Test that the source port from the proxy protocol header is correctly + preserved in system.clients, which is important for shard-aware routing. + """ + server, manager = proxy_server + + fake_src_addr = "203.0.113.200" + fake_src_port = 44444 + + # Keep the connection open while we query system.clients + reader, writer = await asyncio.open_connection(server.ip_addr, PROXY_PORT) + + try: + # Send proxy header + proxy_header = make_proxy_v2_header( + fake_src_addr, fake_src_port, + server.ip_addr, PROXY_PORT + ) + writer.write(proxy_header) + await writer.drain() + + # Complete CQL handshake + await do_cql_handshake(reader, writer) + + # Now query system.clients using the driver to see our connection + cql = manager.get_cql() + rows = list(cql.execute( + f"SELECT address, port FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING" + )) + + # We should find our connection with the fake source address and port + assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients" + + found_correct_port = False + for row in rows: + if row.port == fake_src_port: + found_correct_port = True + break + + assert found_correct_port, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}" + + finally: + writer.close() + await writer.wait_closed() + + +@pytest.mark.asyncio +async def test_proxy_protocol_ssl_basic(proxy_server): + """ + Test proxy protocol with TLS encryption. + The proxy header is sent first, then the connection is upgraded to TLS. + """ + server, manager = proxy_server + + fake_src_addr = "203.0.113.50" + fake_src_port = 55555 + + response = await send_cql_with_proxy_header_tls( + host=server.ip_addr, + port=PROXY_PORT_SSL, + proxy_src_addr=fake_src_addr, + proxy_src_port=fake_src_port, + proxy_dst_addr=server.ip_addr, + proxy_dst_port=PROXY_PORT_SSL, + query="SELECT * FROM system.local" + ) + + assert len(response) > 9, "Expected a valid CQL response" + opcode = response[4] + assert opcode == 0x08, f"Expected RESULT opcode (0x08), got {hex(opcode)}" + + +@pytest.mark.asyncio +async def test_proxy_protocol_ssl_shard_aware(proxy_server): + """ + Test proxy protocol with TLS on the shard-aware port. We connect to all + shards by using source ports that map to each shard (port % num_shards). + """ + server, manager = proxy_server + + # The test harness runs with --smp 2 by default + num_shards = 2 + + cql = manager.get_cql() + + fake_src_addr = "203.0.113.51" + base_port = 20000 + + # Keep connections open while we verify shard assignments + ssl_sockets = [] + try: + # Connect to each shard by using source ports that map to each shard + for shard in range(num_shards): + # Choose a port that maps to this shard: port % num_shards == shard + fake_src_port = base_port + shard + + # Create raw socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + + loop = asyncio.get_event_loop() + await loop.sock_connect(sock, (server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL)) + + # Send proxy header on raw socket before TLS + proxy_header = make_proxy_v2_header( + fake_src_addr, fake_src_port, + server.ip_addr, PROXY_SHARD_AWARE_PORT_SSL + ) + sock.setblocking(True) + sock.sendall(proxy_header) + + # Upgrade to TLS + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_sock = ssl_context.wrap_socket(sock, do_handshake_on_connect=False) + + # Complete TLS handshake + while True: + try: + ssl_sock.do_handshake() + break + except ssl.SSLWantReadError: + await asyncio.sleep(0.01) + except ssl.SSLWantWriteError: + await asyncio.sleep(0.01) + + ssl_sockets.append((ssl_sock, fake_src_port, shard)) + + # Send STARTUP frame + startup_frame = build_cql_startup_frame(stream=1) + ssl_sock.sendall(startup_frame) + + # Read response and handle auth if needed + response = ssl_sock.recv(4096) + opcode = response[4] + if opcode == 0x03: # AUTHENTICATE + auth_body = b'\x00cassandra\x00cassandra' + auth_frame = bytearray() + auth_frame.append(0x04) + auth_frame.append(0x00) + auth_frame.extend(struct.pack('!H', 1)) + auth_frame.append(0x0F) + auth_frame.extend(struct.pack('!I', len(auth_body))) + auth_frame.extend(auth_body) + ssl_sock.sendall(bytes(auth_frame)) + ssl_sock.recv(4096) + + # Now query system.clients to verify shard assignments + rows = list(cql.execute( + f"SELECT address, port, shard_id, ssl_enabled FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING" + )) + + # Build a map of port -> (shard_id, ssl_enabled) from the results + port_to_info = {row.port: (row.shard_id, row.ssl_enabled) for row in rows} + + # Verify each connection landed on the expected shard with SSL enabled + for ssl_sock, fake_src_port, expected_shard in ssl_sockets: + assert fake_src_port in port_to_info, f"Port {fake_src_port} not found in system.clients" + actual_shard, ssl_enabled = port_to_info[fake_src_port] + assert actual_shard == expected_shard, \ + f"Port {fake_src_port} expected shard {expected_shard}, got {actual_shard}" + assert ssl_enabled, f"Port {fake_src_port} expected ssl_enabled=True" + + finally: + for ssl_sock, _, _ in ssl_sockets: + ssl_sock.close() + + +@pytest.mark.asyncio +async def test_proxy_protocol_ssl_port_preserved(proxy_server): + """ + Test that the source port from the proxy protocol header is correctly + preserved in system.clients when using TLS. + """ + server, manager = proxy_server + + fake_src_addr = "203.0.113.201" + fake_src_port = 44445 + + # Create a connection that stays open + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + + loop = asyncio.get_event_loop() + await loop.sock_connect(sock, (server.ip_addr, PROXY_PORT_SSL)) + + ssl_sock = None + try: + # Send proxy header on raw socket + proxy_header = make_proxy_v2_header( + fake_src_addr, fake_src_port, + server.ip_addr, PROXY_PORT_SSL + ) + await loop.sock_sendall(sock, proxy_header) + + # Wrap with TLS + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + ssl_sock = ssl_context.wrap_socket(sock, server_hostname=server.ip_addr, do_handshake_on_connect=False) + + # Do TLS handshake + while True: + try: + ssl_sock.do_handshake() + break + except ssl.SSLWantReadError: + await asyncio.sleep(0.01) + except ssl.SSLWantWriteError: + await asyncio.sleep(0.01) + + ssl_sock.setblocking(True) + + # Send STARTUP + startup_frame = build_cql_startup_frame(stream=1) + ssl_sock.sendall(startup_frame) + + # Read response and handle auth if needed + response = ssl_sock.recv(4096) + opcode = response[4] + if opcode == 0x03: # AUTHENTICATE + auth_body = b'\x00cassandra\x00cassandra' + auth_frame = bytearray() + auth_frame.append(0x04) + auth_frame.append(0x00) + auth_frame.extend(struct.pack('!H', 1)) + auth_frame.append(0x0F) + auth_frame.extend(struct.pack('!I', len(auth_body))) + auth_frame.extend(auth_body) + ssl_sock.sendall(bytes(auth_frame)) + ssl_sock.recv(4096) + + # Now query system.clients using the driver to see our connection + cql = manager.get_cql() + rows = list(cql.execute( + f"SELECT address, port, ssl_enabled FROM system.clients WHERE address = '{fake_src_addr}' ALLOW FILTERING" + )) + + # We should find our connection + assert len(rows) > 0, f"Expected to find connection from {fake_src_addr} in system.clients" + + found_correct = False + for row in rows: + if row.port == fake_src_port: + found_correct = True + assert row.ssl_enabled, "Expected connection to have ssl_enabled=True" + break + + assert found_correct, f"Expected to find port {fake_src_port} in system.clients, got ports: {[r.port for r in rows]}" + + finally: + if ssl_sock: + ssl_sock.close() + else: + sock.close() diff --git a/transport/controller.cc b/transport/controller.cc index 5c1ff6cfce..f9ca540f01 100644 --- a/transport/controller.cc +++ b/transport/controller.cc @@ -72,9 +72,9 @@ future<> controller::start_server() { return do_start_server().finally([this] { _ops_sem.signal(); }); } -static future<> listen_on_all_shards(sharded& cserver, socket_address addr, std::shared_ptr creds, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions) { - co_await cserver.invoke_on_all([addr, creds, is_shard_aware, keepalive, unix_domain_socket_permissions] (cql_server& server) { - return server.listen(addr, creds, is_shard_aware, keepalive, unix_domain_socket_permissions, [&c = server.container()]() -> auto& { return c.local(); }); +static future<> listen_on_all_shards(sharded& cserver, socket_address addr, std::shared_ptr creds, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, bool proxy_protocol = false) { + co_await cserver.invoke_on_all([addr, creds, is_shard_aware, keepalive, unix_domain_socket_permissions, proxy_protocol] (cql_server& server) { + return server.listen(addr, creds, is_shard_aware, keepalive, unix_domain_socket_permissions, proxy_protocol, [&c = server.container()]() -> auto& { return c.local(); }); }); } @@ -89,6 +89,7 @@ future<> controller::start_listening_on_tcp_sockets(sharded& cserver socket_address addr; bool is_shard_aware; std::shared_ptr cred; + bool proxy_protocol = false; }; _listen_addresses.clear(); @@ -112,8 +113,9 @@ future<> controller::start_listening_on_tcp_sockets(sharded& cserver } // main should have made sure values are clean and neatish + std::shared_ptr cred; if (utils::is_true(utils::get_or_default(ceo, "enabled", "false"))) { - auto cred = std::make_shared(); + cred = std::make_shared(); utils::configure_tls_creds_builder(*cred, std::move(ceo)).get(); logger.info("Enabling encrypted CQL connections between client and server"); @@ -130,18 +132,37 @@ future<> controller::start_listening_on_tcp_sockets(sharded& cserver if (cfg.native_shard_aware_transport_port_ssl.is_set() && (!cfg.native_shard_aware_transport_port.is_set() || cfg.native_shard_aware_transport_port_ssl() != cfg.native_shard_aware_transport_port())) { - configs.emplace_back(listen_cfg{{ip, cfg.native_shard_aware_transport_port_ssl()}, true, std::move(cred)}); + configs.emplace_back(listen_cfg{{ip, cfg.native_shard_aware_transport_port_ssl()}, true, cred}); _listen_addresses.push_back(configs.back().addr); } else if (native_shard_aware_port_idx >= 0) { - configs[native_shard_aware_port_idx].cred = std::move(cred); + configs[native_shard_aware_port_idx].cred = cred; } } - co_await parallel_for_each(configs, [&cserver, keepalive](const listen_cfg & cfg) -> future<> { - co_await listen_on_all_shards(cserver, cfg.addr, cfg.cred, cfg.is_shard_aware, keepalive, std::nullopt); + // Proxy protocol ports (disabled by default, port 0 means disabled) + if (cfg.native_transport_port_proxy_protocol()) { + configs.emplace_back(listen_cfg{{ip, cfg.native_transport_port_proxy_protocol()}, false, nullptr, true}); + _listen_addresses.push_back(configs.back().addr); + } + if (cfg.native_shard_aware_transport_port_proxy_protocol()) { + configs.emplace_back(listen_cfg{{ip, cfg.native_shard_aware_transport_port_proxy_protocol()}, true, nullptr, true}); + _listen_addresses.push_back(configs.back().addr); + } + if (cfg.native_transport_port_ssl_proxy_protocol() && cred) { + configs.emplace_back(listen_cfg{{ip, cfg.native_transport_port_ssl_proxy_protocol()}, false, cred, true}); + _listen_addresses.push_back(configs.back().addr); + } + if (cfg.native_shard_aware_transport_port_ssl_proxy_protocol() && cred) { + configs.emplace_back(listen_cfg{{ip, cfg.native_shard_aware_transport_port_ssl_proxy_protocol()}, true, cred, true}); + _listen_addresses.push_back(configs.back().addr); + } - logger.info("Starting listening for CQL clients on {} ({}, {})" + co_await parallel_for_each(configs, [&cserver, keepalive](const listen_cfg & cfg) -> future<> { + co_await listen_on_all_shards(cserver, cfg.addr, cfg.cred, cfg.is_shard_aware, keepalive, std::nullopt, cfg.proxy_protocol); + + logger.info("Starting listening for CQL clients on {} ({}, {}{})" , cfg.addr, cfg.cred ? "encrypted" : "unencrypted", cfg.is_shard_aware ? "shard-aware" : "non-shard-aware" + , cfg.proxy_protocol ? ", proxy-protocol" : "" ); }); } diff --git a/transport/generic_server.cc b/transport/generic_server.cc index a115ea39ac..59f8516ca6 100644 --- a/transport/generic_server.cc +++ b/transport/generic_server.cc @@ -307,7 +307,7 @@ future<> server::shutdown() { } future<> -server::listen(socket_address addr, std::shared_ptr builder, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, std::function get_shard_instance) { +server::listen(socket_address addr, std::shared_ptr builder, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, bool proxy_protocol, std::function get_shard_instance) { // Note: We are making the assumption that if builder is provided it will be the same for each // invocation, regardless of address etc. In general, only CQL server will call this multiple times, // and if TLS, it will use the same cert set. @@ -339,6 +339,7 @@ server::listen(socket_address addr, std::shared_ptr creds, bool is_shard_aware, bool keepalive, std::optional unix_domain_socket_permissions, + bool proxy_protocol = false, std::function get_shard_instance = {} );