diff --git a/generic_server.cc b/generic_server.cc index 3d8bafc0e2..0ea776fd8e 100644 --- a/generic_server.cc +++ b/generic_server.cc @@ -353,18 +353,19 @@ server::listen(socket_address addr, std::shared_ptr {}", _server_name, addr, std::current_exception())); } _listeners.emplace_back(std::move(ss)); - _listeners_stopped = when_all(std::move(_listeners_stopped), do_accepts(_listeners.size() - 1, keepalive, addr)).discard_result(); + _listeners_stopped = when_all(std::move(_listeners_stopped), do_accepts(_listeners.size() - 1, keepalive, addr, is_tls)).discard_result(); } -future<> server::do_accepts(int which, bool keepalive, socket_address server_addr) { +future<> server::do_accepts(int which, bool keepalive, socket_address server_addr, bool is_tls) { while (!_gate.is_closed()) { seastar::gate::holder holder(_gate); bool shed = false; @@ -404,10 +405,26 @@ future<> server::do_accepts(int which, bool keepalive, socket_address server_add conn->shutdown(); continue; } + conn->_ssl_enabled = is_tls; // Move the processing into the background. - (void)futurize_invoke([this, conn] { - // Block while monitoring for lifetime/errors. - return conn->process().then_wrapped([this, conn] (auto f) { + (void)futurize_invoke([this, conn, is_tls] { + return (is_tls + ? tls::get_protocol_version(conn->_fd).then([conn](const sstring& protocol) { + return tls::get_cipher_suite(conn->_fd).then( + [conn, protocol](const sstring& cipher_suite) mutable { + conn->_ssl_protocol = protocol; + conn->_ssl_cipher_suite = cipher_suite; + return make_ready_future(true); + }); + }).handle_exception([this, conn](std::exception_ptr ep) { + _logger.warn("Inspecting TLS connection failed: {}", ep); + return make_ready_future(false); + }) + : make_ready_future(true) + ).then([conn] (bool ok){ + // Block while monitoring for lifetime/errors. + return ok ? conn->process() : make_ready_future<>(); + }).then_wrapped([this, conn](auto f) { try { f.get(); } catch (...) { diff --git a/generic_server.hh b/generic_server.hh index e9b1e5c768..952b1eaa19 100644 --- a/generic_server.hh +++ b/generic_server.hh @@ -60,6 +60,10 @@ protected: seastar::named_gate _pending_requests_gate; seastar::gate::holder _hold_server; + bool _ssl_enabled = false; + std::optional _ssl_cipher_suite = std::nullopt; + std::optional _ssl_protocol = std::nullopt;; + private: future<> process_until_tenant_switch(); bool shutdown_input(); @@ -144,7 +148,7 @@ public: std::function get_shard_instance = {} ); - future<> do_accepts(int which, bool keepalive, socket_address server_addr); + future<> do_accepts(int which, bool keepalive, socket_address server_addr, bool is_tls); protected: virtual seastar::shared_ptr make_connection(socket_address server_addr, connected_socket&& fd, socket_address addr, named_semaphore& sem, semaphore_units initial_sem_units) = 0; diff --git a/test/cqlpy/test_ssl.py b/test/cqlpy/test_ssl.py index 35ea11955d..73540d86f2 100644 --- a/test/cqlpy/test_ssl.py +++ b/test/cqlpy/test_ssl.py @@ -10,8 +10,25 @@ import pytest -import ssl import cassandra.cluster +from contextlib import contextmanager +import re +import ssl + + +# This function normalizes the SSL cipher suite name (a string), +# which we need to do because tests use python library and scylla server uses C library, +# and both the python's and the C's library naming conventions are different, +# so we need some translation in order to compare them. +def normalize_cipher(cipher_name: str) -> str: + if cipher_name.startswith("TLS_"): + cipher_name = cipher_name[len("TLS_"):] # Remove leading "TLS_" if present. + cipher_name = cipher_name.replace("_WITH_", "-") + cipher_name = cipher_name.replace("_", "-") + # Remove hyphen between letters and digits: e.g. convert "AES-256" to "AES256" + cipher_name = re.sub(r'([A-Z]+)-(\d+)', r'\1\2', cipher_name) + return cipher_name + # Test that TLS 1.2 is supported (because this is what "cqlsh --ssl" uses # by default), and that other TLS version are either supported - or if @@ -46,6 +63,39 @@ def test_tls_versions(cql): assert 'protocol version' in str(e) or 'no protocols available' in str(e) print(f"{ssl_version} not supported") +# a regression test for #9216 +def test_system_clients_stores_tls_info(cql): + if not cql.cluster.ssl_context: + table_result = cql.execute(f"SELECT * FROM system.clients") + for row in table_result: + assert not row.ssl_enabled + assert row.ssl_protocol is None + assert row.ssl_cipher_suite is None + + if cql.cluster.ssl_context: + # TLS v1.2 must be supported, because this is the default version that + # "cqlsh --ssl" uses. If this fact changes in the future, we may need + # to reconsider this test. + with try_connect(cql.cluster, ssl.TLSVersion.TLSv1_2) as session: + # As of time of writing, python driver spawns 5 to 6 connections for this single session, + # and some connections may already be past the TLS init phase, while others not, when we query system.clients, + # so we need to retry until all connections are initialized and have their TLS info recorded in system.clients, + # otherwise we'd end up with some connections e.g. having their ssl_enabled=True but other fields still None. + expected_ciphers = [normalize_cipher(cipher['name']) for cipher in ssl.create_default_context().get_ciphers()] + for _ in range(1000): # try for up to 1000 * 0.01s = 10s seconds + rows = session.execute(f"SELECT * FROM system.clients") + if rows and all( + row.ssl_enabled + and row.ssl_protocol == 'TLS1.2' + and normalize_cipher(row.ssl_cipher_suite) in expected_ciphers + for row in rows + ): + return + time.sleep(0.01) + pytest.fail(f"Not all connections have TLS data set correctly in system.clients after 10s seconds") + + +@contextmanager def try_connect(orig_cluster, ssl_version): ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS) ssl_context.minimum_version = ssl_version @@ -63,8 +113,11 @@ def try_connect(orig_cluster, ssl_version): # so let's increase them to 60 seconds. See issue #11289. connect_timeout = 60, control_connection_timeout = 60) - cluster.connect() - cluster.shutdown() + try: + session = cluster.connect() + yield session + finally: + cluster.shutdown() # Test that if we try to connect to an SSL port with *unencrypted* CQL, # it doesn't work. diff --git a/transport/server.cc b/transport/server.cc index b89ccca920..063df6f5e4 100644 --- a/transport/server.cc +++ b/transport/server.cc @@ -680,6 +680,11 @@ client_data cql_server::connection::make_client_data() const { cd.connection_stage = client_connection_stage::authenticating; } cd.scheduling_group_name = _current_scheduling_group.name(); + + cd.ssl_enabled = _ssl_enabled; + cd.ssl_protocol = _ssl_protocol; + cd.ssl_cipher_suite = _ssl_cipher_suite; + return cd; }