diff --git a/configure.py b/configure.py index 53d3c9abdc..0587856f8b 100755 --- a/configure.py +++ b/configure.py @@ -1493,7 +1493,8 @@ idls = ['idl/gossip_digest.idl.hh', 'idl/gossip.idl.hh', 'idl/migration_manager.idl.hh', "idl/node_ops.idl.hh", - "idl/tasks.idl.hh" + "idl/tasks.idl.hh", + "idl/client_state.idl.hh", ] scylla_tests_generic_dependencies = [ diff --git a/idl/CMakeLists.txt b/idl/CMakeLists.txt index d56cdf0d01..03132638f5 100644 --- a/idl/CMakeLists.txt +++ b/idl/CMakeLists.txt @@ -68,6 +68,7 @@ set(idl_headers migration_manager.idl.hh node_ops.idl.hh tasks.idl.hh + client_state.idl.hh ) foreach(idl_header ${idl_headers}) diff --git a/idl/client_state.idl.hh b/idl/client_state.idl.hh new file mode 100644 index 0000000000..5df22acb2f --- /dev/null +++ b/idl/client_state.idl.hh @@ -0,0 +1,33 @@ +/* + * Copyright 2026-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "db/timeout_clock.hh" +#include "gms/inet_address_serializer.hh" + +struct timeout_config { + db::timeout_clock::duration read_timeout; + db::timeout_clock::duration write_timeout; + db::timeout_clock::duration range_read_timeout; + db::timeout_clock::duration counter_write_timeout; + db::timeout_clock::duration truncate_timeout; + db::timeout_clock::duration cas_timeout; + db::timeout_clock::duration other_timeout; +}; + +namespace service { + +struct forwarded_client_state { + sstring keyspace; + std::optional username; + timeout_config timeout_config; + uint64_t protocol_extensions_mask; + gms::inet_address remote_address; + uint16_t remote_port; +}; + +} diff --git a/service/client_state.cc b/service/client_state.cc index d51d1dd47c..ee436b2fd7 100644 --- a/service/client_state.cc +++ b/service/client_state.cc @@ -369,3 +369,27 @@ future<> service::client_state::set_client_options( _client_options.emplace_back(std::move(cached_key), std::move(cached_value)); } } + +service::forwarded_client_state::forwarded_client_state( + sstring keyspace, + std::optional username, + ::timeout_config timeout_config, + uint64_t protocol_extensions_mask, + gms::inet_address remote_address, + uint16_t remote_port) + : keyspace(std::move(keyspace)) + , username(std::move(username)) + , timeout_config(std::move(timeout_config)) + , protocol_extensions_mask(protocol_extensions_mask) + , remote_address(std::move(remote_address)) + , remote_port(remote_port) +{ } + +service::forwarded_client_state::forwarded_client_state(const client_state& cs) + : keyspace(cs.get_raw_keyspace()) + , username(cs.user() ? std::optional{cs.user()->name} : std::nullopt) + , timeout_config(cs.get_timeout_config()) + , protocol_extensions_mask(static_cast(cs.get_protocol_extensions().mask())) + , remote_address(cs.get_client_address()) + , remote_port(cs.get_client_port()) +{ } diff --git a/service/client_state.hh b/service/client_state.hh index d6b5a19757..bf636b3315 100644 --- a/service/client_state.hh +++ b/service/client_state.hh @@ -33,6 +33,24 @@ class database; namespace service { +class client_state; +struct forwarded_client_state { + sstring keyspace; + std::optional username; + timeout_config timeout_config; + uint64_t protocol_extensions_mask; + gms::inet_address remote_address; + uint16_t remote_port; + + forwarded_client_state(sstring keyspace, + std::optional username, + ::timeout_config timeout_config, + uint64_t protocol_extensions_mask, + gms::inet_address remote_address, + uint16_t remote_port); + forwarded_client_state(const client_state& cs); +}; + /** * State related to a client connection. */ @@ -246,6 +264,23 @@ public: , _sl_controller(&sl_controller) {} + client_state(auth::service& auth_service, + qos::service_level_controller* sl_controller, + forwarded_client_state&& forwarded_state) + : _keyspace(std::move(forwarded_state.keyspace)) + , _user(forwarded_state.username ? auth::authenticated_user(*forwarded_state.username) : auth::authenticated_user{}) + , _auth_state(auth_state::READY) + , _is_internal(false) + , _bypass_auth_checks(false) + , _remote_address(socket_address(forwarded_state.remote_address, forwarded_state.remote_port)) + , _auth_service(&auth_service) + , _sl_controller(sl_controller) + , _default_timeout_config(forwarded_state.timeout_config) + , _timeout_config(std::move(forwarded_state.timeout_config)) + , _enabled_protocol_extensions(cql_transport::cql_protocol_extension_enum_set::from_mask( + forwarded_state.protocol_extensions_mask)) + {} + client_state(const client_state&) = delete; client_state(client_state&&) = default; @@ -454,6 +489,10 @@ public: return _enabled_protocol_extensions.contains(ext); } + cql_transport::cql_protocol_extension_enum_set get_protocol_extensions() const { + return _enabled_protocol_extensions; + } + void set_protocol_extensions(cql_transport::cql_protocol_extension_enum_set exts) { _enabled_protocol_extensions = std::move(exts); }