Files
scylladb/transport/server.cc
Sergey Zolotukhin ea311be12b generic_server: Two-step connection shutdown.
When shutting down in `generic_server`, connections are now closed in two steps.
First, only the RX (receive) side is shut down. Then, after all ongoing requests
are completed, or a timeout happened the connections are fully closed.

Fixes scylladb/scylladb#24481
2025-07-28 10:08:06 +02:00

2239 lines
107 KiB
C++

/*
* Copyright (C) 2015-present ScyllaDB
*/
/*
* SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0
*/
#include "server.hh"
#include "cql3/statements/batch_statement.hh"
#include "cql3/statements/modification_statement.hh"
#include <seastar/core/scheduling.hh>
#include <seastar/core/semaphore.hh>
#include "types/collection.hh"
#include "types/list.hh"
#include "types/set.hh"
#include "types/map.hh"
#include "types/vector.hh"
#include "dht/token-sharding.hh"
#include "service/migration_manager.hh"
#include "service/storage_service.hh"
#include "service/memory_limiter.hh"
#include "service/storage_proxy.hh"
#include "service/qos/service_level_controller.hh"
#include "db/consistency_level_type.hh"
#include "db/write_type.hh"
#include <seastar/core/coroutine.hh>
#include <seastar/core/future-util.hh>
#include <seastar/core/seastar.hh>
#include <seastar/coroutine/as_future.hh>
#include <seastar/net/byteorder.hh>
#include <seastar/core/metrics.hh>
#include <seastar/net/byteorder.hh>
#include <seastar/net/tls.hh>
#include <seastar/util/lazy.hh>
#include <seastar/util/short_streams.hh>
#include <seastar/core/execution_stage.hh>
#include "utils/assert.hh"
#include "utils/exception_container.hh"
#include "utils/log.hh"
#include "utils/result_try.hh"
#include "utils/result_combinators.hh"
#include "db/operation_type.hh"
#include "enum_set.hh"
#include "service/query_state.hh"
#include "service/client_state.hh"
#include "exceptions/exceptions.hh"
#include "client_data.hh"
#include "cql3/query_processor.hh"
#include "auth/authenticator.hh"
#include <cassert>
#include <string>
#include <snappy-c.h>
#include <lz4.h>
#include "response.hh"
#include "request.hh"
#include "types/user.hh"
#include "transport/cql_protocol_extension.hh"
#include "utils/bit_cast.hh"
#include "utils/labels.hh"
#include "utils/reusable_buffer.hh"
template<typename T = void>
using coordinator_result = exceptions::coordinator_result<T>;
namespace cql_transport {
static logging::logger clogger("cql_server");
/**
* Skip registering CQL metrics for these SGs - these are internal scheduling groups that are not supposed to handle CQL
* requests.
*/
static const std::vector<sstring> non_cql_scheduling_classes_names = {
"atexit",
"background_reclaim",
"compaction",
"gossip",
"main",
"mem_compaction",
"memtable",
"memtable_to_cache",
"streaming"
};
struct cql_frame_error : std::exception {
const char* what() const throw () override {
return "bad cql binary frame";
}
};
inline int16_t consistency_to_wire(db::consistency_level c)
{
switch (c) {
case db::consistency_level::ANY: return 0x0000;
case db::consistency_level::ONE: return 0x0001;
case db::consistency_level::TWO: return 0x0002;
case db::consistency_level::THREE: return 0x0003;
case db::consistency_level::QUORUM: return 0x0004;
case db::consistency_level::ALL: return 0x0005;
case db::consistency_level::LOCAL_QUORUM: return 0x0006;
case db::consistency_level::EACH_QUORUM: return 0x0007;
case db::consistency_level::SERIAL: return 0x0008;
case db::consistency_level::LOCAL_SERIAL: return 0x0009;
case db::consistency_level::LOCAL_ONE: return 0x000A;
default: throw std::runtime_error("Invalid consistency level");
}
}
sstring to_string(const event::topology_change::change_type t) {
using type = event::topology_change::change_type;
switch (t) {
case type::NEW_NODE: return "NEW_NODE";
case type::REMOVED_NODE: return "REMOVED_NODE";
case type::MOVED_NODE: return "MOVED_NODE";
}
throw std::invalid_argument("unknown change type");
}
sstring to_string(cql_binary_opcode op) {
switch(op) {
case cql_binary_opcode::ERROR: return "ERROR";
case cql_binary_opcode::STARTUP: return "STARTUP";
case cql_binary_opcode::READY: return "READY";
case cql_binary_opcode::AUTHENTICATE: return "AUTHENTICATE";
case cql_binary_opcode::CREDENTIALS: return "CREDENTIALS";
case cql_binary_opcode::OPTIONS: return "OPTIONS";
case cql_binary_opcode::SUPPORTED: return "SUPPORTED";
case cql_binary_opcode::QUERY: return "QUERY";
case cql_binary_opcode::RESULT: return "RESULT";
case cql_binary_opcode::PREPARE: return "PREPARE";
case cql_binary_opcode::EXECUTE: return "EXECUTE";
case cql_binary_opcode::REGISTER: return "REGISTER";
case cql_binary_opcode::EVENT: return "EVENT";
case cql_binary_opcode::BATCH: return "BATCH";
case cql_binary_opcode::AUTH_CHALLENGE: return "AUTH_CHALLENGE";
case cql_binary_opcode::AUTH_RESPONSE: return "AUTH_RESPONSE";
case cql_binary_opcode::AUTH_SUCCESS: return "AUTH_SUCCESS";
case cql_binary_opcode::OPCODES_COUNT: return "OPCODES_COUNT";
}
return format("Unknown CQL binary opcode {}", static_cast<unsigned>(op));
}
sstring to_string(const event::status_change::status_type t) {
using type = event::status_change::status_type;
switch (t) {
case type::UP: return "UP";
case type::DOWN: return "DOWN";
}
throw std::invalid_argument("unknown change type");
}
sstring to_string(const event::schema_change::change_type t) {
switch (t) {
case event::schema_change::change_type::CREATED: return "CREATED";
case event::schema_change::change_type::UPDATED: return "UPDATED";
case event::schema_change::change_type::DROPPED: return "DROPPED";
}
SCYLLA_ASSERT(false && "unreachable");
}
sstring to_string(const event::schema_change::target_type t) {
switch (t) {
case event::schema_change::target_type::KEYSPACE: return "KEYSPACE";
case event::schema_change::target_type::TABLE: return "TABLE";
case event::schema_change::target_type::TYPE: return "TYPE";
case event::schema_change::target_type::FUNCTION: return "FUNCTION";
case event::schema_change::target_type::AGGREGATE:return "AGGREGATE";
}
SCYLLA_ASSERT(false && "unreachable");
}
bool is_metadata_id_supported(const service::client_state& client_state) {
// TODO: metadata_id is mandatory in CQLv5, so extend the check below
// when CQLv5 support is implemented
return client_state.is_protocol_extension_set(cql_transport::cql_protocol_extension::USE_METADATA_ID);
}
utils::result_with_exception<event::event_type, exceptions::protocol_exception>
parse_event_type(const sstring& value)
{
if (value == "TOPOLOGY_CHANGE") {
return event::event_type::TOPOLOGY_CHANGE;
} else if (value == "STATUS_CHANGE") {
return event::event_type::STATUS_CHANGE;
} else if (value == "SCHEMA_CHANGE") {
return event::event_type::SCHEMA_CHANGE;
} else {
return exceptions::protocol_exception(format("Invalid value '{}' for Event.Type", value));
}
}
cql_sg_stats::cql_sg_stats(maintenance_socket_enabled used_by_maintenance_socket)
: _cql_requests_stats(static_cast<uint8_t>(cql_binary_opcode::OPCODES_COUNT))
{
if (used_by_maintenance_socket) {
return;
}
auto& vector_ref = non_cql_scheduling_classes_names;
if (std::find(vector_ref.begin(), vector_ref.end(), current_scheduling_group().name()) != vector_ref.end()) {
return;
}
_use_metrics = true;
register_metrics();
}
void cql_sg_stats::register_metrics()
{
namespace sm = seastar::metrics;
auto new_metrics = sm::metric_groups();
std::vector<sm::metric_definition> transport_metrics;
auto cur_sg_name = current_scheduling_group().name();
for (uint8_t i = 0; i < static_cast<uint8_t>(cql_binary_opcode::OPCODES_COUNT); ++i) {
cql_binary_opcode opcode = cql_binary_opcode{i};
transport_metrics.emplace_back(
sm::make_counter("cql_requests_count", [this, opcode] { return get_cql_opcode_stats(opcode).count; },
sm::description("Counts the total number of CQL messages of a specific kind."),
{{"kind", to_string(opcode)}, {"scheduling_group_name", cur_sg_name}}).set_skip_when_empty()
);
transport_metrics.emplace_back(
sm::make_counter("cql_request_bytes", [this, opcode] { return get_cql_opcode_stats(opcode).request_size; },
sm::description("Counts the total number of received bytes in CQL messages of a specific kind."),
{{"kind", to_string(opcode)}, {"scheduling_group_name", cur_sg_name}}).set_skip_when_empty()
);
transport_metrics.emplace_back(
sm::make_counter("cql_response_bytes", [this, opcode] { return get_cql_opcode_stats(opcode).response_size; },
sm::description("Counts the total number of sent response bytes for CQL requests of a specific kind."),
{{"kind", to_string(opcode)}, {"scheduling_group_name", cur_sg_name}}).set_skip_when_empty()
);
}
new_metrics.add_group("transport", std::move(transport_metrics));
_metrics = std::exchange(new_metrics, {});
}
void cql_sg_stats::rename_metrics() {
if (_use_metrics) {
register_metrics();
}
}
cql_server::cql_server(distributed<cql3::query_processor>& qp, auth::service& auth_service,
service::memory_limiter& ml, cql_server_config config,
qos::service_level_controller& sl_controller, gms::gossiper& g, scheduling_group_key stats_key,
maintenance_socket_enabled used_by_maintenance_socket)
: server("CQLServer", clogger, generic_server::config{std::move(config.uninitialized_connections_semaphore_cpu_concurrency), config.request_timeout_on_shutdown_in_seconds})
, _query_processor(qp)
, _config(std::move(config))
, _memory_available(ml.get_semaphore())
, _notifier(std::make_unique<event_notifier>(*this))
, _auth_service(auth_service)
, _sl_controller(sl_controller)
, _gossiper(g)
, _stats_key(stats_key)
{
namespace sm = seastar::metrics;
if (used_by_maintenance_socket) {
return;
}
auto ls = {
sm::make_counter("cql-connections", _stats.connects,
sm::description("Counts a number of client connections.")),
sm::make_gauge("current_connections", _stats.connections,
sm::description("Holds a current number of client connections."))(basic_level),
sm::make_counter("requests_served", _stats.requests_served,
sm::description("Counts a number of served requests."))(basic_level),
sm::make_gauge("requests_serving", _stats.requests_serving,
sm::description("Holds a number of requests that are being processed right now.")),
sm::make_gauge("requests_blocked_memory_current", [this] { return _memory_available.waiters(); },
sm::description(
seastar::format("Holds the number of requests that are currently blocked due to reaching the memory quota limit ({}B). "
"Non-zero value indicates that our bottleneck is memory and more specifically - the memory quota allocated for the \"CQL transport\" component.", _config.max_request_size))),
sm::make_counter("requests_blocked_memory", _stats.requests_blocked_memory,
sm::description(
seastar::format("Holds an incrementing counter with the requests that ever blocked due to reaching the memory quota limit ({}B). "
"The first derivative of this value shows how often we block due to memory exhaustion in the \"CQL transport\" component.", _config.max_request_size))),
sm::make_counter("requests_shed", _stats.requests_shed,
sm::description("Holds an incrementing counter with the requests that were shed due to overload (threshold configured via max_concurrent_requests_per_shard). "
"The first derivative of this value shows how often we shed requests due to overload in the \"CQL transport\" component."))(basic_level),
sm::make_counter("connections_shed", _shed_connections,
sm::description("Holds an incrementing counter with the CQL connections that were shed due to concurrency semaphore timeout (threshold configured via uninitialized_connections_semaphore_cpu_concurrency). "
"This typically can happen during connection storm. ")),
sm::make_counter("connections_blocked", _blocked_connections,
sm::description("Holds an incrementing counter with the CQL connections that were blocked before being processed due to threshold configured via uninitialized_connections_semaphore_cpu_concurrency. "
"Blocks are normal when we have multiple connections initialized at once. If connections are timing out and this value is high it indicates either connections storm or unusually slow processing.")),
sm::make_gauge("requests_memory_available", [this] { return _memory_available.current(); },
sm::description(
seastar::format("Holds the amount of available memory for admitting new requests (max is {}B)."
"Zero value indicates that our bottleneck is memory and more specifically - the memory quota allocated for the \"CQL transport\" component.", _config.max_request_size)))
};
std::vector<sm::metric_definition> transport_metrics;
for (auto& m : ls) {
transport_metrics.emplace_back(std::move(m));
}
sm::label cql_error_label("type");
for (const auto& e : exceptions::exception_map()) {
_stats.errors.insert({e.first, 0});
auto label_instance = cql_error_label(e.second);
transport_metrics.emplace_back(
sm::make_counter("cql_errors_total", sm::description("Counts the total number of returned CQL errors."),
{label_instance, basic_level},
[this, code = e.first] { auto it = _stats.errors.find(code); return it != _stats.errors.end() ? it->second : 0; }).set_skip_when_empty()
);
}
_metrics.add_group("transport", std::move(transport_metrics));
}
cql_server::~cql_server() = default;
shared_ptr<generic_server::connection>
cql_server::make_connection(socket_address server_addr, connected_socket&& fd, socket_address addr, named_semaphore& sem, semaphore_units<named_semaphore_exception_factory> initial_sem_units) {
return make_shared<connection>(*this, server_addr, std::move(fd), std::move(addr), sem, std::move(initial_sem_units));
}
unsigned
cql_server::connection::frame_size() const {
return 9;
}
utils::result_with_exception<cql_binary_frame_v3, exceptions::protocol_exception, cql_frame_error>
cql_server::connection::parse_frame(temporary_buffer<char> buf) const {
if (buf.size() != frame_size()) {
return cql_frame_error();
}
cql_binary_frame_v3 v3;
switch (_version) {
case 3:
case 4: {
cql_binary_frame_v3 raw = read_unaligned<cql_binary_frame_v3>(buf.get());
v3 = net::ntoh(raw);
break;
}
default:
return exceptions::protocol_exception(format("Invalid or unsupported protocol version: {:d}", _version));
}
if (v3.version != _version) {
return exceptions::protocol_exception(format("Invalid message version. Got {:d} but previous messages on this connection had version {:d}", v3.version, _version));
}
return v3;
}
future<std::optional<cql_binary_frame_v3>>
cql_server::connection::read_frame() {
using ret_type = std::optional<cql_binary_frame_v3>;
if (!_version) {
// We don't know the frame size before reading the first frame,
// so read just one byte, and then read the rest of the frame.
return _read_buf.read_exactly(1).then([this] (temporary_buffer<char> buf) {
if (buf.empty()) {
return make_ready_future<ret_type>();
}
_version = buf[0];
if (_version < 3 || _version > current_version) {
auto client_version = _version;
_version = current_version;
return make_exception_future<ret_type>(exceptions::protocol_exception(format("Invalid or unsupported protocol version: {:d}", client_version)));
}
return _read_buf.read_exactly(frame_size() - 1).then([this] (temporary_buffer<char> tail) {
temporary_buffer<char> full(frame_size());
full.get_write()[0] = _version;
std::copy(tail.get(), tail.get() + tail.size(), full.get_write() + 1);
auto frame = parse_frame(std::move(full));
if (!frame) {
return std::move(frame).assume_error().as_exception_future<ret_type>();
}
// This is the very first frame, so reject obviously incorrect frames, to
// avoid allocating large amounts of memory for the message body
if (frame.value().length > 100'000) {
// The STARTUP message body is a [string map] containing just a few options,
// so it should be smaller that 100kB. See #4366.
return make_exception_future<ret_type>(exceptions::protocol_exception(format("Initial message size too large ({:d}), rejecting as invalid", uint32_t(frame.value().length))));
}
return make_ready_future<ret_type>(std::move(frame).value());
});
});
} else {
// Not the first frame, so we know the size.
return _read_buf.read_exactly(frame_size()).then([this] (temporary_buffer<char> buf) {
if (buf.empty()) {
return make_ready_future<ret_type>();
}
auto frame = parse_frame(std::move(buf));
if (!frame) {
return std::move(frame).assume_error().as_exception_future<ret_type>();
};
return make_ready_future<ret_type>(std::move(frame).value());
});
}
}
// This function intentionally sleeps to the end of the query timeout in CQL server.
// It was introduced to remove similar waiting in storage_proxy (ref. scylladb#3699),
// because storage proxy was blocking ERM (thus topology changes).
future<foreign_ptr<std::unique_ptr<cql_server::response>>> cql_server::connection::sleep_until_timeout_passes(const seastar::lowres_clock::time_point& timeout, std::unique_ptr<cql_server::response>&& resp) const {
auto time_left = timeout - seastar::lowres_clock::now();
return seastar::sleep_abortable(time_left, _server._abort_source).then_wrapped([resp = std::move(resp)](auto&& f) mutable {
if (f.failed()) {
clogger.debug("Got exception {} while waiting for a request timeout.", f.get_exception());
}
// Return timeout error no matter if sleep was aborted or not
return utils::result_into_future<result_with_foreign_response_ptr>(std::move(resp));
});
}
future<foreign_ptr<std::unique_ptr<cql_server::response>>>
cql_server::connection::process_request_one(fragmented_temporary_buffer::istream fbuf, uint8_t op, uint16_t stream, service::client_state& client_state, tracing_request_type tracing_request, service_permit permit) {
using auth_state = service::client_state::auth_state;
auto cqlop = static_cast<cql_binary_opcode>(op);
tracing::trace_state_props_set trace_props;
trace_props.set_if<tracing::trace_state_props::log_slow_query>(tracing::tracing::get_local_tracing_instance().slow_query_tracing_enabled());
trace_props.set_if<tracing::trace_state_props::full_tracing>(tracing_request != tracing_request_type::not_requested);
tracing::trace_state_ptr trace_state;
if (trace_props) {
if (cqlop == cql_binary_opcode::QUERY ||
cqlop == cql_binary_opcode::PREPARE ||
cqlop == cql_binary_opcode::EXECUTE ||
cqlop == cql_binary_opcode::BATCH) {
trace_props.set_if<tracing::trace_state_props::write_on_close>(tracing_request == tracing_request_type::write_on_close);
trace_state = tracing::tracing::get_local_tracing_instance().create_session(tracing::trace_type::QUERY, trace_props);
}
}
cql_sg_stats::request_kind_stats& cql_stats = _server.get_cql_opcode_stats(cqlop);
tracing::set_request_size(trace_state, fbuf.bytes_left());
cql_stats.request_size += fbuf.bytes_left();
++cql_stats.count;
auto linearization_buffer = std::make_unique<bytes_ostream>();
auto linearization_buffer_ptr = linearization_buffer.get();
return futurize_invoke([this, cqlop, stream, &fbuf, &client_state, linearization_buffer_ptr, permit = std::move(permit), trace_state] () mutable {
// When using authentication, we need to ensure we are doing proper state transitions,
// i.e. we cannot simply accept any query/exec ops unless auth is complete
switch (client_state.get_auth_state()) {
case auth_state::UNINITIALIZED:
if (cqlop != cql_binary_opcode::STARTUP && cqlop != cql_binary_opcode::OPTIONS) {
return make_exception_future<result_with_foreign_response_ptr>(exceptions::protocol_exception(format("Unexpected message {:d}, expecting STARTUP or OPTIONS", int(cqlop))));
}
break;
case auth_state::AUTHENTICATION:
// Support both SASL auth from protocol v2 and the older style Credentials auth from v1
if (cqlop != cql_binary_opcode::AUTH_RESPONSE && cqlop != cql_binary_opcode::CREDENTIALS) {
return make_exception_future<result_with_foreign_response_ptr>(exceptions::protocol_exception(format("Unexpected message {:d}, expecting {}", int(cqlop), "SASL_RESPONSE")));
}
break;
case auth_state::READY: default:
if (cqlop == cql_binary_opcode::STARTUP) {
return make_exception_future<result_with_foreign_response_ptr>(exceptions::protocol_exception("Unexpected message STARTUP, the connection is already initialized"));
}
break;
}
tracing::set_username(trace_state, client_state.user());
auto wrap_in_foreign = [] (future<std::unique_ptr<cql_server::response>> f) {
return f.then([] (std::unique_ptr<cql_server::response> p) {
return make_ready_future<result_with_foreign_response_ptr>(make_foreign(std::move(p)));
});
};
auto in = request_reader(std::move(fbuf), *linearization_buffer_ptr);
switch (cqlop) {
case cql_binary_opcode::STARTUP: return wrap_in_foreign(process_startup(stream, std::move(in), client_state, trace_state));
case cql_binary_opcode::AUTH_RESPONSE: return wrap_in_foreign(process_auth_response(stream, std::move(in), client_state, trace_state));
case cql_binary_opcode::OPTIONS: return wrap_in_foreign(process_options(stream, std::move(in), client_state, trace_state));
case cql_binary_opcode::QUERY: return process_query(stream, std::move(in), client_state, std::move(permit), trace_state);
case cql_binary_opcode::PREPARE: return wrap_in_foreign(process_prepare(stream, std::move(in), client_state, trace_state));
case cql_binary_opcode::EXECUTE: return process_execute(stream, std::move(in), client_state, std::move(permit), trace_state);
case cql_binary_opcode::BATCH: return process_batch(stream, std::move(in), client_state, std::move(permit), trace_state);
case cql_binary_opcode::REGISTER: return wrap_in_foreign(process_register(stream, std::move(in), client_state, trace_state));
default: return make_exception_future<result_with_foreign_response_ptr>(exceptions::protocol_exception(format("Unknown opcode {:d}", int(cqlop))));
}
}).then_wrapped([this, cqlop, &cql_stats, stream, &client_state, linearization_buffer = std::move(linearization_buffer), trace_state] (future<result_with_foreign_response_ptr> f) {
auto stop_trace = defer([&] {
tracing::stop_foreground(trace_state);
});
--_server._stats.requests_serving;
return seastar::futurize_invoke([&] () {
if (f.failed()) {
return make_exception_future<foreign_ptr<std::unique_ptr<cql_server::response>>>(std::move(f).get_exception());
}
result_with_foreign_response_ptr res = f.get();
if (!res) {
return std::move(res).assume_error().as_exception_future<foreign_ptr<std::unique_ptr<cql_server::response>>>();
}
auto response = std::move(res).assume_value();
auto res_op = response->opcode();
// and modify state now that we've generated a response.
switch (client_state.get_auth_state()) {
case auth_state::UNINITIALIZED:
if (cqlop == cql_binary_opcode::STARTUP) {
if (res_op == cql_binary_opcode::AUTHENTICATE) {
client_state.set_auth_state(auth_state::AUTHENTICATION);
} else if (res_op == cql_binary_opcode::READY) {
client_state.set_auth_state(auth_state::READY);
}
}
break;
case auth_state::AUTHENTICATION:
// Support both SASL auth from protocol v2 and the older style Credentials auth from v1
if (cqlop != cql_binary_opcode::AUTH_RESPONSE && cqlop != cql_binary_opcode::CREDENTIALS) {
return make_exception_future<foreign_ptr<std::unique_ptr<cql_server::response>>>(exceptions::protocol_exception(format("Unexpected message {:d}, expecting AUTH_RESPONSE or CREDENTIALS", int(cqlop))));
}
if (res_op == cql_binary_opcode::READY || res_op == cql_binary_opcode::AUTH_SUCCESS) {
client_state.set_auth_state(auth_state::READY);
}
break;
default:
case auth_state::READY:
break;
}
tracing::set_response_size(trace_state, response->size());
cql_stats.response_size += response->size();
return make_ready_future<foreign_ptr<std::unique_ptr<cql_server::response>>>(std::move(response));
}).handle_exception([this, stream, &client_state, trace_state] (std::exception_ptr eptr) {
if (auto* exp = try_catch<exceptions::unavailable_exception>(eptr)) {
clogger.debug("{}: request resulted in unavailable_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_unavailable_error(stream, exp->code(), exp->what(), exp->consistency, exp->required, exp->alive, trace_state));
} else if (auto* exp = try_catch<exceptions::read_failure_exception_with_timeout>(eptr)) {
clogger.debug("{}: request resulted in read_failure_exception_with_timeout, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
// Return read timeout exception, as we wait here until the timeout passes
return sleep_until_timeout_passes(
exp->_timeout,
make_read_timeout_error(stream, exp->_timeout_exception.code(), exp->_timeout_exception.what(), exp->_timeout_exception.consistency, exp->_timeout_exception.received, exp->_timeout_exception.block_for, exp->_timeout_exception.data_present, trace_state)
);
} else if (auto* exp = try_catch<exceptions::read_timeout_exception>(eptr)) {
clogger.debug("{}: request resulted in read_timeout_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_read_timeout_error(stream, exp->code(), exp->what(), exp->consistency, exp->received, exp->block_for, exp->data_present, trace_state));
} else if (auto* exp = try_catch<exceptions::read_failure_exception>(eptr)) {
clogger.debug("{}: request resulted in read_failure_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_read_failure_error(stream, exp->code(), exp->what(), exp->consistency, exp->received, exp->failures, exp->block_for, exp->data_present, trace_state));
} else if (auto* exp = try_catch<exceptions::mutation_write_timeout_exception>(eptr)) {
clogger.debug("{}: request resulted in mutation_write_timeout_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_mutation_write_timeout_error(stream, exp->code(), exp->what(), exp->consistency, exp->received, exp->block_for, exp->type, trace_state));
} else if (auto* exp = try_catch<exceptions::mutation_write_failure_exception>(eptr)) {
clogger.debug("{}: request resulted in mutation_write_failure_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_mutation_write_failure_error(stream, exp->code(), exp->what(), exp->consistency, exp->received, exp->failures, exp->block_for, exp->type, trace_state));
} else if (auto* exp = try_catch<exceptions::already_exists_exception>(eptr)) {
clogger.debug("{}: request resulted in already_exists_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_already_exists_error(stream, exp->code(), exp->what(), exp->ks_name, exp->cf_name, trace_state));
} else if (auto* exp = try_catch<exceptions::prepared_query_not_found_exception>(eptr)) {
clogger.debug("{}: request resulted in unprepared_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_unprepared_error(stream, exp->code(), exp->what(), exp->id, trace_state));
} else if (auto* exp = try_catch<exceptions::function_execution_exception>(eptr)) {
clogger.debug("{}: request resulted in function_failure_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_function_failure_error(stream, exp->code(), exp->what(), exp->ks_name, exp->func_name, exp->args, trace_state));
} else if (auto* exp = try_catch<exceptions::rate_limit_exception>(eptr)) {
clogger.debug("{}: request resulted in rate_limit_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_rate_limit_error(stream, exp->code(), exp->what(), exp->op_type, exp->rejected_by_coordinator, trace_state, client_state));
} else if (auto* exp = try_catch<exceptions::cassandra_exception>(eptr)) {
clogger.debug("{}: request resulted in cassandra_error, stream {}, code {}, message [{}]",
_client_state.get_remote_address(), stream, exp->code(), exp->what());
// Note: the CQL protocol specifies that many types of errors have
// mandatory parameters. These cassandra_exception subclasses MUST
// be handled above. This default "cassandra_exception" case is
// only appropriate for the specific types of errors which do not have
// additional information, such as invalid_request_exception.
// TODO: consider listing those types explicitly, instead of the
// catch-all type cassandra_exception.
try { ++_server._stats.errors[exp->code()]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_error(stream, exp->code(), exp->what(), trace_state));
} else if (auto* exp = try_catch<std::exception>(eptr)) {
clogger.debug("{}: request resulted in error, stream {}, message [{}]",
_client_state.get_remote_address(), stream, exp->what());
try { ++_server._stats.errors[exceptions::exception_code::SERVER_ERROR]; } catch(...) {}
sstring msg = exp->what();
try {
std::rethrow_if_nested(*exp);
} catch (...) {
std::ostringstream ss;
ss << msg << ": " << std::current_exception();
msg = ss.str();
}
return utils::result_into_future<result_with_foreign_response_ptr>(make_error(stream, exceptions::exception_code::SERVER_ERROR, msg, trace_state));
} else {
clogger.debug("{}: request resulted in unknown error, stream {}",
_client_state.get_remote_address(), stream);
try { ++_server._stats.errors[exceptions::exception_code::SERVER_ERROR]; } catch(...) {}
return utils::result_into_future<result_with_foreign_response_ptr>(make_error(stream, exceptions::exception_code::SERVER_ERROR, "unknown error", trace_state));
}
});
});
}
cql_server::connection::connection(cql_server& server, socket_address server_addr, connected_socket&& fd, socket_address addr, named_semaphore& sem, semaphore_units<named_semaphore_exception_factory> initial_sem_units)
: generic_server::connection{server, std::move(fd), sem, std::move(initial_sem_units)}
, _server(server)
, _server_addr(server_addr)
, _client_state(service::client_state::external_tag{}, server._auth_service, &server._sl_controller, server.timeout_config(), addr)
, _current_scheduling_group(default_scheduling_group())
{
_shedding_timer.set_callback([this] {
clogger.debug("Shedding all incoming requests due to overload");
_shed_incoming_requests = true;
});
++_server._stats.connects;
++_server._stats.connections;
if (clogger.is_enabled(logging::log_level::trace)) {
const auto ip = get_client_state().get_client_address().addr();
const auto port = get_client_state().get_client_port();
clogger.trace("Advertising new connection from CQL client {}:{}", ip, port);
}
}
cql_server::connection::~connection() {
_server._notifier->unregister_connection(this);
--_server._stats.connections;
if (clogger.is_enabled(logging::log_level::trace)) {
const auto ip = get_client_state().get_client_address().addr();
const auto port = get_client_state().get_client_port();
clogger.trace("Advertising disconnection of CQL client {}:{}", ip, port);
}
}
client_data cql_server::connection::make_client_data() const {
client_data cd;
cd.ip = _client_state.get_client_address().addr();
cd.port = _client_state.get_client_port();
cd.shard_id = this_shard_id();
cd.protocol_version = _version;
cd.driver_name = _client_state.get_driver_name();
cd.driver_version = _client_state.get_driver_version();
if (const auto user_ptr = _client_state.user(); user_ptr) {
cd.username = user_ptr->name;
}
if (_ready) {
cd.connection_stage = client_connection_stage::ready;
} else if (_authenticating) {
cd.connection_stage = client_connection_stage::authenticating;
}
cd.scheduling_group_name = _current_scheduling_group.name();
return cd;
}
thread_local cql_server::connection::execution_stage_type
cql_server::connection::_process_request_stage{"transport", &connection::process_request_one};
void cql_server::connection::handle_error(future<>&& f) {
if (!f.failed()) {
return;
}
std::exception_ptr eptr = f.get_exception();
if (auto* ex = try_catch<exceptions::cassandra_exception>(eptr)) {
clogger.debug("{}: connection error, code {}, message [{}]", _client_state.get_remote_address(), ex->code(), ex->what());
try { ++_server._stats.errors[ex->code()]; } catch(...) {}
write_response(make_error(0, ex->code(), ex->what(), tracing::trace_state_ptr()));
} else if (auto* ex = try_catch<std::exception>(eptr)) {
clogger.debug("{}: connection error, message [{}]", _client_state.get_remote_address(), ex->what());
try { ++_server._stats.errors[exceptions::exception_code::SERVER_ERROR]; } catch(...) {}
write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, ex->what(), tracing::trace_state_ptr()));
} else {
clogger.debug("{}: connection error, unknown error", _client_state.get_remote_address());
try { ++_server._stats.errors[exceptions::exception_code::SERVER_ERROR]; } catch(...) {}
write_response(make_error(0, exceptions::exception_code::SERVER_ERROR, "unknown error", tracing::trace_state_ptr()));
}
}
future<> cql_server::connection::process_request() {
return read_frame().then_wrapped([this] (future<std::optional<cql_binary_frame_v3>>&& v) {
if (v.failed()) {
return std::move(v).discard_result();
}
auto maybe_frame = v.get();
if (!maybe_frame) {
// eof
return make_ready_future<>();
}
auto& f = *maybe_frame;
const bool allow_shedding = _client_state.get_workload_type() == service::client_state::workload_type::interactive;
if (allow_shedding && _shed_incoming_requests) {
++_server._stats.requests_shed;
return _read_buf.skip(f.length).then([this, stream = f.stream] {
const char* message = "request shed due to coordinator overload";
clogger.debug("{}: {}, stream {}", _client_state.get_remote_address(), message, uint16_t(stream));
write_response(make_error(stream, exceptions::exception_code::OVERLOADED,
message, tracing::trace_state_ptr()));
return make_ready_future<>();
});
}
tracing_request_type tracing_requested = tracing_request_type::not_requested;
if (f.flags & cql_frame_flags::tracing) {
// If tracing is requested for a specific CQL command - flush
// tracing info right after the command is over.
tracing_requested = tracing_request_type::write_on_close;
} else if (tracing::tracing::get_local_tracing_instance().trace_next_query()) {
tracing_requested = tracing_request_type::no_write_on_close;
}
auto op = f.opcode;
auto stream = f.stream;
auto mem_estimate = f.length * 2 + 8000; // Allow for extra copies and bookkeeping
if (mem_estimate > _server._config.max_request_size) {
const auto message = format("request size too large (frame size {:d}; estimate {:d}; allowed {:d})",
uint32_t(f.length), mem_estimate, _server._config.max_request_size);
clogger.debug("{}: {}, request dropped", _client_state.get_remote_address(), message);
write_response(make_error(stream, exceptions::exception_code::INVALID, message, tracing::trace_state_ptr()));
return std::exchange(_ready_to_respond, make_ready_future<>())
.then([this] { return _read_buf.close(); })
.then([this] { return util::skip_entire_stream(_read_buf); });
}
if (_server._stats.requests_serving > _server._config.max_concurrent_requests) {
++_server._stats.requests_shed;
return _read_buf.skip(f.length).then([this, stream = f.stream] {
const auto message = format("too many in-flight requests (configured via max_concurrent_requests_per_shard): {}",
_server._stats.requests_serving);
clogger.debug("{}: {}, request dropped", _client_state.get_remote_address(), message);
write_response(make_error(stream, exceptions::exception_code::OVERLOADED,
message,
tracing::trace_state_ptr()));
return make_ready_future<>();
});
}
const auto shedding_timeout = std::chrono::milliseconds(50);
auto fut = allow_shedding
? get_units(_server._memory_available, mem_estimate, shedding_timeout).then_wrapped([this, length = f.length] (auto f) {
try {
return make_ready_future<semaphore_units<>>(f.get());
} catch (semaphore_timed_out& sto) {
// Cancel shedding in case no more requests are going to do that on completion
if (_pending_requests_gate.get_count() == 0) {
_shed_incoming_requests = false;
}
return _read_buf.skip(length).then([sto = std::move(sto)] () mutable {
return make_exception_future<semaphore_units<>>(std::move(sto));
});
}
})
: get_units(_server._memory_available, mem_estimate);
if (_server._memory_available.waiters()) {
if (allow_shedding && !_shedding_timer.armed()) {
_shedding_timer.arm(shedding_timeout);
}
++_server._stats.requests_blocked_memory;
}
return fut.then_wrapped([this, length = f.length, flags = f.flags, op, stream, tracing_requested] (auto mem_permit_fut) {
if (mem_permit_fut.failed()) {
// Ignore semaphore errors - they are expected if load shedding took place
mem_permit_fut.ignore_ready_future();
return make_ready_future<>();
}
semaphore_units<> mem_permit = mem_permit_fut.get();
return this->read_and_decompress_frame(length, flags).then([this, op, stream, tracing_requested, mem_permit = make_service_permit(std::move(mem_permit))] (fragmented_temporary_buffer buf) mutable {
++_server._stats.requests_served;
++_server._stats.requests_serving;
_pending_requests_gate.enter();
auto leave = defer([this] {
_shedding_timer.cancel();
_shed_incoming_requests = false;
_pending_requests_gate.leave();
});
auto istream = buf.get_istream();
// Parallelize only the performance sensitive requests:
// QUERY, PREPARE, EXECUTE, BATCH
bool should_paralelize = (op == uint8_t(cql_binary_opcode::QUERY) ||
op == uint8_t(cql_binary_opcode::PREPARE) ||
op == uint8_t (cql_binary_opcode::EXECUTE) ||
op == uint8_t(cql_binary_opcode::BATCH));
future<foreign_ptr<std::unique_ptr<cql_server::response>>> request_process_future = should_paralelize ?
_process_request_stage(this, istream, op, stream, seastar::ref(_client_state), tracing_requested, mem_permit) :
process_request_one(istream, op, stream, seastar::ref(_client_state), tracing_requested, mem_permit);
future<> request_response_future = request_process_future.then_wrapped([this, buf = std::move(buf), mem_permit, leave = std::move(leave), stream] (future<foreign_ptr<std::unique_ptr<cql_server::response>>> response_f) mutable {
try {
if (response_f.failed()) {
const auto message = format("request processing failed, error [{}]", response_f.get_exception());
clogger.error("{}: {}", _client_state.get_remote_address(), message);
write_response(make_error(stream, exceptions::exception_code::SERVER_ERROR,
message,
tracing::trace_state_ptr()));
} else {
write_response(response_f.get(), std::move(mem_permit), _compression);
}
_ready_to_respond = _ready_to_respond.finally([leave = std::move(leave)] {});
} catch (...) {
clogger.error("{}: request processing failed: {}",
_client_state.get_remote_address(), std::current_exception());
}
});
if (should_paralelize) {
return make_ready_future<>();
} else {
return request_response_future;
}
});
});
});
}
// Contiguous buffers for use with compression primitives.
// Be careful when dealing with them, because they are shared and
// can be modified on preemption points.
// See the comments on reusable_buffer for a discussion.
static utils::reusable_buffer_guard input_buffer_guard() {
using namespace std::chrono_literals;
static thread_local utils::reusable_buffer<lowres_clock> buf(600s);
return buf;
}
static utils::reusable_buffer_guard output_buffer_guard() {
using namespace std::chrono_literals;
static thread_local utils::reusable_buffer<lowres_clock> buf(600s);
return buf;
}
future<fragmented_temporary_buffer> cql_server::connection::read_and_decompress_frame(size_t length, uint8_t flags)
{
if (flags & cql_frame_flags::compression) {
if (_compression == cql_compression::lz4) {
if (length < 4) {
return make_exception_future<fragmented_temporary_buffer>(std::runtime_error(fmt::format("CQL frame truncated: expected to have at least 4 bytes, got {}", length)));
}
return _buffer_reader.read_exactly(_read_buf, length).then([] (fragmented_temporary_buffer buf) {
auto input_buffer = input_buffer_guard();
auto output_buffer = output_buffer_guard();
auto v = fragmented_temporary_buffer::view(buf);
int32_t uncomp_len = read_simple<int32_t>(v);
if (uncomp_len < 0) {
return make_exception_future<fragmented_temporary_buffer>(std::runtime_error("CQL frame uncompressed length is negative: " + std::to_string(uncomp_len)));
}
auto in = input_buffer.get_linearized_view(v);
return utils::result_into_future(output_buffer.make_fragmented_temporary_buffer(uncomp_len, [&in] (bytes_mutable_view out) -> utils::result_with_exception<size_t, std::runtime_error> {
auto ret = LZ4_decompress_safe(reinterpret_cast<const char*>(in.data()), reinterpret_cast<char*>(out.data()), in.size(), out.size());
if (ret < 0) {
return bo::failure(std::runtime_error("CQL frame LZ4 uncompression failure"));
}
if (static_cast<size_t>(ret) != out.size()) { // ret is known to be positive here
return bo::failure(std::runtime_error("Malformed CQL frame - provided uncompressed size different than real uncompressed size"));
}
return bo::success(static_cast<size_t>(ret));
}));
});
} else if (_compression == cql_compression::snappy) {
return _buffer_reader.read_exactly(_read_buf, length).then([] (fragmented_temporary_buffer buf) {
auto input_buffer = input_buffer_guard();
auto output_buffer = output_buffer_guard();
auto in = input_buffer.get_linearized_view(fragmented_temporary_buffer::view(buf));
size_t uncomp_len;
if (snappy_uncompressed_length(reinterpret_cast<const char*>(in.data()), in.size(), &uncomp_len) != SNAPPY_OK) {
return make_exception_future<fragmented_temporary_buffer>(std::runtime_error("CQL frame Snappy uncompressed size is unknown"));
}
return utils::result_into_future(output_buffer.make_fragmented_temporary_buffer(uncomp_len, [&in] (bytes_mutable_view out) -> utils::result_with_exception<size_t, std::runtime_error> {
size_t output_len = out.size();
if (snappy_uncompress(reinterpret_cast<const char*>(in.data()), in.size(), reinterpret_cast<char*>(out.data()), &output_len) != SNAPPY_OK) {
return bo::failure(std::runtime_error("CQL frame Snappy uncompression failure"));
}
if (output_len != out.size()) {
return bo::failure(std::runtime_error("Malformed CQL frame - provided uncompressed size different than real uncompressed size"));
}
return bo::success(output_len);
}));
});
} else {
return make_exception_future<fragmented_temporary_buffer>(exceptions::protocol_exception("Unknown compression algorithm"));
}
}
return _buffer_reader.read_exactly(_read_buf, length);
}
future<std::unique_ptr<cql_server::response>> cql_server::connection::process_startup(uint16_t stream, request_reader in, service::client_state& client_state,
tracing::trace_state_ptr trace_state) {
auto options = in.read_string_map();
auto compression_opt = options.find("COMPRESSION");
if (compression_opt != options.end()) {
auto compression = compression_opt->second;
std::transform(compression.begin(), compression.end(), compression.begin(), ::tolower);
if (compression == "lz4") {
_compression = cql_compression::lz4;
} else if (compression == "snappy") {
_compression = cql_compression::snappy;
} else {
co_return coroutine::exception(std::make_exception_ptr(exceptions::protocol_exception(format("Unknown compression algorithm: {}", compression))));
}
}
if (auto driver_ver_opt = options.find("DRIVER_VERSION"); driver_ver_opt != options.end()) {
_client_state.set_driver_version(driver_ver_opt->second);
}
if (auto driver_name_opt = options.find("DRIVER_NAME"); driver_name_opt != options.end()) {
_client_state.set_driver_name(driver_name_opt->second);
}
cql_protocol_extension_enum_set cql_proto_exts;
for (cql_protocol_extension ext : supported_cql_protocol_extensions()) {
if (options.contains(protocol_extension_name(ext))) {
cql_proto_exts.set(ext);
}
}
_client_state.set_protocol_extensions(std::move(cql_proto_exts));
std::unique_ptr<cql_server::response> res;
if (auto& a = client_state.get_auth_service()->underlying_authenticator(); a.require_authentication()) {
_authenticating = true;
auto opt_user = co_await a.authenticate([this]() -> future<std::optional<auth::certificate_info>> {
auto dn_info = co_await tls::get_dn_information(this->_fd);
if (dn_info) {
co_return auth::certificate_info{ dn_info->subject, [this]() -> future<std::string> {
auto altnames = co_await tls::get_alt_name_information(this->_fd);
auto res = fmt::format("{}", fmt::join(altnames, ","));
co_return res;
} };
}
co_return std::nullopt;
});
if (opt_user) {
client_state.set_login(std::move(*opt_user));
co_await client_state.check_user_can_login();
co_await client_state.maybe_update_per_service_level_params();
res = make_ready(stream, trace_state);
} else {
res = make_autheticate(stream, a.qualified_java_name(), trace_state);
}
} else {
_ready = true;
on_connection_ready();
res = make_ready(stream, trace_state);
}
co_return res;
}
void cql_server::connection::update_scheduling_group() {
switch_tenant([this] (noncopyable_function<future<> ()> process_loop) -> future<> {
auto shg = co_await _server._sl_controller.get_user_scheduling_group(_client_state.user());
_current_scheduling_group = shg;
co_return co_await _server._sl_controller.with_user_service_level(_client_state.user(), std::move(process_loop));
});
}
future<std::unique_ptr<cql_server::response>> cql_server::connection::process_auth_response(uint16_t stream, request_reader in, service::client_state& client_state,
tracing::trace_state_ptr trace_state) {
auto sasl_challenge = client_state.get_auth_service()->underlying_authenticator().new_sasl_challenge();
auto buf = in.read_raw_bytes_view(in.bytes_left());
auto challenge = sasl_challenge->evaluate_response(buf);
if (sasl_challenge->is_complete()) {
return sasl_challenge->get_authenticated_user().then_wrapped([this, sasl_challenge, stream, &client_state, challenge = std::move(challenge), trace_state](future<auth::authenticated_user> f) mutable {
bool failed = f.failed();
return audit::inspect_login(sasl_challenge->get_username(), client_state.get_client_address().addr(), failed).then(
[this, stream, challenge = std::move(challenge), &client_state, sasl_challenge, ff = std::move(f), trace_state = std::move(trace_state)] () mutable {
client_state.set_login(ff.get());
update_scheduling_group();
auto f = client_state.check_user_can_login();
f = f.then([&client_state] {
return client_state.maybe_update_per_service_level_params();
});
return f.then([this, stream, challenge = std::move(challenge), trace_state]() mutable {
_authenticating = false;
_ready = true;
on_connection_ready();
return make_ready_future<std::unique_ptr<cql_server::response>>(make_auth_success(stream, std::move(challenge), trace_state));
});
});
});
}
return make_ready_future<std::unique_ptr<cql_server::response>>(make_auth_challenge(stream, std::move(challenge), trace_state));
}
future<std::unique_ptr<cql_server::response>> cql_server::connection::process_options(uint16_t stream, request_reader in, service::client_state& client_state,
tracing::trace_state_ptr trace_state) {
return make_ready_future<std::unique_ptr<cql_server::response>>(make_supported(stream, std::move(trace_state)));
}
std::unique_ptr<cql_server::response>
make_result(int16_t stream, messages::result_message& msg, const tracing::trace_state_ptr& tr_state,
cql_protocol_version_type version, cql_metadata_id_wrapper&& metadata_id, bool skip_metadata = false);
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<cql_server::process_fn_return_type>
cql_server::connection::process_on_shard(shard_id shard, uint16_t stream, fragmented_temporary_buffer::istream is, service::client_state& cs,
tracing::trace_state_ptr trace_state, cql3::dialect dialect, cql3::computed_function_values&& cached_vals, Process process_fn) {
auto sg = _server._config.bounce_request_smp_service_group;
auto gcs = cs.move_to_other_shard();
auto gt = tracing::global_trace_state_ptr(std::move(trace_state));
co_return co_await _server.container().invoke_on(shard, sg, [&, stream, dialect] (cql_server& server) -> future<process_fn_return_type> {
bytes_ostream linearization_buffer;
request_reader in(is, linearization_buffer);
auto client_state = gcs.get();
auto trace_state = gt.get();
co_return co_await process_fn(client_state, server._query_processor, in, stream, _version,
/* FIXME */empty_service_permit(), std::move(trace_state), false, cached_vals, dialect);
});
}
static inline cql_server::result_with_foreign_response_ptr convert_error_message_to_coordinator_result(messages::result_message* msg) {
return std::move(*dynamic_cast<messages::result_message::exception*>(msg)).get_exception();
}
template <typename Process>
requires std::is_invocable_r_v<future<cql_server::process_fn_return_type>,
Process,
service::client_state&,
distributed<cql3::query_processor>&,
request_reader,
uint16_t,
cql_protocol_version_type,
service_permit,
tracing::trace_state_ptr,
bool,
cql3::computed_function_values,
cql3::dialect>
future<cql_server::result_with_foreign_response_ptr>
cql_server::connection::process(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit,
tracing::trace_state_ptr trace_state, Process process_fn) {
fragmented_temporary_buffer::istream is = in.get_stream();
auto dialect = get_dialect();
auto f = co_await coroutine::as_future(process_fn(client_state, _server._query_processor, in, stream,
_version, permit, trace_state, true, {}, dialect));
if (f.failed()) {
co_return coroutine::exception(f.get_exception());
}
auto msg = std::move(f.get());
while (auto* bounce_msg = std::get_if<result_with_bounce_to_shard>(&msg)) {
auto shard = (*bounce_msg)->move_to_shard().value();
auto&& cached_vals = (*bounce_msg)->take_cached_pk_function_calls();
msg = co_await process_on_shard(shard, stream, is, client_state, trace_state, dialect, std::move(cached_vals), process_fn);
}
co_return std::get<cql_server::result_with_foreign_response_ptr>(std::move(msg));
}
static future<cql_server::process_fn_return_type>
process_query_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls,
cql3::dialect dialect) {
auto query = in.read_long_string_view();
auto q_state = std::make_unique<cql_query_state>(client_state, trace_state, std::move(permit));
auto& query_state = q_state->query_state;
q_state->options = in.read_options(version, qp.local().get_cql_config());
auto& options = *q_state->options;
if (!cached_pk_fn_calls.empty()) {
options.set_cached_pk_function_calls(std::move(cached_pk_fn_calls));
}
auto skip_metadata = options.skip_metadata();
if (init_trace) {
tracing::set_page_size(trace_state, options.get_page_size());
tracing::add_query(trace_state, query);
tracing::set_common_query_parameters(trace_state, options.get_consistency(),
options.get_serial_consistency(), options.get_specific_options().timestamp);
tracing::begin(trace_state, "Execute CQL3 query", client_state.get_client_address());
}
return qp.local().execute_direct_without_checking_exception_message(query, query_state, dialect, options).then([q_state = std::move(q_state), stream, skip_metadata, version] (auto msg) {
if (msg->move_to_shard()) {
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, cql_metadata_id_wrapper{}, skip_metadata)));
}
});
}
future<cql_server::result_with_foreign_response_ptr>
cql_server::connection::process_query(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state) {
return process(stream, in, client_state, std::move(permit), std::move(trace_state), process_query_internal);
}
future<std::unique_ptr<cql_server::response>> cql_server::connection::process_prepare(uint16_t stream, request_reader in, service::client_state& client_state,
tracing::trace_state_ptr trace_state) {
auto query = sstring(in.read_long_string_view());
auto dialect = get_dialect();
tracing::add_query(trace_state, query);
tracing::begin(trace_state, "Preparing CQL3 query", client_state.get_client_address());
return _server._query_processor.invoke_on_others([query, &client_state, dialect] (auto& qp) mutable {
return qp.prepare(std::move(query), client_state, dialect).discard_result();
}).then([this, query, stream, &client_state, trace_state, dialect] () mutable {
tracing::trace(trace_state, "Done preparing on remote shards");
return _server._query_processor.local().prepare(std::move(query), client_state, dialect).then([this, stream, &client_state, trace_state] (auto msg) {
tracing::trace(trace_state, "Done preparing on a local shard - preparing a result. ID is [{}]", seastar::value_of([&msg] {
return messages::result_message::prepared::cql::get_id(msg);
}));
cql_metadata_id_wrapper metadata_id = is_metadata_id_supported(client_state)
? cql_metadata_id_wrapper(msg->get_metadata_id())
: cql_metadata_id_wrapper();
return make_result(stream, *msg, trace_state, _version, std::move(metadata_id));
});
});
}
static future<cql_server::process_fn_return_type>
process_execute_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls,
cql3::dialect dialect) {
cql3::prepared_cache_key_type cache_key(in.read_short_bytes(), dialect);
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
bool needs_authorization = false;
// First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
// look for the prepared statement and then authorize it.
auto prepared = qp.local().get_prepared(client_state.user(), cache_key);
if (!prepared) {
needs_authorization = true;
prepared = qp.local().get_prepared(cache_key);
}
if (!prepared) {
throw exceptions::prepared_query_not_found_exception(id);
}
cql_metadata_id_wrapper metadata_id = is_metadata_id_supported(client_state)
? cql_metadata_id_wrapper(cql3::cql_metadata_id_type(in.read_short_bytes()), prepared->get_metadata_id())
: cql_metadata_id_wrapper();
auto q_state = std::make_unique<cql_query_state>(client_state, trace_state, std::move(permit));
auto& query_state = q_state->query_state;
q_state->options = in.read_options(version, qp.local().get_cql_config());
auto& options = *q_state->options;
if (!cached_pk_fn_calls.empty()) {
options.set_cached_pk_function_calls(std::move(cached_pk_fn_calls));
}
auto skip_metadata = options.skip_metadata();
if (init_trace) {
tracing::set_page_size(trace_state, options.get_page_size());
tracing::add_query(trace_state, prepared->statement->raw_cql_statement);
tracing::add_prepared_statement(trace_state, prepared);
tracing::set_common_query_parameters(trace_state, options.get_consistency(),
options.get_serial_consistency(), options.get_specific_options().timestamp);
tracing::begin(trace_state, seastar::value_of([&id] { return seastar::format("Execute CQL3 prepared query [{}]", id); }),
client_state.get_client_address());
}
auto stmt = prepared->statement;
tracing::trace(query_state.get_trace_state(), "Checking bounds");
if (stmt->get_bound_terms() != options.get_values_count()) {
const auto msg = format("Invalid amount of bind variables: expected {:d} received {:d}",
stmt->get_bound_terms(),
options.get_values_count());
tracing::trace(query_state.get_trace_state(), "{}", msg);
throw exceptions::invalid_request_exception(msg);
}
options.prepare(prepared->bound_names);
if (init_trace) {
tracing::add_prepared_query_options(trace_state, options);
}
tracing::trace(trace_state, "Processing a statement");
return qp.local().execute_prepared_without_checking_exception_message(query_state, std::move(stmt), options, std::move(prepared), std::move(cache_key), needs_authorization)
.then([trace_state = query_state.get_trace_state(), skip_metadata, q_state = std::move(q_state), stream, version, metadata_id = std::move(metadata_id)] (auto msg) mutable {
if (msg->move_to_shard()) {
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, q_state->query_state.get_trace_state(), version, std::move(metadata_id), skip_metadata)));
}
});
}
future<cql_server::result_with_foreign_response_ptr> cql_server::connection::process_execute(uint16_t stream, request_reader in,
service::client_state& client_state, service_permit permit, tracing::trace_state_ptr trace_state) {
return process(stream, in, client_state, std::move(permit), std::move(trace_state), process_execute_internal);
}
static future<cql_server::process_fn_return_type>
process_batch_internal(service::client_state& client_state, distributed<cql3::query_processor>& qp, request_reader in,
uint16_t stream, cql_protocol_version_type version,
service_permit permit, tracing::trace_state_ptr trace_state, bool init_trace, cql3::computed_function_values cached_pk_fn_calls, cql3::dialect dialect) {
const auto type = in.read_byte();
const unsigned n = in.read_short();
std::vector<cql3::statements::batch_statement::single_statement> modifications;
std::vector<cql3::raw_value_view_vector_with_unset> values;
std::unordered_map<cql3::prepared_cache_key_type, cql3::authorized_prepared_statements_cache::value_type> pending_authorization_entries;
modifications.reserve(n);
values.reserve(n);
if (init_trace) {
tracing::begin(trace_state, "Execute batch of CQL3 queries", client_state.get_client_address());
}
for ([[gnu::unused]] auto i : std::views::iota(0u, n)) {
const auto kind = in.read_byte();
std::unique_ptr<cql3::statements::prepared_statement> stmt_ptr;
cql3::statements::prepared_statement::checked_weak_ptr ps;
bool needs_authorization(kind == 0);
switch (kind) {
case 0: {
auto query = in.read_long_string_view();
stmt_ptr = qp.local().get_statement(query, client_state, dialect);
ps = stmt_ptr->checked_weak_from_this();
if (init_trace) {
tracing::add_query(trace_state, query);
}
break;
}
case 1: {
cql3::prepared_cache_key_type cache_key(in.read_short_bytes(), dialect);
auto& id = cql3::prepared_cache_key_type::cql_id(cache_key);
// First, try to lookup in the cache of already authorized statements. If the corresponding entry is not found there
// look for the prepared statement and then authorize it.
ps = qp.local().get_prepared(client_state.user(), cache_key);
if (!ps) {
ps = qp.local().get_prepared(cache_key);
if (!ps) {
return make_exception_future<cql_server::process_fn_return_type>(exceptions::prepared_query_not_found_exception(id));
}
// authorize a particular prepared statement only once
needs_authorization = pending_authorization_entries.emplace(std::move(cache_key), ps->checked_weak_from_this()).second;
}
if (init_trace) {
tracing::add_query(trace_state, ps->statement->raw_cql_statement);
}
break;
}
default:
return make_exception_future<cql_server::process_fn_return_type>(exceptions::protocol_exception(
"Invalid query kind in BATCH messages. Must be 0 or 1 but got "
+ std::to_string(int(kind))));
}
if (dynamic_cast<cql3::statements::modification_statement*>(ps->statement.get()) == nullptr) {
return make_exception_future<cql_server::process_fn_return_type>(exceptions::invalid_request_exception("Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."));
}
::shared_ptr<cql3::statements::modification_statement> modif_statement_ptr = static_pointer_cast<cql3::statements::modification_statement>(ps->statement);
if (init_trace) {
tracing::add_table_name(trace_state, modif_statement_ptr->keyspace(), modif_statement_ptr->column_family());
tracing::add_prepared_statement(trace_state, ps);
}
modifications.emplace_back(std::move(modif_statement_ptr), needs_authorization);
std::vector<cql3::raw_value_view> tmp;
cql3::unset_bind_variable_vector unset;
in.read_value_view_list(version, tmp, unset);
auto stmt = ps->statement;
if (stmt->get_bound_terms() != tmp.size()) {
return make_exception_future<cql_server::process_fn_return_type>(
exceptions::invalid_request_exception(format("There were {:d} markers(?) in CQL but {:d} bound variables",
stmt->get_bound_terms(), tmp.size())));
}
values.emplace_back(cql3::raw_value_view_vector_with_unset(std::move(tmp), std::move(unset)));
}
auto q_state = std::make_unique<cql_query_state>(client_state, trace_state, std::move(permit));
auto& query_state = q_state->query_state;
// #563. CQL v2 encodes query_options in v1 format for batch requests.
q_state->options = std::make_unique<cql3::query_options>(cql3::query_options::make_batch_options(std::move(*in.read_options(version,
qp.local().get_cql_config())), std::move(values)));
auto& options = *q_state->options;
if (!cached_pk_fn_calls.empty()) {
options.set_cached_pk_function_calls(std::move(cached_pk_fn_calls));
}
if (init_trace) {
tracing::add_prepared_query_options(trace_state, options);
tracing::set_common_query_parameters(trace_state, options.get_consistency(),
options.get_serial_consistency(), options.get_specific_options().timestamp);
tracing::trace(trace_state, "Creating a batch statement");
}
auto batch = ::make_shared<cql3::statements::batch_statement>(cql3::statements::batch_statement::type(type), std::move(modifications), cql3::attributes::none(), qp.local().get_cql_stats());
return qp.local().execute_batch_without_checking_exception_message(batch, query_state, options, std::move(pending_authorization_entries))
.then([stream, batch, q_state = std::move(q_state), trace_state = query_state.get_trace_state(), version] (auto msg) {
if (msg->move_to_shard()) {
return cql_server::process_fn_return_type(make_foreign(dynamic_pointer_cast<messages::result_message::bounce_to_shard>(msg)));
} else if (msg->is_exception()) {
return cql_server::process_fn_return_type(convert_error_message_to_coordinator_result(msg.get()));
} else {
tracing::trace(q_state->query_state.get_trace_state(), "Done processing - preparing a result");
return cql_server::process_fn_return_type(make_foreign(make_result(stream, *msg, trace_state, version, cql_metadata_id_wrapper{})));
}
});
}
cql3::dialect
cql_server::connection::get_dialect() const {
return cql3::dialect{
.duplicate_bind_variable_names_refer_to_same_variable = _server._config.cql_duplicate_bind_variable_names_refer_to_same_variable,
};
}
future<cql_server::result_with_foreign_response_ptr>
cql_server::connection::process_batch(uint16_t stream, request_reader in, service::client_state& client_state, service_permit permit,
tracing::trace_state_ptr trace_state) {
return process(stream, in, client_state, permit, std::move(trace_state), process_batch_internal);
}
future<std::unique_ptr<cql_server::response>>
cql_server::connection::process_register(uint16_t stream, request_reader in, service::client_state& client_state,
tracing::trace_state_ptr trace_state) {
using ret_type = std::unique_ptr<cql_server::response>;
std::vector<sstring> event_types;
in.read_string_list(event_types);
for (auto&& event_type : event_types) {
utils::result_with_exception<event::event_type, exceptions::protocol_exception> et = parse_event_type(event_type);
if (!et) {
return std::move(et).assume_error().into_exception_future<ret_type>();
}
_server._notifier->register_event(std::move(et).value(), this);
}
_ready = true;
on_connection_ready();
return make_ready_future<ret_type>(make_ready(stream, std::move(trace_state)));
}
std::unique_ptr<cql_server::response> cql_server::connection::make_unavailable_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t required, int32_t alive, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_consistency(cl);
response->write_int(required);
response->write_int(alive);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_read_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_consistency(cl);
response->write_int(received);
response->write_int(blockfor);
response->write_byte(data_present);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_read_failure_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t numfailures, int32_t blockfor, bool data_present, const tracing::trace_state_ptr& tr_state) const
{
if (_version < 4) {
return make_read_timeout_error(stream, exceptions::exception_code::READ_TIMEOUT, std::move(msg), cl, received, blockfor, data_present, tr_state);
}
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_consistency(cl);
response->write_int(received);
response->write_int(blockfor);
response->write_int(numfailures);
response->write_byte(data_present);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_mutation_write_timeout_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_consistency(cl);
response->write_int(received);
response->write_int(blockfor);
response->write_string(format("{}", type));
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_mutation_write_failure_error(int16_t stream, exceptions::exception_code err, sstring msg, db::consistency_level cl, int32_t received, int32_t numfailures, int32_t blockfor, db::write_type type, const tracing::trace_state_ptr& tr_state) const
{
if (_version < 4) {
return make_mutation_write_timeout_error(stream, exceptions::exception_code::WRITE_TIMEOUT, std::move(msg), cl, received, blockfor, type, tr_state);
}
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_consistency(cl);
response->write_int(received);
response->write_int(blockfor);
response->write_int(numfailures);
response->write_string(format("{}", type));
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_already_exists_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring cf_name, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_string(ks_name);
response->write_string(cf_name);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_unprepared_error(int16_t stream, exceptions::exception_code err, sstring msg, bytes id, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_short_bytes(id);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_function_failure_error(int16_t stream, exceptions::exception_code err, sstring msg, sstring ks_name, sstring func_name, std::vector<sstring> args, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_string(ks_name);
response->write_string(func_name);
response->write_string_list(args);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_rate_limit_error(int16_t stream, exceptions::exception_code err, sstring msg, db::operation_type op_type, bool rejected_by_coordinator, const tracing::trace_state_ptr& tr_state, const service::client_state& client_state) const
{
if (!client_state.is_protocol_extension_set(cql_protocol_extension::RATE_LIMIT_ERROR)) {
return make_error(stream, exceptions::exception_code::CONFIG_ERROR, std::move(msg), tr_state);
}
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
response->write_byte(static_cast<uint8_t>(op_type));
response->write_byte(static_cast<uint8_t>(rejected_by_coordinator));
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_error(int16_t stream, exceptions::exception_code err, sstring msg, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::ERROR, tr_state);
response->write_int(static_cast<int32_t>(err));
response->write_string(msg);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_ready(int16_t stream, const tracing::trace_state_ptr& tr_state) const
{
return std::make_unique<cql_server::response>(stream, cql_binary_opcode::READY, tr_state);
}
std::unique_ptr<cql_server::response> cql_server::connection::make_autheticate(int16_t stream, std::string_view clz, const tracing::trace_state_ptr& tr_state) const
{
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::AUTHENTICATE, tr_state);
response->write_string(clz);
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_auth_success(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) const {
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::AUTH_SUCCESS, tr_state);
response->write_bytes(std::move(b));
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_auth_challenge(int16_t stream, bytes b, const tracing::trace_state_ptr& tr_state) const {
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::AUTH_CHALLENGE, tr_state);
response->write_bytes(std::move(b));
return response;
}
std::unique_ptr<cql_server::response> cql_server::connection::make_supported(int16_t stream, const tracing::trace_state_ptr& tr_state) const
{
std::multimap<sstring, sstring> opts;
opts.insert({"CQL_VERSION", cql3::query_processor::CQL_VERSION});
opts.insert({"COMPRESSION", "lz4"});
opts.insert({"COMPRESSION", "snappy"});
if (_server._config.allow_shard_aware_drivers) {
opts.insert({"SCYLLA_SHARD", format("{:d}", this_shard_id())});
opts.insert({"SCYLLA_NR_SHARDS", format("{:d}", smp::count)});
opts.insert({"SCYLLA_SHARDING_ALGORITHM", dht::cpu_sharding_algorithm_name()});
if (_server._config.shard_aware_transport_port) {
opts.insert({"SCYLLA_SHARD_AWARE_PORT", format("{:d}", *_server._config.shard_aware_transport_port)});
}
if (_server._config.shard_aware_transport_port_ssl) {
opts.insert({"SCYLLA_SHARD_AWARE_PORT_SSL", format("{:d}", *_server._config.shard_aware_transport_port_ssl)});
}
opts.insert({"SCYLLA_SHARDING_IGNORE_MSB", format("{:d}", _server._config.sharding_ignore_msb)});
opts.insert({"SCYLLA_PARTITIONER", _server._config.partitioner_name});
}
for (cql_protocol_extension ext : supported_cql_protocol_extensions()) {
const sstring ext_key_name = protocol_extension_name(ext);
std::vector<sstring> params = additional_options_for_proto_ext(ext);
if (params.empty()) {
opts.emplace(ext_key_name, "");
} else {
for (sstring val : params) {
opts.emplace(ext_key_name, std::move(val));
}
}
}
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::SUPPORTED, tr_state);
response->write_string_multimap(std::move(opts));
return response;
}
class cql_server::fmt_visitor : public messages::result_message::visitor_base {
private:
uint8_t _version;
cql_server::response& _response;
bool _skip_metadata;
cql_metadata_id_wrapper _metadata_id;
public:
fmt_visitor(uint8_t version, cql_server::response& response, bool skip_metadata, cql_metadata_id_wrapper&& metadata_id)
: _version{version}
, _response{response}
, _skip_metadata{skip_metadata}
, _metadata_id(std::move(metadata_id))
{ }
virtual void visit(const messages::result_message::void_message&) override {
_response.write_int(0x0001);
}
virtual void visit(const messages::result_message::set_keyspace& m) override {
_response.write_int(0x0003);
_response.write_string(m.get_keyspace());
}
virtual void visit(const messages::result_message::prepared::cql& m) override {
_response.write_int(0x0004);
_response.write_short_bytes(m.get_id());
if (_metadata_id.has_response_metadata_id()) {
_response.write_short_bytes(_metadata_id.get_response_metadata_id()._metadata_id);
}
_response.write(m.metadata(), _version);
_response.write(*m.result_metadata(), _metadata_id);
}
virtual void visit(const messages::result_message::schema_change& m) override {
auto change = m.get_change();
switch (change->type) {
case event::event_type::SCHEMA_CHANGE: {
auto sc = static_pointer_cast<event::schema_change>(change);
_response.write_int(0x0005);
_response.serialize(*sc, _version);
break;
}
default:
SCYLLA_ASSERT(0);
}
}
virtual void visit(const messages::result_message::rows& m) override {
_response.write_int(0x0002);
auto& rs = m.rs();
_response.write(rs.get_metadata(), _metadata_id, _skip_metadata);
auto row_count_plhldr = _response.write_int_placeholder();
class visitor {
cql_server::response& _response;
int64_t _row_count = 0;
public:
visitor(cql_server::response& r) : _response(r) { }
void start_row() {
_row_count++;
}
void accept_value(std::optional<managed_bytes_view> cell) {
_response.write_value(cell);
}
void end_row() { }
int64_t row_count() const { return _row_count; }
};
auto v = visitor(_response);
rs.visit(v);
row_count_plhldr.write(v.row_count()); // even though the placeholder is for int32_t we won't overflow because of memory limits
}
};
std::unique_ptr<cql_server::response>
make_result(int16_t stream, messages::result_message& msg, const tracing::trace_state_ptr& tr_state,
cql_protocol_version_type version, cql_metadata_id_wrapper&& metadata_id, bool skip_metadata) {
auto response = std::make_unique<cql_server::response>(stream, cql_binary_opcode::RESULT, tr_state);
if (!msg.warnings().empty() && version > 3) [[unlikely]] {
response->set_frame_flag(cql_frame_flags::warning);
response->write_string_list(msg.warnings());
}
if (msg.custom_payload()) {
response->set_frame_flag(cql_frame_flags::custom_payload);
response->write_string_bytes_map(msg.custom_payload().value());
}
cql_server::fmt_visitor fmt{version, *response, skip_metadata, std::move(metadata_id)};
msg.accept(fmt);
return response;
}
std::unique_ptr<cql_server::response>
cql_server::connection::make_topology_change_event(const event::topology_change& event) const
{
auto response = std::make_unique<cql_server::response>(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr());
response->write_string("TOPOLOGY_CHANGE");
response->write_string(to_string(event.change));
response->write_inet(event.node);
return response;
}
std::unique_ptr<cql_server::response>
cql_server::connection::make_status_change_event(const event::status_change& event) const
{
auto response = std::make_unique<cql_server::response>(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr());
response->write_string("STATUS_CHANGE");
response->write_string(to_string(event.status));
response->write_inet(event.node);
return response;
}
std::unique_ptr<cql_server::response>
cql_server::connection::make_schema_change_event(const event::schema_change& event) const
{
auto response = std::make_unique<cql_server::response>(-1, cql_binary_opcode::EVENT, tracing::trace_state_ptr());
response->write_string("SCHEMA_CHANGE");
response->serialize(event, _version);
return response;
}
void cql_server::connection::write_response(foreign_ptr<std::unique_ptr<cql_server::response>>&& response, service_permit permit, cql_compression compression)
{
_ready_to_respond = _ready_to_respond.then([this, compression, response = std::move(response), permit = std::move(permit)] () mutable {
auto message = response->make_message(_version, compression);
message.on_delete([response = std::move(response)] { });
return _write_buf.write(std::move(message)).then([this] {
return _write_buf.flush();
});
});
}
scattered_message<char> cql_server::response::make_message(uint8_t version, cql_compression compression) {
if (compression != cql_compression::none) {
compress(compression);
}
scattered_message<char> msg;
auto frame = make_frame(version, _body.size());
msg.append(std::move(frame));
for (auto&& fragment : _body.fragments()) {
msg.append_static(reinterpret_cast<const char*>(fragment.data()), fragment.size());
}
return msg;
}
void cql_server::response::compress(cql_compression compression)
{
switch (compression) {
case cql_compression::lz4:
compress_lz4();
break;
case cql_compression::snappy:
compress_snappy();
break;
default:
throw std::invalid_argument("Invalid CQL compression algorithm");
}
set_frame_flag(cql_frame_flags::compression);
}
void cql_server::response::compress_lz4()
{
auto input_buffer = input_buffer_guard();
auto output_buffer = output_buffer_guard();
auto in = input_buffer.get_linearized_view(_body);
size_t output_len = LZ4_COMPRESSBOUND(in.size()) + 4;
auto bytes_ostream = output_buffer.make_bytes_ostream(output_len, [&in] (bytes_mutable_view out) -> utils::result_with_exception<size_t, std::runtime_error> {
out.data()[0] = (in.size() >> 24) & 0xFF;
out.data()[1] = (in.size() >> 16) & 0xFF;
out.data()[2] = (in.size() >> 8) & 0xFF;
out.data()[3] = in.size() & 0xFF;
auto ret = LZ4_compress_default(reinterpret_cast<const char*>(in.data()), reinterpret_cast<char*>(out.data() + 4), in.size(), out.size() - 4);
if (ret == 0) {
return bo::failure(std::runtime_error("CQL frame LZ4 compression failure"));
}
return bo::success(static_cast<size_t>(ret) + 4);
});
if (!bytes_ostream) {
throw std::move(bytes_ostream).as_failure();
}
_body = std::move(bytes_ostream).value();
}
void cql_server::response::compress_snappy()
{
auto input_buffer = input_buffer_guard();
auto output_buffer = output_buffer_guard();
auto in = input_buffer.get_linearized_view(_body);
size_t output_len = snappy_max_compressed_length(in.size());
auto bytes_ostream = output_buffer.make_bytes_ostream(output_len, [&in] (bytes_mutable_view out) -> utils::result_with_exception<size_t, std::runtime_error> {
// FIXME: snappy internally performs allocations greater than 128 kiB.
const memory::scoped_large_allocation_warning_threshold slawt{256*1024};
size_t actual_len = out.size();
if (snappy_compress(reinterpret_cast<const char*>(in.data()), in.size(), reinterpret_cast<char*>(out.data()), &actual_len) != SNAPPY_OK) {
return bo::failure(std::runtime_error("CQL frame Snappy compression failure"));
}
return bo::success(actual_len);
});
if (!bytes_ostream) {
throw std::move(bytes_ostream).as_failure();
}
_body = std::move(bytes_ostream).value();
}
void cql_server::response::serialize(const event::schema_change& event, uint8_t version)
{
write_string(to_string(event.change));
write_string(to_string(event.target));
write_string(event.keyspace);
switch (event.target) {
case event::schema_change::target_type::KEYSPACE:
break;
case event::schema_change::target_type::TYPE:
case event::schema_change::target_type::TABLE:
write_string(event.arguments[0]);
break;
case event::schema_change::target_type::FUNCTION:
case event::schema_change::target_type::AGGREGATE:
write_string(event.arguments[0]);
write_string_list(std::vector<sstring>(event.arguments.begin() + 1, event.arguments.end()));
break;
}
}
void cql_server::response::write_byte(uint8_t b)
{
auto s = reinterpret_cast<const int8_t*>(&b);
_body.write(bytes_view(s, sizeof(b)));
}
void cql_server::response::write_int(int32_t n)
{
auto u = htonl(n);
auto *s = reinterpret_cast<const int8_t*>(&u);
_body.write(bytes_view(s, sizeof(u)));
}
cql_server::response::placeholder<int32_t> cql_server::response::write_int_placeholder() {
return placeholder<int32_t>(_body.write_place_holder(sizeof(int32_t)));
}
void cql_server::response::write_long(int64_t n)
{
auto u = htonq(n);
auto *s = reinterpret_cast<const int8_t*>(&u);
_body.write(bytes_view(s, sizeof(u)));
}
void cql_server::response::write_short(uint16_t n)
{
auto u = htons(n);
auto *s = reinterpret_cast<const int8_t*>(&u);
_body.write(bytes_view(s, sizeof(u)));
}
template<typename T>
inline
T cast_if_fits(size_t v) {
size_t max = std::numeric_limits<T>::max();
if (v > max) {
throw std::runtime_error(format("Value too large, {:d} > {:d}", v, max));
}
return static_cast<T>(v);
}
void cql_server::response::write_string(std::string_view s)
{
write_short(cast_if_fits<uint16_t>(s.size()));
_body.write(bytes_view(reinterpret_cast<const int8_t*>(s.data()), s.size()));
}
void cql_server::response::write_bytes_as_string(bytes_view s)
{
write_short(cast_if_fits<uint16_t>(s.size()));
_body.write(s);
}
void cql_server::response::write_long_string(const sstring& s)
{
write_int(cast_if_fits<int32_t>(s.size()));
_body.write(bytes_view(reinterpret_cast<const int8_t*>(s.data()), s.size()));
}
void cql_server::response::write_string_list(std::vector<sstring> string_list)
{
write_short(cast_if_fits<uint16_t>(string_list.size()));
for (auto&& s : string_list) {
write_string(s);
}
}
void cql_server::response::write_bytes(bytes b)
{
write_int(cast_if_fits<int32_t>(b.size()));
_body.write(b);
}
void cql_server::response::write_short_bytes(bytes b)
{
write_short(cast_if_fits<uint16_t>(b.size()));
_body.write(b);
}
void cql_server::response::write_inet(socket_address inet)
{
auto addr = inet.addr();
write_byte(uint8_t(addr.size()));
auto * p = static_cast<const int8_t*>(addr.data());
_body.write(bytes_view(p, addr.size()));
write_int(inet.port());
}
void cql_server::response::write_consistency(db::consistency_level c)
{
write_short(consistency_to_wire(c));
}
void cql_server::response::write_string_map(std::map<sstring, sstring> string_map)
{
write_short(cast_if_fits<uint16_t>(string_map.size()));
for (auto&& s : string_map) {
write_string(s.first);
write_string(s.second);
}
}
void cql_server::response::write_string_multimap(std::multimap<sstring, sstring> string_map)
{
std::vector<sstring> keys;
for (auto it = string_map.begin(), end = string_map.end(); it != end; it = string_map.upper_bound(it->first)) {
keys.push_back(it->first);
}
write_short(cast_if_fits<uint16_t>(keys.size()));
for (auto&& key : keys) {
std::vector<sstring> values;
auto range = string_map.equal_range(key);
for (auto it = range.first; it != range.second; ++it) {
values.push_back(it->second);
}
write_string(key);
write_string_list(values);
}
}
void cql_server::response::write_string_bytes_map(const std::unordered_map<sstring, bytes>& map)
{
write_short(cast_if_fits<uint16_t>(map.size()));
for (auto&& s : map) {
write_string(s.first);
write_bytes(s.second);
}
}
void cql_server::response::write_value(bytes_opt value)
{
if (!value) {
write_int(-1);
return;
}
write_int(value->size());
_body.write(*value);
}
void cql_server::response::write_value(std::optional<managed_bytes_view> value)
{
if (!value) {
write_int(-1);
return;
}
write_int(value->size_bytes());
while (!value->empty()) {
_body.write(value->current_fragment());
value->remove_current();
}
}
class type_codec {
private:
enum class type_id : int16_t {
CUSTOM = 0x0000,
ASCII = 0x0001,
BIGINT = 0x0002,
BLOB = 0x0003,
BOOLEAN = 0x0004,
COUNTER = 0x0005,
DECIMAL = 0x0006,
DOUBLE = 0x0007,
FLOAT = 0x0008,
INT = 0x0009,
TIMESTAMP = 0x000B,
UUID = 0x000C,
VARCHAR = 0x000D,
VARINT = 0x000E,
TIMEUUID = 0x000F,
INET = 0x0010,
DATE = 0x0011,
TIME = 0x0012,
SMALLINT = 0x0013,
TINYINT = 0x0014,
DURATION = 0x0015,
LIST = 0x0020,
MAP = 0x0021,
SET = 0x0022,
UDT = 0x0030,
TUPLE = 0x0031,
};
using type_id_to_type_type = std::unordered_map<data_type, type_id>;
static thread_local const type_id_to_type_type type_id_to_type;
public:
static void encode(cql_server::response& r, data_type type) {
type = type->underlying_type();
// For compatibility sake, we still return DateType as the timestamp type in resultSet metadata (#5723)
if (type == date_type) {
type = timestamp_type;
}
auto i = type_id_to_type.find(type);
if (i != type_id_to_type.end()) {
r.write_short(static_cast<std::underlying_type<type_id>::type>(i->second));
return;
}
if (type->is_reversed()) {
fail(unimplemented::cause::REVERSED);
}
if (type->is_user_type()) {
r.write_short(uint16_t(type_id::UDT));
auto udt = static_pointer_cast<const user_type_impl>(type);
r.write_string(udt->_keyspace);
r.write_bytes_as_string(udt->_name);
r.write_short(udt->size());
for (auto&& i : std::views::iota(0u, udt->size())) {
r.write_bytes_as_string(udt->field_name(i));
encode(r, udt->field_type(i));
}
return;
}
if (type->is_tuple()) {
r.write_short(uint16_t(type_id::TUPLE));
auto ttype = static_pointer_cast<const tuple_type_impl>(type);
r.write_short(ttype->size());
for (auto&& t : ttype->all_types()) {
encode(r, t);
}
return;
}
if (type->is_vector()) {
r.write_short(uint16_t(type_id::CUSTOM));
r.write_string(type->name());
return;
}
if (type->is_collection()) {
auto&& ctype = static_cast<const collection_type_impl*>(type.get());
if (ctype->get_kind() == abstract_type::kind::map) {
r.write_short(uint16_t(type_id::MAP));
auto&& mtype = static_cast<const map_type_impl*>(ctype);
encode(r, mtype->get_keys_type());
encode(r, mtype->get_values_type());
} else if (ctype->get_kind() == abstract_type::kind::set) {
r.write_short(uint16_t(type_id::SET));
auto&& stype = static_cast<const set_type_impl*>(ctype);
encode(r, stype->get_elements_type());
} else if (ctype->get_kind() == abstract_type::kind::list) {
r.write_short(uint16_t(type_id::LIST));
auto&& ltype = static_cast<const list_type_impl*>(ctype);
encode(r, ltype->get_elements_type());
} else {
abort();
}
return;
}
abort();
}
};
thread_local const type_codec::type_id_to_type_type type_codec::type_id_to_type {
{ ascii_type, type_id::ASCII },
{ long_type, type_id::BIGINT },
{ bytes_type, type_id::BLOB },
{ boolean_type, type_id::BOOLEAN },
{ counter_type, type_id::COUNTER },
{ decimal_type, type_id::DECIMAL },
{ double_type, type_id::DOUBLE },
{ float_type, type_id::FLOAT },
{ int32_type, type_id::INT },
{ byte_type, type_id::TINYINT },
{ duration_type, type_id::DURATION },
{ short_type, type_id::SMALLINT },
{ timestamp_type, type_id::TIMESTAMP },
{ uuid_type, type_id::UUID },
{ utf8_type, type_id::VARCHAR },
{ varint_type, type_id::VARINT },
{ timeuuid_type, type_id::TIMEUUID },
{ simple_date_type, type_id::DATE },
{ time_type, type_id::TIME },
{ inet_addr_type, type_id::INET },
};
void cql_server::response::write(const cql3::metadata& m, const cql_metadata_id_wrapper& metadata_id, bool no_metadata) {
auto flags = m.flags();
bool global_tables_spec = m.flags().contains<cql3::metadata::flag::GLOBAL_TABLES_SPEC>();
bool has_more_pages = m.flags().contains<cql3::metadata::flag::HAS_MORE_PAGES>();
if (no_metadata) {
flags.set<cql3::metadata::flag::NO_METADATA>();
}
cql3::cql_metadata_id_type calculated_metadata_id{bytes{}};
if (metadata_id.has_request_metadata_id() && metadata_id.has_response_metadata_id()) {
if (metadata_id.get_request_metadata_id() != metadata_id.get_response_metadata_id()) {
flags.remove<cql3::metadata::flag::NO_METADATA>();
flags.set<cql3::metadata::flag::METADATA_CHANGED>();
no_metadata = false;
}
}
write_int(flags.mask());
write_int(m.column_count());
if (has_more_pages) {
write_value(m.paging_state()->serialize());
}
if (no_metadata) {
return;
}
if (flags.contains<cql3::metadata::flag::METADATA_CHANGED>()) {
write_short_bytes(metadata_id.get_response_metadata_id()._metadata_id);
}
auto names_i = m.get_names().begin();
if (global_tables_spec) {
auto first_spec = *names_i;
write_string(first_spec->ks_name);
write_string(first_spec->cf_name);
}
for (uint32_t i = 0; i < m.column_count(); ++i, ++names_i) {
lw_shared_ptr<cql3::column_specification> name = *names_i;
if (!global_tables_spec) {
write_string(name->ks_name);
write_string(name->cf_name);
}
write_string(name->name->text());
type_codec::encode(*this, name->type);
}
}
void cql_server::response::write(const cql3::prepared_metadata& m, uint8_t version)
{
bool global_tables_spec = m.flags().contains<cql3::prepared_metadata::flag::GLOBAL_TABLES_SPEC>();
write_int(m.flags().mask());
write_int(m.names().size());
if (version >= 4) {
if (!global_tables_spec) {
write_int(0);
} else {
write_int(m.partition_key_bind_indices().size());
for (uint16_t bind_index : m.partition_key_bind_indices()) {
write_short(bind_index);
}
}
}
if (global_tables_spec) {
write_string(m.names()[0]->ks_name);
write_string(m.names()[0]->cf_name);
}
for (auto const& name : m.names()) {
if (!global_tables_spec) {
write_string(name->ks_name);
write_string(name->cf_name);
}
write_string(name->name->text());
type_codec::encode(*this, name->type);
}
}
bool cql_metadata_id_wrapper::has_request_metadata_id() const {
return _request_metadata_id.has_value();
}
bool cql_metadata_id_wrapper::has_response_metadata_id() const {
return _response_metadata_id.has_value();
}
const cql3::cql_metadata_id_type& cql_metadata_id_wrapper::get_request_metadata_id() const {
if (!has_request_metadata_id()) {
on_internal_error(clogger, "request metadata_id is empty");
}
return _request_metadata_id.value();
}
const cql3::cql_metadata_id_type& cql_metadata_id_wrapper::get_response_metadata_id() const {
if (!has_response_metadata_id()) {
on_internal_error(clogger, "response metadata_id is empty");
}
return _response_metadata_id.value();
}
future<utils::chunked_vector<client_data>> cql_server::get_client_data() {
utils::chunked_vector<client_data> ret;
co_await for_each_gently([&ret] (const generic_server::connection& c) {
const connection& conn = dynamic_cast<const connection&>(c);
ret.emplace_back(conn.make_client_data());
});
co_return ret;
}
future<> cql_server::update_connections_scheduling_group() {
return for_each_gently([] (generic_server::connection& conn) {
connection& cql_conn = dynamic_cast<connection&>(conn);
cql_conn.update_scheduling_group();
});
}
future<> cql_server::update_connections_service_level_params() {
if (!_sl_controller.is_v2()) {
// Auto update of connections' service level params requires
// service levels in v2.
return make_ready_future<>();
}
return for_each_gently([this] (generic_server::connection& conn) {
connection& cql_conn = dynamic_cast<connection&>(conn);
auto& cs = cql_conn.get_client_state();
auto& user = cs.user();
if (user && user->name) {
auto slo = _sl_controller.find_cached_effective_service_level(user->name.value());
if (slo) {
cs.update_per_service_level_params(*slo);
}
}
cql_conn.update_scheduling_group();
});
}
future<std::vector<connection_service_level_params>> cql_server::get_connections_service_level_params() {
std::vector<connection_service_level_params> sl_params;
co_await for_each_gently([&sl_params] (const generic_server::connection& conn) {
auto& cql_conn = dynamic_cast<const connection&>(conn);
auto& client_state = cql_conn.get_client_state();
auto& user = client_state.user();
auto role_name = user
? (user->name ? *(user->name) : "ANONYMOUS")
: "UNAUTHENTICATED";
sl_params.emplace_back(std::move(role_name), client_state.get_timeout_config(), client_state.get_workload_type(), cql_conn.get_scheduling_group().name());
});
co_return sl_params;
}
}